From 411bee3b1f984459ce7a496c655dda333ddf6f32 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 20 Apr 2018 11:45:30 +0100 Subject: Support all implicit methods and 1 explicit method --- packages/sundials/src/Arkode.hsc | 70 ++++++---- packages/sundials/src/Main.hs | 6 +- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 152 ++++++++++++++++----- 3 files changed, 166 insertions(+), 62 deletions(-) diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Arkode.hsc index 4ef95e2..023d102 100644 --- a/packages/sundials/src/Arkode.hsc +++ b/packages/sundials/src/Arkode.hsc @@ -45,46 +45,68 @@ getData ptr = (#peek SunContent, data) ptr arkSMax :: Int arkSMax = #const ARK_S_MAX +mIN_DIRK_NUM, mAX_DIRK_NUM :: Int +mIN_DIRK_NUM = #const MIN_DIRK_NUM +mAX_DIRK_NUM = #const MAX_DIRK_NUM + -- FIXME: We could just use inline-c instead --- /* Butcher table accessors -- implicit */ +-- Butcher table accessors -- implicit sDIRK_2_1_2 :: Int sDIRK_2_1_2 = #const SDIRK_2_1_2 --- #define BILLINGTON_3_3_2 13 --- #define TRBDF2_3_3_2 14 +bILLINGTON_3_3_2 :: Int +bILLINGTON_3_3_2 = #const BILLINGTON_3_3_2 +tRBDF2_3_3_2 :: Int +tRBDF2_3_3_2 = #const TRBDF2_3_3_2 kVAERNO_4_2_3 :: Int kVAERNO_4_2_3 = #const KVAERNO_4_2_3 --- #define ARK324L2SA_DIRK_4_2_3 16 --- #define CASH_5_2_4 17 --- #define CASH_5_3_4 18 --- #define SDIRK_5_3_4 19 +aRK324L2SA_DIRK_4_2_3 :: Int +aRK324L2SA_DIRK_4_2_3 = #const ARK324L2SA_DIRK_4_2_3 +cASH_5_2_4 :: Int +cASH_5_2_4 = #const CASH_5_2_4 +cASH_5_3_4 :: Int +cASH_5_3_4 = #const CASH_5_3_4 sDIRK_5_3_4 :: Int sDIRK_5_3_4 = #const SDIRK_5_3_4 --- #define KVAERNO_5_3_4 20 --- #define ARK436L2SA_DIRK_6_3_4 21 --- #define KVAERNO_7_4_5 22 --- #define ARK548L2SA_DIRK_8_4_5 23 +kVAERNO_5_3_4 :: Int +kVAERNO_5_3_4 = #const KVAERNO_5_3_4 +aRK436L2SA_DIRK_6_3_4 :: Int +aRK436L2SA_DIRK_6_3_4 = #const ARK436L2SA_DIRK_6_3_4 +kVAERNO_7_4_5 :: Int +kVAERNO_7_4_5 = #const KVAERNO_7_4_5 +aRK548L2SA_DIRK_8_4_5 :: Int +aRK548L2SA_DIRK_8_4_5 = #const ARK548L2SA_DIRK_8_4_5 -- #define DEFAULT_DIRK_2 SDIRK_2_1_2 -- #define DEFAULT_DIRK_3 ARK324L2SA_DIRK_4_2_3 -- #define DEFAULT_DIRK_4 SDIRK_5_3_4 -- #define DEFAULT_DIRK_5 ARK548L2SA_DIRK_8_4_5 --- /* Butcher table accessors -- explicit */ --- #define HEUN_EULER_2_1_2 0 --- #define BOGACKI_SHAMPINE_4_2_3 1 --- #define ARK324L2SA_ERK_4_2_3 2 --- #define ZONNEVELD_5_3_4 3 --- #define ARK436L2SA_ERK_6_3_4 4 --- #define SAYFY_ABURUB_6_3_4 5 --- #define CASH_KARP_6_4_5 6 +-- Butcher table accessors -- explicit +hEUN_EULER_2_1_2 :: Int +hEUN_EULER_2_1_2 = #const HEUN_EULER_2_1_2 +bOGACKI_SHAMPINE_4_2_3 :: Int +bOGACKI_SHAMPINE_4_2_3 = #const BOGACKI_SHAMPINE_4_2_3 +aRK324L2SA_ERK_4_2_3 :: Int +aRK324L2SA_ERK_4_2_3 = #const ARK324L2SA_ERK_4_2_3 +zONNEVELD_5_3_4 :: Int +zONNEVELD_5_3_4 = #const ZONNEVELD_5_3_4 +aRK436L2SA_ERK_6_3_4 :: Int +aRK436L2SA_ERK_6_3_4 = #const ARK436L2SA_ERK_6_3_4 +sAYFY_ABURUB_6_3_4 :: Int +sAYFY_ABURUB_6_3_4 = #const SAYFY_ABURUB_6_3_4 +cASH_KARP_6_4_5 :: Int +cASH_KARP_6_4_5 = #const CASH_KARP_6_4_5 fEHLBERG_6_4_5 :: Int fEHLBERG_6_4_5 = #const FEHLBERG_6_4_5 --- #define FEHLBERG_6_4_5 7 --- #define DORMAND_PRINCE_7_4_5 8 --- #define ARK548L2SA_ERK_8_4_5 9 --- #define VERNER_8_5_6 10 --- #define FEHLBERG_13_7_8 11 +dORMAND_PRINCE_7_4_5 :: Int +dORMAND_PRINCE_7_4_5 = #const DORMAND_PRINCE_7_4_5 +aRK548L2SA_ERK_8_4_5 :: Int +aRK548L2SA_ERK_8_4_5 = #const ARK548L2SA_ERK_8_4_5 +vERNER_8_5_6 :: Int +vERNER_8_5_6 = #const VERNER_8_5_6 +fEHLBERG_13_7_8 :: Int +fEHLBERG_13_7_8 = #const FEHLBERG_13_7_8 -- #define DEFAULT_ERK_2 HEUN_EULER_2_1_2 -- #define DEFAULT_ERK_3 BOGACKI_SHAMPINE_4_2_3 diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index d8f3f3d..78921a7 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -128,9 +128,9 @@ main = do putStrLn $ show resB putStrLn $ butcherTableauTex resB - -- let resC = butcherTable (FEHLBERG_6_4_5 undefined) - -- putStrLn $ show resC - -- putStrLn $ butcherTableauTex resC + let resC = butcherTable (FEHLBERG_6_4_5 undefined) + putStrLn $ show resC + putStrLn $ butcherTableauTex resC let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) renderRasterific "diagrams/brusselator.png" diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index f6f6884..1460680 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -8,7 +8,7 @@ ----------------------------------------------------------------------------- -- | --- Module : Numeric.Sundials.ARKode +-- Module : Numeric.Sundials.ARKode.ODE -- Copyright : Dominic Steinitz 2018, -- Novadiscovery 2018 -- License : BSD @@ -138,11 +138,9 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), size, subVector) import qualified Types as T -import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4, fEHLBERG_6_4_5) +import Arkode import qualified Arkode as B -import Debug.Trace - C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -222,28 +220,88 @@ data SundialsDiagnostics = SundialsDiagnostics { type Jacobian = Double -> Vector Double -> Matrix Double -- | Stepping functions -data ODEMethod = SDIRK_2_1_2 Jacobian - | KVAERNO_4_2_3 Jacobian - | SDIRK_5_3_4 Jacobian +data ODEMethod = SDIRK_2_1_2 Jacobian + | SDIRK_2_1_2' + | BILLINGTON_3_3_2 Jacobian + | BILLINGTON_3_3_2' + | TRBDF2_3_3_2 Jacobian + | TRBDF2_3_3_2' + | KVAERNO_4_2_3 Jacobian + | KVAERNO_4_2_3' + | ARK324L2SA_DIRK_4_2_3 Jacobian + | ARK324L2SA_DIRK_4_2_3' + | CASH_5_2_4 Jacobian + | CASH_5_2_4' + | CASH_5_3_4 Jacobian + | CASH_5_3_4' + | SDIRK_5_3_4 Jacobian | SDIRK_5_3_4' - | FEHLBERG_6_4_5 Jacobian + | KVAERNO_5_3_4 Jacobian + | KVAERNO_5_3_4' + | ARK436L2SA_DIRK_6_3_4 Jacobian + | ARK436L2SA_DIRK_6_3_4' + | KVAERNO_7_4_5 Jacobian + | KVAERNO_7_4_5' + | ARK548L2SA_DIRK_8_4_5 Jacobian + | ARK548L2SA_DIRK_8_4_5' + | FEHLBERG_6_4_5 Jacobian | FEHLBERG_6_4_5' getMethod :: ODEMethod -> Int -getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 -getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 -getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 -getMethod (SDIRK_5_3_4' ) = sDIRK_5_3_4 -getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5 -getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5 +getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 +getMethod (SDIRK_2_1_2') = sDIRK_2_1_2 +getMethod (BILLINGTON_3_3_2 _) = bILLINGTON_3_3_2 +getMethod (BILLINGTON_3_3_2') = bILLINGTON_3_3_2 +getMethod (TRBDF2_3_3_2 _) = tRBDF2_3_3_2 +getMethod (TRBDF2_3_3_2') = tRBDF2_3_3_2 +getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 +getMethod (KVAERNO_4_2_3') = kVAERNO_4_2_3 +getMethod (ARK324L2SA_DIRK_4_2_3 _) = aRK324L2SA_DIRK_4_2_3 +getMethod (ARK324L2SA_DIRK_4_2_3') = aRK324L2SA_DIRK_4_2_3 +getMethod (CASH_5_2_4 _) = cASH_5_2_4 +getMethod (CASH_5_2_4') = cASH_5_2_4 +getMethod (CASH_5_3_4 _) = cASH_5_3_4 +getMethod (CASH_5_3_4') = cASH_5_3_4 +getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 +getMethod (SDIRK_5_3_4') = sDIRK_5_3_4 +getMethod (KVAERNO_5_3_4 _) = kVAERNO_5_3_4 +getMethod (KVAERNO_5_3_4') = kVAERNO_5_3_4 +getMethod (ARK436L2SA_DIRK_6_3_4 _) = aRK436L2SA_DIRK_6_3_4 +getMethod (ARK436L2SA_DIRK_6_3_4') = aRK436L2SA_DIRK_6_3_4 +getMethod (KVAERNO_7_4_5 _) = kVAERNO_7_4_5 +getMethod (KVAERNO_7_4_5') = kVAERNO_7_4_5 +getMethod (ARK548L2SA_DIRK_8_4_5 _) = aRK548L2SA_DIRK_8_4_5 +getMethod (ARK548L2SA_DIRK_8_4_5') = aRK548L2SA_DIRK_8_4_5 +getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5 +getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5 getJacobian :: ODEMethod -> Maybe Jacobian -getJacobian (SDIRK_2_1_2 j) = Just j -getJacobian (KVAERNO_4_2_3 j) = Just j -getJacobian (SDIRK_5_3_4 j) = Just j -getJacobian (SDIRK_5_3_4' ) = Nothing -getJacobian (FEHLBERG_6_4_5 j) = Just j -getJacobian (FEHLBERG_6_4_5' ) = Nothing +getJacobian (SDIRK_2_1_2 j) = Just j +getJacobian (SDIRK_2_1_2') = Nothing +getJacobian (BILLINGTON_3_3_2 j) = Just j +getJacobian (BILLINGTON_3_3_2') = Nothing +getJacobian (TRBDF2_3_3_2 j) = Just j +getJacobian (TRBDF2_3_3_2') = Nothing +getJacobian (KVAERNO_4_2_3 j) = Just j +getJacobian (KVAERNO_4_2_3') = Nothing +getJacobian (ARK324L2SA_DIRK_4_2_3 j) = Just j +getJacobian (ARK324L2SA_DIRK_4_2_3') = Nothing +getJacobian (CASH_5_2_4 j) = Just j +getJacobian (CASH_5_2_4') = Nothing +getJacobian (CASH_5_3_4 j) = Just j +getJacobian (CASH_5_3_4') = Nothing +getJacobian (SDIRK_5_3_4 j) = Just j +getJacobian (SDIRK_5_3_4') = Nothing +getJacobian (KVAERNO_5_3_4 j) = Just j +getJacobian (KVAERNO_5_3_4') = Nothing +getJacobian (ARK436L2SA_DIRK_6_3_4 j) = Just j +getJacobian (ARK436L2SA_DIRK_6_3_4') = Nothing +getJacobian (KVAERNO_7_4_5 j) = Just j +getJacobian (KVAERNO_7_4_5') = Nothing +getJacobian (ARK548L2SA_DIRK_8_4_5 j) = Just j +getJacobian (ARK548L2SA_DIRK_8_4_5') = Nothing +getJacobian (FEHLBERG_6_4_5 j) = Just j +getJacobian (FEHLBERG_6_4_5' ) = Nothing -- | A version of 'odeSolveVWith' with reasonable default step control. odeSolveV @@ -262,9 +320,9 @@ odeSolveV -> Matrix Double -- ^ solution odeSolveV meth hi epsAbs epsRel f y0 ts = case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of - Left c -> error $ show c -- FIXME + Left c -> error $ show c -- FIXME -- FIXME: Can we do better than using lists? - Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) + Right (v, _d) -> (nR >< nC) (V.toList v) where us = toList ts nR = length us @@ -282,8 +340,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of - Left c -> error $ show c -- FIXME - Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) + Left c -> error $ show c -- FIXME + Right (v, _d) -> (nR >< nC) (V.toList v) where us = toList ts nR = length us @@ -449,8 +507,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO /* Here we use the C types defined in helpers.h which tie up with */ /* the Haskell types defined in Types */ - flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); - if (check_flag(&flag, "ARKodeInit", 1)) return 1; + if ($(int method) < MIN_DIRK_NUM) { + flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); + if (check_flag(&flag, "ARKodeInit", 1)) return 1; + } else { + flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); + if (check_flag(&flag, "ARKodeInit", 1)) return 1; + } /* Set routines */ flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); @@ -486,8 +549,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO } /* Explicitly set the method */ - flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); - if (check_flag(&flag, "ARKode", 1)) return 1; + if ($(int method) >= MIN_DIRK_NUM) { + flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); + if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; + } else { + flag = ARKodeSetERKTableNum(arkode_mem, $(int method)); + if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; + } /* Main time-stepping loop: calls ARKode to perform the integration */ /* Stops when the final time has been reached */ @@ -614,8 +682,10 @@ getButcherTable :: ODEMethod getButcherTable method = unsafePerformIO $ do -- ARKode seems to want an ODE in order to set and then get the -- Butcher tableau so here's one to keep it happy - let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble - fun _t ys = V.fromList [ ys V.! 0 ] + let funI :: CDouble -> V.Vector CDouble -> V.Vector CDouble + funI _t ys = V.fromList [ ys V.! 0 ] + let funE :: CDouble -> V.Vector CDouble -> V.Vector CDouble + funE _t ys = V.fromList [ ys V.! 0 ] f0 = V.fromList [ 1.0 ] ts = V.fromList [ 0.0 ] dim = V.length f0 @@ -634,9 +704,16 @@ getButcherTable method = unsafePerformIO $ do btCsMut <- V.thaw btCs btBsMut <- V.thaw btBs btB2sMut <- V.thaw btB2s - let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt - funIO x y f _ptr = do - fImm <- fun x <$> getDataFromContents dim y + let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt + funIOI x y f _ptr = do + fImm <- funI x <$> getDataFromContents dim y + putDataInContents fImm dim f + -- FIXME: I don't understand what this comment means + -- Unsafe since the function will be called many times. + [CU.exp| int{ 0 } |] + let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt + funIOE x y f _ptr = do + fImm <- funE x <$> getDataFromContents dim y putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. @@ -665,11 +742,16 @@ getButcherTable method = unsafePerformIO $ do arkode_mem = ARKodeCreate(); /* Create the solver memory */ if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; - flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); + flag = ARKodeInit(arkode_mem, $fun:(int (* funIOE) (double t, SunVector y[], SunVector dydt[], void * params)), $fun:(int (* funIOI) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); if (check_flag(&flag, "ARKodeInit", 1)) return 1; + if ($(int mN) >= MIN_DIRK_NUM) { flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); - if (check_flag(&flag, "ARKode", 1)) return 1; + if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; + } else { + flag = ARKodeSetERKTableNum(arkode_mem, $(int mN)); + if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; + } int s, q, p; realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); -- cgit v1.2.3