From 35a8e1f9d942af92eac7c9340c91ffb3d5e710a0 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Sat, 7 Apr 2018 13:58:27 +0100 Subject: Pass through tolerances --- packages/sundials/src/Main.hs | 10 ++++++ .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 38 +++++++++++++++------- 2 files changed, 36 insertions(+), 12 deletions(-) (limited to 'packages') diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 9978aa5..6d6a397 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -69,6 +69,12 @@ stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] lamda = -100.0 u = v !! 0 +stiffishV :: Double -> Vector Double -> Vector Double +stiffishV t v = fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] + where + lamda = -100.0 + u = v ! 0 + stiffJac :: Double -> Vector Double -> Matrix Double stiffJac _t _v = (1><1) [ lamda ] where @@ -143,6 +149,10 @@ main = do (D.dims2D 500.0 500.0) (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) + let res2a = odeSolveV (SDIRK_5_3_4 stiffJac) 0.1 1e-3 1e-6 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) + putStrLn "Lower tolerances" + putStrLn $ show res2a + let res3 = odeSolve' (SDIRK_5_3_4 lorenzJac) lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) putStrLn $ show $ last ((toLists $ tr res3)!!0) diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 9ddb3df..85f1b3d 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -215,7 +215,7 @@ odeSolveV -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolveV meth _hi epsAbs epsRel f y0 ts = - case odeSolveVWith meth (XX' epsAbs epsRel 1 1) epsAbs epsAbs g y0 ts of + case odeSolveVWith meth (X epsAbs epsRel) g y0 ts of Left c -> error $ show c -- FIXME -- FIXME: Can we do better than using lists? Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) @@ -235,7 +235,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Matrix Double -- ^ solution odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL - case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of + case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) where @@ -250,7 +250,7 @@ odeSolve' :: ODEMethod -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolve' method f y0 ts = - case odeSolveVWith method (XX' 1.0e-6 1.0e-10 1 1) 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of + case odeSolveVWith method (XX' 1.0e-6 1.0e-10 1 1) g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) where @@ -262,36 +262,40 @@ odeSolve' method f y0 ts = odeSolveVWith :: ODEMethod -> StepControl - -> Double - -> Double -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith method _control relTol absTol f y0 tt = - case solveOdeC (fromIntegral $ getMethod method) jacH (CDouble relTol) (CDouble absTol) +odeSolveVWith method control f y0 tt = + case solveOdeC (fromIntegral $ getMethod method) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c Right (v, d) -> Right (coerce v, d) where + l = size y0 + scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) + scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) + scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) + -- FIXME; Should we check that the length of ss is correct? + scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ getJacobian method matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } where nr = fromIntegral $ rows m nc = fromIntegral $ cols m + -- FIXME: efficiency vs = V.fromList $ map coerce $ concat $ toLists m solveOdeC :: CInt -> (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> - CDouble -> - CDouble -> + (V.Vector CDouble, CDouble) -> (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution -solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do +solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim @@ -336,6 +340,9 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do /* general problem variables */ int flag; /* reusable error-checking flag */ N_Vector y = NULL; /* empty vector for storing solution */ + /* empty vector for storing absolute tolerances */ + N_Vector tv = NULL; + SUNMatrix A = NULL; /* empty matrix for linear solver */ SUNLinearSolver LS = NULL; /* empty linear solver object */ void *arkode_mem = NULL; /* empty ARKode memory structure */ @@ -353,6 +360,13 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do for (i = 0; i < NEQ; i++) { NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; }; /* Specify initial condition */ + + tv = N_VNew_Serial(NEQ); /* Create serial vector for solution */ + if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; + for (i = 0; i < NEQ; i++) { + NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; + }; + arkode_mem = ARKodeCreate(); /* Create the solver memory */ if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; @@ -367,8 +381,8 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do if (check_flag(&flag, "ARKodeInit", 1)) return 1; /* Set routines */ - flag = ARKodeSStolerances(arkode_mem, $(double relTol), $(double absTol)); - if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; + flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); + if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; /* Initialize dense matrix data structure and solver */ A = SUNDenseMatrix(NEQ, NEQ); -- cgit v1.2.3