summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-04-20 11:45:30 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-04-20 11:45:30 +0100
commit411bee3b1f984459ce7a496c655dda333ddf6f32 (patch)
treede5e51bbd23e2f221636ea4a0598b39efbcaf30f /packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
parent51cac5fbc8571b11ac1841ec21cbff66150c7a62 (diff)
Support all implicit methods and 1 explicit method
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs152
1 files changed, 117 insertions, 35 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index f6f6884..1460680 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -8,7 +8,7 @@
8 8
9----------------------------------------------------------------------------- 9-----------------------------------------------------------------------------
10-- | 10-- |
11-- Module : Numeric.Sundials.ARKode 11-- Module : Numeric.Sundials.ARKode.ODE
12-- Copyright : Dominic Steinitz 2018, 12-- Copyright : Dominic Steinitz 2018,
13-- Novadiscovery 2018 13-- Novadiscovery 2018
14-- License : BSD 14-- License : BSD
@@ -138,11 +138,9 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
138 size, subVector) 138 size, subVector)
139 139
140import qualified Types as T 140import qualified Types as T
141import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4, fEHLBERG_6_4_5) 141import Arkode
142import qualified Arkode as B 142import qualified Arkode as B
143 143
144import Debug.Trace
145
146 144
147C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) 145C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
148 146
@@ -222,28 +220,88 @@ data SundialsDiagnostics = SundialsDiagnostics {
222type Jacobian = Double -> Vector Double -> Matrix Double 220type Jacobian = Double -> Vector Double -> Matrix Double
223 221
224-- | Stepping functions 222-- | Stepping functions
225data ODEMethod = SDIRK_2_1_2 Jacobian 223data ODEMethod = SDIRK_2_1_2 Jacobian
226 | KVAERNO_4_2_3 Jacobian 224 | SDIRK_2_1_2'
227 | SDIRK_5_3_4 Jacobian 225 | BILLINGTON_3_3_2 Jacobian
226 | BILLINGTON_3_3_2'
227 | TRBDF2_3_3_2 Jacobian
228 | TRBDF2_3_3_2'
229 | KVAERNO_4_2_3 Jacobian
230 | KVAERNO_4_2_3'
231 | ARK324L2SA_DIRK_4_2_3 Jacobian
232 | ARK324L2SA_DIRK_4_2_3'
233 | CASH_5_2_4 Jacobian
234 | CASH_5_2_4'
235 | CASH_5_3_4 Jacobian
236 | CASH_5_3_4'
237 | SDIRK_5_3_4 Jacobian
228 | SDIRK_5_3_4' 238 | SDIRK_5_3_4'
229 | FEHLBERG_6_4_5 Jacobian 239 | KVAERNO_5_3_4 Jacobian
240 | KVAERNO_5_3_4'
241 | ARK436L2SA_DIRK_6_3_4 Jacobian
242 | ARK436L2SA_DIRK_6_3_4'
243 | KVAERNO_7_4_5 Jacobian
244 | KVAERNO_7_4_5'
245 | ARK548L2SA_DIRK_8_4_5 Jacobian
246 | ARK548L2SA_DIRK_8_4_5'
247 | FEHLBERG_6_4_5 Jacobian
230 | FEHLBERG_6_4_5' 248 | FEHLBERG_6_4_5'
231 249
232getMethod :: ODEMethod -> Int 250getMethod :: ODEMethod -> Int
233getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 251getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2
234getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 252getMethod (SDIRK_2_1_2') = sDIRK_2_1_2
235getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 253getMethod (BILLINGTON_3_3_2 _) = bILLINGTON_3_3_2
236getMethod (SDIRK_5_3_4' ) = sDIRK_5_3_4 254getMethod (BILLINGTON_3_3_2') = bILLINGTON_3_3_2
237getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5 255getMethod (TRBDF2_3_3_2 _) = tRBDF2_3_3_2
238getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5 256getMethod (TRBDF2_3_3_2') = tRBDF2_3_3_2
257getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3
258getMethod (KVAERNO_4_2_3') = kVAERNO_4_2_3
259getMethod (ARK324L2SA_DIRK_4_2_3 _) = aRK324L2SA_DIRK_4_2_3
260getMethod (ARK324L2SA_DIRK_4_2_3') = aRK324L2SA_DIRK_4_2_3
261getMethod (CASH_5_2_4 _) = cASH_5_2_4
262getMethod (CASH_5_2_4') = cASH_5_2_4
263getMethod (CASH_5_3_4 _) = cASH_5_3_4
264getMethod (CASH_5_3_4') = cASH_5_3_4
265getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4
266getMethod (SDIRK_5_3_4') = sDIRK_5_3_4
267getMethod (KVAERNO_5_3_4 _) = kVAERNO_5_3_4
268getMethod (KVAERNO_5_3_4') = kVAERNO_5_3_4
269getMethod (ARK436L2SA_DIRK_6_3_4 _) = aRK436L2SA_DIRK_6_3_4
270getMethod (ARK436L2SA_DIRK_6_3_4') = aRK436L2SA_DIRK_6_3_4
271getMethod (KVAERNO_7_4_5 _) = kVAERNO_7_4_5
272getMethod (KVAERNO_7_4_5') = kVAERNO_7_4_5
273getMethod (ARK548L2SA_DIRK_8_4_5 _) = aRK548L2SA_DIRK_8_4_5
274getMethod (ARK548L2SA_DIRK_8_4_5') = aRK548L2SA_DIRK_8_4_5
275getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5
276getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5
239 277
240getJacobian :: ODEMethod -> Maybe Jacobian 278getJacobian :: ODEMethod -> Maybe Jacobian
241getJacobian (SDIRK_2_1_2 j) = Just j 279getJacobian (SDIRK_2_1_2 j) = Just j
242getJacobian (KVAERNO_4_2_3 j) = Just j 280getJacobian (SDIRK_2_1_2') = Nothing
243getJacobian (SDIRK_5_3_4 j) = Just j 281getJacobian (BILLINGTON_3_3_2 j) = Just j
244getJacobian (SDIRK_5_3_4' ) = Nothing 282getJacobian (BILLINGTON_3_3_2') = Nothing
245getJacobian (FEHLBERG_6_4_5 j) = Just j 283getJacobian (TRBDF2_3_3_2 j) = Just j
246getJacobian (FEHLBERG_6_4_5' ) = Nothing 284getJacobian (TRBDF2_3_3_2') = Nothing
285getJacobian (KVAERNO_4_2_3 j) = Just j
286getJacobian (KVAERNO_4_2_3') = Nothing
287getJacobian (ARK324L2SA_DIRK_4_2_3 j) = Just j
288getJacobian (ARK324L2SA_DIRK_4_2_3') = Nothing
289getJacobian (CASH_5_2_4 j) = Just j
290getJacobian (CASH_5_2_4') = Nothing
291getJacobian (CASH_5_3_4 j) = Just j
292getJacobian (CASH_5_3_4') = Nothing
293getJacobian (SDIRK_5_3_4 j) = Just j
294getJacobian (SDIRK_5_3_4') = Nothing
295getJacobian (KVAERNO_5_3_4 j) = Just j
296getJacobian (KVAERNO_5_3_4') = Nothing
297getJacobian (ARK436L2SA_DIRK_6_3_4 j) = Just j
298getJacobian (ARK436L2SA_DIRK_6_3_4') = Nothing
299getJacobian (KVAERNO_7_4_5 j) = Just j
300getJacobian (KVAERNO_7_4_5') = Nothing
301getJacobian (ARK548L2SA_DIRK_8_4_5 j) = Just j
302getJacobian (ARK548L2SA_DIRK_8_4_5') = Nothing
303getJacobian (FEHLBERG_6_4_5 j) = Just j
304getJacobian (FEHLBERG_6_4_5' ) = Nothing
247 305
248-- | A version of 'odeSolveVWith' with reasonable default step control. 306-- | A version of 'odeSolveVWith' with reasonable default step control.
249odeSolveV 307odeSolveV
@@ -262,9 +320,9 @@ odeSolveV
262 -> Matrix Double -- ^ solution 320 -> Matrix Double -- ^ solution
263odeSolveV meth hi epsAbs epsRel f y0 ts = 321odeSolveV meth hi epsAbs epsRel f y0 ts =
264 case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of 322 case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of
265 Left c -> error $ show c -- FIXME 323 Left c -> error $ show c -- FIXME
266 -- FIXME: Can we do better than using lists? 324 -- FIXME: Can we do better than using lists?
267 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 325 Right (v, _d) -> (nR >< nC) (V.toList v)
268 where 326 where
269 us = toList ts 327 us = toList ts
270 nR = length us 328 nR = length us
@@ -282,8 +340,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
282odeSolve f y0 ts = 340odeSolve f y0 ts =
283 -- FIXME: These tolerances are different from the ones in GSL 341 -- FIXME: These tolerances are different from the ones in GSL
284 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 342 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of
285 Left c -> error $ show c -- FIXME 343 Left c -> error $ show c -- FIXME
286 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 344 Right (v, _d) -> (nR >< nC) (V.toList v)
287 where 345 where
288 us = toList ts 346 us = toList ts
289 nR = length us 347 nR = length us
@@ -449,8 +507,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
449 507
450 /* Here we use the C types defined in helpers.h which tie up with */ 508 /* Here we use the C types defined in helpers.h which tie up with */
451 /* the Haskell types defined in Types */ 509 /* the Haskell types defined in Types */
452 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); 510 if ($(int method) < MIN_DIRK_NUM) {
453 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 511 flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y);
512 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
513 } else {
514 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
515 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
516 }
454 517
455 /* Set routines */ 518 /* Set routines */
456 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); 519 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv);
@@ -486,8 +549,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
486 } 549 }
487 550
488 /* Explicitly set the method */ 551 /* Explicitly set the method */
489 flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); 552 if ($(int method) >= MIN_DIRK_NUM) {
490 if (check_flag(&flag, "ARKode", 1)) return 1; 553 flag = ARKodeSetIRKTableNum(arkode_mem, $(int method));
554 if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1;
555 } else {
556 flag = ARKodeSetERKTableNum(arkode_mem, $(int method));
557 if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1;
558 }
491 559
492 /* Main time-stepping loop: calls ARKode to perform the integration */ 560 /* Main time-stepping loop: calls ARKode to perform the integration */
493 /* Stops when the final time has been reached */ 561 /* Stops when the final time has been reached */
@@ -614,8 +682,10 @@ getButcherTable :: ODEMethod
614getButcherTable method = unsafePerformIO $ do 682getButcherTable method = unsafePerformIO $ do
615 -- ARKode seems to want an ODE in order to set and then get the 683 -- ARKode seems to want an ODE in order to set and then get the
616 -- Butcher tableau so here's one to keep it happy 684 -- Butcher tableau so here's one to keep it happy
617 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble 685 let funI :: CDouble -> V.Vector CDouble -> V.Vector CDouble
618 fun _t ys = V.fromList [ ys V.! 0 ] 686 funI _t ys = V.fromList [ ys V.! 0 ]
687 let funE :: CDouble -> V.Vector CDouble -> V.Vector CDouble
688 funE _t ys = V.fromList [ ys V.! 0 ]
619 f0 = V.fromList [ 1.0 ] 689 f0 = V.fromList [ 1.0 ]
620 ts = V.fromList [ 0.0 ] 690 ts = V.fromList [ 0.0 ]
621 dim = V.length f0 691 dim = V.length f0
@@ -634,9 +704,16 @@ getButcherTable method = unsafePerformIO $ do
634 btCsMut <- V.thaw btCs 704 btCsMut <- V.thaw btCs
635 btBsMut <- V.thaw btBs 705 btBsMut <- V.thaw btBs
636 btB2sMut <- V.thaw btB2s 706 btB2sMut <- V.thaw btB2s
637 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt 707 let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
638 funIO x y f _ptr = do 708 funIOI x y f _ptr = do
639 fImm <- fun x <$> getDataFromContents dim y 709 fImm <- funI x <$> getDataFromContents dim y
710 putDataInContents fImm dim f
711 -- FIXME: I don't understand what this comment means
712 -- Unsafe since the function will be called many times.
713 [CU.exp| int{ 0 } |]
714 let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
715 funIOE x y f _ptr = do
716 fImm <- funE x <$> getDataFromContents dim y
640 putDataInContents fImm dim f 717 putDataInContents fImm dim f
641 -- FIXME: I don't understand what this comment means 718 -- FIXME: I don't understand what this comment means
642 -- Unsafe since the function will be called many times. 719 -- Unsafe since the function will be called many times.
@@ -665,11 +742,16 @@ getButcherTable method = unsafePerformIO $ do
665 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 742 arkode_mem = ARKodeCreate(); /* Create the solver memory */
666 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; 743 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
667 744
668 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); 745 flag = ARKodeInit(arkode_mem, $fun:(int (* funIOE) (double t, SunVector y[], SunVector dydt[], void * params)), $fun:(int (* funIOI) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
669 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 746 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
670 747
748 if ($(int mN) >= MIN_DIRK_NUM) {
671 flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); 749 flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN));
672 if (check_flag(&flag, "ARKode", 1)) return 1; 750 if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1;
751 } else {
752 flag = ARKodeSetERKTableNum(arkode_mem, $(int mN));
753 if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1;
754 }
673 755
674 int s, q, p; 756 int s, q, p;
675 realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); 757 realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));