diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2018-04-20 11:45:30 +0100 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2018-04-20 11:45:30 +0100 |
commit | 411bee3b1f984459ce7a496c655dda333ddf6f32 (patch) | |
tree | de5e51bbd23e2f221636ea4a0598b39efbcaf30f /packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |
parent | 51cac5fbc8571b11ac1841ec21cbff66150c7a62 (diff) |
Support all implicit methods and 1 explicit method
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 152 |
1 files changed, 117 insertions, 35 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index f6f6884..1460680 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -8,7 +8,7 @@ | |||
8 | 8 | ||
9 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
10 | -- | | 10 | -- | |
11 | -- Module : Numeric.Sundials.ARKode | 11 | -- Module : Numeric.Sundials.ARKode.ODE |
12 | -- Copyright : Dominic Steinitz 2018, | 12 | -- Copyright : Dominic Steinitz 2018, |
13 | -- Novadiscovery 2018 | 13 | -- Novadiscovery 2018 |
14 | -- License : BSD | 14 | -- License : BSD |
@@ -138,11 +138,9 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | |||
138 | size, subVector) | 138 | size, subVector) |
139 | 139 | ||
140 | import qualified Types as T | 140 | import qualified Types as T |
141 | import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4, fEHLBERG_6_4_5) | 141 | import Arkode |
142 | import qualified Arkode as B | 142 | import qualified Arkode as B |
143 | 143 | ||
144 | import Debug.Trace | ||
145 | |||
146 | 144 | ||
147 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 145 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
148 | 146 | ||
@@ -222,28 +220,88 @@ data SundialsDiagnostics = SundialsDiagnostics { | |||
222 | type Jacobian = Double -> Vector Double -> Matrix Double | 220 | type Jacobian = Double -> Vector Double -> Matrix Double |
223 | 221 | ||
224 | -- | Stepping functions | 222 | -- | Stepping functions |
225 | data ODEMethod = SDIRK_2_1_2 Jacobian | 223 | data ODEMethod = SDIRK_2_1_2 Jacobian |
226 | | KVAERNO_4_2_3 Jacobian | 224 | | SDIRK_2_1_2' |
227 | | SDIRK_5_3_4 Jacobian | 225 | | BILLINGTON_3_3_2 Jacobian |
226 | | BILLINGTON_3_3_2' | ||
227 | | TRBDF2_3_3_2 Jacobian | ||
228 | | TRBDF2_3_3_2' | ||
229 | | KVAERNO_4_2_3 Jacobian | ||
230 | | KVAERNO_4_2_3' | ||
231 | | ARK324L2SA_DIRK_4_2_3 Jacobian | ||
232 | | ARK324L2SA_DIRK_4_2_3' | ||
233 | | CASH_5_2_4 Jacobian | ||
234 | | CASH_5_2_4' | ||
235 | | CASH_5_3_4 Jacobian | ||
236 | | CASH_5_3_4' | ||
237 | | SDIRK_5_3_4 Jacobian | ||
228 | | SDIRK_5_3_4' | 238 | | SDIRK_5_3_4' |
229 | | FEHLBERG_6_4_5 Jacobian | 239 | | KVAERNO_5_3_4 Jacobian |
240 | | KVAERNO_5_3_4' | ||
241 | | ARK436L2SA_DIRK_6_3_4 Jacobian | ||
242 | | ARK436L2SA_DIRK_6_3_4' | ||
243 | | KVAERNO_7_4_5 Jacobian | ||
244 | | KVAERNO_7_4_5' | ||
245 | | ARK548L2SA_DIRK_8_4_5 Jacobian | ||
246 | | ARK548L2SA_DIRK_8_4_5' | ||
247 | | FEHLBERG_6_4_5 Jacobian | ||
230 | | FEHLBERG_6_4_5' | 248 | | FEHLBERG_6_4_5' |
231 | 249 | ||
232 | getMethod :: ODEMethod -> Int | 250 | getMethod :: ODEMethod -> Int |
233 | getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 | 251 | getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 |
234 | getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 | 252 | getMethod (SDIRK_2_1_2') = sDIRK_2_1_2 |
235 | getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 | 253 | getMethod (BILLINGTON_3_3_2 _) = bILLINGTON_3_3_2 |
236 | getMethod (SDIRK_5_3_4' ) = sDIRK_5_3_4 | 254 | getMethod (BILLINGTON_3_3_2') = bILLINGTON_3_3_2 |
237 | getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5 | 255 | getMethod (TRBDF2_3_3_2 _) = tRBDF2_3_3_2 |
238 | getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5 | 256 | getMethod (TRBDF2_3_3_2') = tRBDF2_3_3_2 |
257 | getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 | ||
258 | getMethod (KVAERNO_4_2_3') = kVAERNO_4_2_3 | ||
259 | getMethod (ARK324L2SA_DIRK_4_2_3 _) = aRK324L2SA_DIRK_4_2_3 | ||
260 | getMethod (ARK324L2SA_DIRK_4_2_3') = aRK324L2SA_DIRK_4_2_3 | ||
261 | getMethod (CASH_5_2_4 _) = cASH_5_2_4 | ||
262 | getMethod (CASH_5_2_4') = cASH_5_2_4 | ||
263 | getMethod (CASH_5_3_4 _) = cASH_5_3_4 | ||
264 | getMethod (CASH_5_3_4') = cASH_5_3_4 | ||
265 | getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 | ||
266 | getMethod (SDIRK_5_3_4') = sDIRK_5_3_4 | ||
267 | getMethod (KVAERNO_5_3_4 _) = kVAERNO_5_3_4 | ||
268 | getMethod (KVAERNO_5_3_4') = kVAERNO_5_3_4 | ||
269 | getMethod (ARK436L2SA_DIRK_6_3_4 _) = aRK436L2SA_DIRK_6_3_4 | ||
270 | getMethod (ARK436L2SA_DIRK_6_3_4') = aRK436L2SA_DIRK_6_3_4 | ||
271 | getMethod (KVAERNO_7_4_5 _) = kVAERNO_7_4_5 | ||
272 | getMethod (KVAERNO_7_4_5') = kVAERNO_7_4_5 | ||
273 | getMethod (ARK548L2SA_DIRK_8_4_5 _) = aRK548L2SA_DIRK_8_4_5 | ||
274 | getMethod (ARK548L2SA_DIRK_8_4_5') = aRK548L2SA_DIRK_8_4_5 | ||
275 | getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5 | ||
276 | getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5 | ||
239 | 277 | ||
240 | getJacobian :: ODEMethod -> Maybe Jacobian | 278 | getJacobian :: ODEMethod -> Maybe Jacobian |
241 | getJacobian (SDIRK_2_1_2 j) = Just j | 279 | getJacobian (SDIRK_2_1_2 j) = Just j |
242 | getJacobian (KVAERNO_4_2_3 j) = Just j | 280 | getJacobian (SDIRK_2_1_2') = Nothing |
243 | getJacobian (SDIRK_5_3_4 j) = Just j | 281 | getJacobian (BILLINGTON_3_3_2 j) = Just j |
244 | getJacobian (SDIRK_5_3_4' ) = Nothing | 282 | getJacobian (BILLINGTON_3_3_2') = Nothing |
245 | getJacobian (FEHLBERG_6_4_5 j) = Just j | 283 | getJacobian (TRBDF2_3_3_2 j) = Just j |
246 | getJacobian (FEHLBERG_6_4_5' ) = Nothing | 284 | getJacobian (TRBDF2_3_3_2') = Nothing |
285 | getJacobian (KVAERNO_4_2_3 j) = Just j | ||
286 | getJacobian (KVAERNO_4_2_3') = Nothing | ||
287 | getJacobian (ARK324L2SA_DIRK_4_2_3 j) = Just j | ||
288 | getJacobian (ARK324L2SA_DIRK_4_2_3') = Nothing | ||
289 | getJacobian (CASH_5_2_4 j) = Just j | ||
290 | getJacobian (CASH_5_2_4') = Nothing | ||
291 | getJacobian (CASH_5_3_4 j) = Just j | ||
292 | getJacobian (CASH_5_3_4') = Nothing | ||
293 | getJacobian (SDIRK_5_3_4 j) = Just j | ||
294 | getJacobian (SDIRK_5_3_4') = Nothing | ||
295 | getJacobian (KVAERNO_5_3_4 j) = Just j | ||
296 | getJacobian (KVAERNO_5_3_4') = Nothing | ||
297 | getJacobian (ARK436L2SA_DIRK_6_3_4 j) = Just j | ||
298 | getJacobian (ARK436L2SA_DIRK_6_3_4') = Nothing | ||
299 | getJacobian (KVAERNO_7_4_5 j) = Just j | ||
300 | getJacobian (KVAERNO_7_4_5') = Nothing | ||
301 | getJacobian (ARK548L2SA_DIRK_8_4_5 j) = Just j | ||
302 | getJacobian (ARK548L2SA_DIRK_8_4_5') = Nothing | ||
303 | getJacobian (FEHLBERG_6_4_5 j) = Just j | ||
304 | getJacobian (FEHLBERG_6_4_5' ) = Nothing | ||
247 | 305 | ||
248 | -- | A version of 'odeSolveVWith' with reasonable default step control. | 306 | -- | A version of 'odeSolveVWith' with reasonable default step control. |
249 | odeSolveV | 307 | odeSolveV |
@@ -262,9 +320,9 @@ odeSolveV | |||
262 | -> Matrix Double -- ^ solution | 320 | -> Matrix Double -- ^ solution |
263 | odeSolveV meth hi epsAbs epsRel f y0 ts = | 321 | odeSolveV meth hi epsAbs epsRel f y0 ts = |
264 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of | 322 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of |
265 | Left c -> error $ show c -- FIXME | 323 | Left c -> error $ show c -- FIXME |
266 | -- FIXME: Can we do better than using lists? | 324 | -- FIXME: Can we do better than using lists? |
267 | Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) | 325 | Right (v, _d) -> (nR >< nC) (V.toList v) |
268 | where | 326 | where |
269 | us = toList ts | 327 | us = toList ts |
270 | nR = length us | 328 | nR = length us |
@@ -282,8 +340,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
282 | odeSolve f y0 ts = | 340 | odeSolve f y0 ts = |
283 | -- FIXME: These tolerances are different from the ones in GSL | 341 | -- FIXME: These tolerances are different from the ones in GSL |
284 | case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | 342 | case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of |
285 | Left c -> error $ show c -- FIXME | 343 | Left c -> error $ show c -- FIXME |
286 | Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) | 344 | Right (v, _d) -> (nR >< nC) (V.toList v) |
287 | where | 345 | where |
288 | us = toList ts | 346 | us = toList ts |
289 | nR = length us | 347 | nR = length us |
@@ -449,8 +507,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
449 | 507 | ||
450 | /* Here we use the C types defined in helpers.h which tie up with */ | 508 | /* Here we use the C types defined in helpers.h which tie up with */ |
451 | /* the Haskell types defined in Types */ | 509 | /* the Haskell types defined in Types */ |
452 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | 510 | if ($(int method) < MIN_DIRK_NUM) { |
453 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 511 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); |
512 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
513 | } else { | ||
514 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | ||
515 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
516 | } | ||
454 | 517 | ||
455 | /* Set routines */ | 518 | /* Set routines */ |
456 | flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); | 519 | flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); |
@@ -486,8 +549,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
486 | } | 549 | } |
487 | 550 | ||
488 | /* Explicitly set the method */ | 551 | /* Explicitly set the method */ |
489 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); | 552 | if ($(int method) >= MIN_DIRK_NUM) { |
490 | if (check_flag(&flag, "ARKode", 1)) return 1; | 553 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); |
554 | if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; | ||
555 | } else { | ||
556 | flag = ARKodeSetERKTableNum(arkode_mem, $(int method)); | ||
557 | if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; | ||
558 | } | ||
491 | 559 | ||
492 | /* Main time-stepping loop: calls ARKode to perform the integration */ | 560 | /* Main time-stepping loop: calls ARKode to perform the integration */ |
493 | /* Stops when the final time has been reached */ | 561 | /* Stops when the final time has been reached */ |
@@ -614,8 +682,10 @@ getButcherTable :: ODEMethod | |||
614 | getButcherTable method = unsafePerformIO $ do | 682 | getButcherTable method = unsafePerformIO $ do |
615 | -- ARKode seems to want an ODE in order to set and then get the | 683 | -- ARKode seems to want an ODE in order to set and then get the |
616 | -- Butcher tableau so here's one to keep it happy | 684 | -- Butcher tableau so here's one to keep it happy |
617 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble | 685 | let funI :: CDouble -> V.Vector CDouble -> V.Vector CDouble |
618 | fun _t ys = V.fromList [ ys V.! 0 ] | 686 | funI _t ys = V.fromList [ ys V.! 0 ] |
687 | let funE :: CDouble -> V.Vector CDouble -> V.Vector CDouble | ||
688 | funE _t ys = V.fromList [ ys V.! 0 ] | ||
619 | f0 = V.fromList [ 1.0 ] | 689 | f0 = V.fromList [ 1.0 ] |
620 | ts = V.fromList [ 0.0 ] | 690 | ts = V.fromList [ 0.0 ] |
621 | dim = V.length f0 | 691 | dim = V.length f0 |
@@ -634,9 +704,16 @@ getButcherTable method = unsafePerformIO $ do | |||
634 | btCsMut <- V.thaw btCs | 704 | btCsMut <- V.thaw btCs |
635 | btBsMut <- V.thaw btBs | 705 | btBsMut <- V.thaw btBs |
636 | btB2sMut <- V.thaw btB2s | 706 | btB2sMut <- V.thaw btB2s |
637 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 707 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
638 | funIO x y f _ptr = do | 708 | funIOI x y f _ptr = do |
639 | fImm <- fun x <$> getDataFromContents dim y | 709 | fImm <- funI x <$> getDataFromContents dim y |
710 | putDataInContents fImm dim f | ||
711 | -- FIXME: I don't understand what this comment means | ||
712 | -- Unsafe since the function will be called many times. | ||
713 | [CU.exp| int{ 0 } |] | ||
714 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
715 | funIOE x y f _ptr = do | ||
716 | fImm <- funE x <$> getDataFromContents dim y | ||
640 | putDataInContents fImm dim f | 717 | putDataInContents fImm dim f |
641 | -- FIXME: I don't understand what this comment means | 718 | -- FIXME: I don't understand what this comment means |
642 | -- Unsafe since the function will be called many times. | 719 | -- Unsafe since the function will be called many times. |
@@ -665,11 +742,16 @@ getButcherTable method = unsafePerformIO $ do | |||
665 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | 742 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ |
666 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | 743 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; |
667 | 744 | ||
668 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | 745 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIOE) (double t, SunVector y[], SunVector dydt[], void * params)), $fun:(int (* funIOI) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); |
669 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 746 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
670 | 747 | ||
748 | if ($(int mN) >= MIN_DIRK_NUM) { | ||
671 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); | 749 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); |
672 | if (check_flag(&flag, "ARKode", 1)) return 1; | 750 | if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; |
751 | } else { | ||
752 | flag = ARKodeSetERKTableNum(arkode_mem, $(int mN)); | ||
753 | if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; | ||
754 | } | ||
673 | 755 | ||
674 | int s, q, p; | 756 | int s, q, p; |
675 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | 757 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); |