diff options
Diffstat (limited to 'packages/sundials/src')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 145 |
1 files changed, 112 insertions, 33 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 8358954..0973c82 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -6,8 +6,34 @@ | |||
6 | {-# LANGUAGE OverloadedStrings #-} | 6 | {-# LANGUAGE OverloadedStrings #-} |
7 | {-# LANGUAGE ScopedTypeVariables #-} | 7 | {-# LANGUAGE ScopedTypeVariables #-} |
8 | 8 | ||
9 | ----------------------------------------------------------------------------- | ||
9 | -- | | 10 | -- | |
10 | -- Module: Numeric.Sundials.ARKode | 11 | -- Module : Numeric.Sundials.ARKode |
12 | -- Copyright : Dominic Steinitz 2018, | ||
13 | -- Novadiscovery 2018 | ||
14 | -- License : BSD | ||
15 | -- Maintainer : Dominic Steinitz | ||
16 | -- Stability : provisional | ||
17 | -- | ||
18 | -- Solution of ordinary differential equation (ODE) initial value problems. | ||
19 | -- | ||
20 | -- <https://computation.llnl.gov/projects/sundials/sundials-software> | ||
21 | -- | ||
22 | -- A simple example: | ||
23 | -- | ||
24 | -- @ | ||
25 | -- import Numeric.Sundials.ARKode | ||
26 | -- import Numeric.LinearAlgebra | ||
27 | -- import Graphics.Plot(mplot) | ||
28 | -- | ||
29 | -- xdot t [x,v] = [v, -0.95*x - 0.1*v] | ||
30 | -- | ||
31 | -- ts = linspace 100 (0,20 :: Double) | ||
32 | -- | ||
33 | -- sol = odeSolve xdot [10,0] ts | ||
34 | -- | ||
35 | -- main = mplot (ts : toColumns sol) | ||
36 | -- @ | ||
11 | -- | 37 | -- |
12 | -- KVAERNO_4_2_3 | 38 | -- KVAERNO_4_2_3 |
13 | -- | 39 | -- |
@@ -29,19 +55,34 @@ | |||
29 | -- \end{array} | 55 | -- \end{array} |
30 | -- \] | 56 | -- \] |
31 | -- | 57 | -- |
32 | module Numeric.Sundials.ARKode.ODE ( solveOde | 58 | -- SDIRK_5_3_4 |
33 | , odeSolve | 59 | -- |
60 | -- \[ | ||
61 | -- \begin{array}{c|ccccc} | ||
62 | -- c_1 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
63 | -- c_2 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ | ||
64 | -- c_3 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ | ||
65 | -- c_4 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ | ||
66 | -- c_5 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ | ||
67 | -- \end{array} | ||
68 | -- \] | ||
69 | ----------------------------------------------------------------------------- | ||
70 | module Numeric.Sundials.ARKode.ODE ( odeSolve | ||
71 | , odeSolveV | ||
72 | , odeSolveVWith | ||
73 | , odeSolve' | ||
34 | , getButcherTable | 74 | , getButcherTable |
35 | , getBT | 75 | , getBT |
36 | , btGet | 76 | , btGet |
37 | , ODEMethod(..) | 77 | , ODEMethod(..) |
38 | , odeSolveV | 78 | , StepControl(..) |
39 | ) where | 79 | ) where |
40 | 80 | ||
41 | import qualified Language.C.Inline as C | 81 | import qualified Language.C.Inline as C |
42 | import qualified Language.C.Inline.Unsafe as CU | 82 | import qualified Language.C.Inline.Unsafe as CU |
43 | 83 | ||
44 | import Data.Monoid ((<>)) | 84 | import Data.Monoid ((<>)) |
85 | import Data.Maybe (isJust) | ||
45 | 86 | ||
46 | import Foreign.C.Types | 87 | import Foreign.C.Types |
47 | import Foreign.Ptr (Ptr) | 88 | import Foreign.Ptr (Ptr) |
@@ -60,9 +101,11 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | |||
60 | subMatrix, rows, cols, toLists) | 101 | subMatrix, rows, cols, toLists) |
61 | 102 | ||
62 | import qualified Types as T | 103 | import qualified Types as T |
63 | import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) | 104 | import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4) |
64 | import qualified Arkode as B | 105 | import qualified Arkode as B |
65 | 106 | ||
107 | import Debug.Trace | ||
108 | |||
66 | 109 | ||
67 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 110 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
68 | 111 | ||
@@ -139,14 +182,17 @@ data SundialsDiagnostics = SundialsDiagnostics { | |||
139 | , _aRKDlsGetNumRhsEvals :: Int | 182 | , _aRKDlsGetNumRhsEvals :: Int |
140 | } deriving Show | 183 | } deriving Show |
141 | 184 | ||
185 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
186 | |||
142 | -- | Stepping functions | 187 | -- | Stepping functions |
143 | data ODEMethod = SDIRK_2_1_2 | 188 | data ODEMethod = SDIRK_2_1_2 Jacobian |
144 | | KVAERNO_4_2_3 | 189 | | KVAERNO_4_2_3 Jacobian |
190 | | SDIRK_5_3_4 Jacobian | ||
145 | 191 | ||
146 | instance Enum ODEMethod where | 192 | getMethod :: ODEMethod -> Int |
147 | fromEnum SDIRK_2_1_2 = sDIRK_2_1_2 | 193 | getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 |
148 | fromEnum KVAERNO_4_2_3 = kVAERNO_4_2_3 | 194 | getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 |
149 | toEnum _ = error "toEnum not defined for ODEMethod" | 195 | getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 |
150 | 196 | ||
151 | -- | A version of 'odeSolveVWith' with reasonable default step control. | 197 | -- | A version of 'odeSolveVWith' with reasonable default step control. |
152 | odeSolveV | 198 | odeSolveV |
@@ -160,16 +206,34 @@ odeSolveV | |||
160 | -> Matrix Double -- ^ solution | 206 | -> Matrix Double -- ^ solution |
161 | odeSolveV _meth _hi _epsAbs _epsRel = undefined | 207 | odeSolveV _meth _hi _epsAbs _epsRel = undefined |
162 | 208 | ||
163 | odeSolve :: ODEMethod | 209 | -- | A version of 'odeSolveV' with reasonable default parameters and |
210 | -- system of equations defined using lists. FIXME: we should say | ||
211 | -- something about the fact we could use the Jacobian but don't for | ||
212 | -- compatibility with hmatrix-gsl. | ||
213 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
214 | -> [Double] -- ^ initial conditions | ||
215 | -> Vector Double -- ^ desired solution times | ||
216 | -> Matrix Double -- ^ solution | ||
217 | odeSolve f y0 ts = | ||
218 | case odeSolveVWith (SDIRK_5_3_4 undefined) Nothing 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of | ||
219 | Left c -> error $ show c -- FIXME | ||
220 | Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) | ||
221 | where | ||
222 | us = toList ts | ||
223 | nR = length us | ||
224 | nC = length y0 | ||
225 | g t x0 = V.fromList $ f t (V.toList x0) | ||
226 | |||
227 | odeSolve' :: ODEMethod | ||
164 | -> (Double -> Vector Double -> Matrix Double) | 228 | -> (Double -> Vector Double -> Matrix Double) |
165 | -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 229 | -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
166 | -> [Double] -- ^ initial conditions | 230 | -> [Double] -- ^ initial conditions |
167 | -> Vector Double -- ^ desired solution times | 231 | -> Vector Double -- ^ desired solution times |
168 | -> Matrix Double -- ^ solution | 232 | -> Matrix Double -- ^ solution |
169 | odeSolve method jac f y0 ts = | 233 | odeSolve' method jac f y0 ts = |
170 | case solveOde method jac' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of | 234 | case odeSolveVWith method (pure jac') 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of |
171 | Left c -> error $ show c -- FIXME | 235 | Left c -> error $ show c -- FIXME |
172 | Right (v, _) -> (nR >< nC) (V.toList v) | 236 | Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) |
173 | where | 237 | where |
174 | us = toList ts | 238 | us = toList ts |
175 | nR = length us | 239 | nR = length us |
@@ -182,26 +246,27 @@ odeSolve method jac f y0 ts = | |||
182 | nc = fromIntegral $ cols m | 246 | nc = fromIntegral $ cols m |
183 | vs = V.fromList $ map coerce $ concat $ toLists m | 247 | vs = V.fromList $ map coerce $ concat $ toLists m |
184 | 248 | ||
185 | solveOde :: | 249 | odeSolveVWith :: |
186 | ODEMethod | 250 | ODEMethod |
187 | -> (Double -> V.Vector Double -> T.SunMatrix) | 251 | -> (Maybe (Double -> V.Vector Double -> T.SunMatrix)) |
188 | -> Double | 252 | -> Double |
189 | -> Double | 253 | -> Double |
190 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 254 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
191 | -> V.Vector Double -- ^ Initial conditions | 255 | -> V.Vector Double -- ^ Initial conditions |
192 | -> V.Vector Double -- ^ Desired solution times | 256 | -> V.Vector Double -- ^ Desired solution times |
193 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 257 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution |
194 | solveOde method jac relTol absTol f y0 tt = | 258 | odeSolveVWith method jac relTol absTol f y0 tt = |
195 | case solveOdeC (fromIntegral $ fromEnum method) jacH (CDouble relTol) (CDouble absTol) | 259 | case solveOdeC (fromIntegral $ getMethod method) jacH (CDouble relTol) (CDouble absTol) |
196 | (coerce f) (coerce y0) (coerce tt) of | 260 | (coerce f) (coerce y0) (coerce tt) of |
197 | Left c -> Left $ fromIntegral c | 261 | Left c -> Left $ fromIntegral c |
198 | Right (v, d) -> Right (coerce v, d) | 262 | Right (v, d) -> Right (coerce v, d) |
199 | where | 263 | where |
200 | jacH t v = jac (coerce t) (coerce v) | 264 | jacH :: Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix) |
265 | jacH = fmap (\g -> (\t v -> g (coerce t) (coerce v))) jac | ||
201 | 266 | ||
202 | solveOdeC :: | 267 | solveOdeC :: |
203 | CInt -> | 268 | CInt -> |
204 | (CDouble -> V.Vector CDouble -> T.SunMatrix) -> | 269 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> |
205 | CDouble -> | 270 | CDouble -> |
206 | CDouble -> | 271 | CDouble -> |
207 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 272 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
@@ -235,15 +300,19 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do | |||
235 | -- FIXME: I don't understand what this comment means | 300 | -- FIXME: I don't understand what this comment means |
236 | -- Unsafe since the function will be called many times. | 301 | -- Unsafe since the function will be called many times. |
237 | [CU.exp| int{ 0 } |] | 302 | [CU.exp| int{ 0 } |] |
238 | let jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> | 303 | let isJac :: CInt |
304 | isJac = fromIntegral $ fromEnum $ isJust jacH | ||
305 | jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> | ||
239 | Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> | 306 | Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> |
240 | IO CInt | 307 | IO CInt |
241 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | 308 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do |
242 | j <- jacH t <$> getDataFromContents dim y | 309 | case jacH of |
243 | putMatrixDataFromContents j jacS | 310 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
244 | -- FIXME: I don't understand what this comment means | 311 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y |
245 | -- Unsafe since the function will be called many times. | 312 | putMatrixDataFromContents j jacS |
246 | [CU.exp| int{ 0 } |] | 313 | -- FIXME: I don't understand what this comment means |
314 | -- Unsafe since the function will be called many times. | ||
315 | [CU.exp| int{ 0 } |] | ||
247 | 316 | ||
248 | res <- [C.block| int { | 317 | res <- [C.block| int { |
249 | /* general problem variables */ | 318 | /* general problem variables */ |
@@ -292,9 +361,10 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do | |||
292 | /* Linear solver interface */ | 361 | /* Linear solver interface */ |
293 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | 362 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ |
294 | 363 | ||
295 | flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); | 364 | if ($(int isJac)) { |
296 | if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; | 365 | flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); |
297 | 366 | if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; | |
367 | } | ||
298 | 368 | ||
299 | /* Store initial conditions */ | 369 | /* Store initial conditions */ |
300 | for (j = 0; j < NEQ; j++) { | 370 | for (j = 0; j < NEQ; j++) { |
@@ -389,9 +459,10 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do | |||
389 | btGet :: ODEMethod -> Matrix Double | 459 | btGet :: ODEMethod -> Matrix Double |
390 | btGet method = case getBT method of | 460 | btGet method = case getBT method of |
391 | Left c -> error $ show c -- FIXME | 461 | Left c -> error $ show c -- FIXME |
392 | -- FIXME | 462 | Right (v, sqp) -> subMatrix (0, 0) (s, s) $ |
393 | Right (v, _sqp) -> subMatrix (0, 0) (2, 2) $ | ||
394 | (B.arkSMax >< B.arkSMax) (V.toList v) | 463 | (B.arkSMax >< B.arkSMax) (V.toList v) |
464 | where | ||
465 | s = fromIntegral $ sqp V.! 0 | ||
395 | 466 | ||
396 | getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) | 467 | getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) |
397 | getBT method = case getButcherTable method of | 468 | getBT method = case getButcherTable method of |
@@ -410,7 +481,7 @@ getButcherTable method = unsafePerformIO $ do | |||
410 | nEq :: CLong | 481 | nEq :: CLong |
411 | nEq = fromIntegral dim | 482 | nEq = fromIntegral dim |
412 | mN :: CInt | 483 | mN :: CInt |
413 | mN = fromIntegral $ fromEnum method | 484 | mN = fromIntegral $ getMethod method |
414 | 485 | ||
415 | -- FIXME: I believe these gets taken from the ghc heap and so should | 486 | -- FIXME: I believe these gets taken from the ghc heap and so should |
416 | -- be subject to garbage collection. | 487 | -- be subject to garbage collection. |
@@ -493,3 +564,11 @@ getButcherTable method = unsafePerformIO $ do | |||
493 | return $ Right (x, y) | 564 | return $ Right (x, y) |
494 | else do | 565 | else do |
495 | return $ Left res | 566 | return $ Left res |
567 | |||
568 | -- | Adaptive step-size control functions. FIXME: It may not be | ||
569 | -- possible to scale the tolerances for the derivatives in sundials so | ||
570 | -- for now we ignore them and emit a warning. | ||
571 | data StepControl = X Double Double -- ^ abs. and rel. tolerance for x(t) | ||
572 | | X' Double Double -- ^ abs. and rel. tolerance for x'(t) | ||
573 | | XX' Double Double Double Double -- ^ include both via rel. tolerance scaling factors a_x, a_x' | ||
574 | | ScXX' Double Double Double Double (Vector Double) -- ^ scale abs. tolerance of x(t) components | ||