From 6148374bbf11e443ed3c64eb4add3f20d612f362 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 28 Mar 2018 18:50:54 +0100 Subject: Make more methods available --- packages/sundials/src/Bar.hsc | 23 +++++- packages/sundials/src/Main.hs | 6 +- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 96 +++++++++++++++------- 3 files changed, 91 insertions(+), 34 deletions(-) (limited to 'packages') diff --git a/packages/sundials/src/Bar.hsc b/packages/sundials/src/Bar.hsc index 7db0d4a..7d53af9 100644 --- a/packages/sundials/src/Bar.hsc +++ b/packages/sundials/src/Bar.hsc @@ -23,6 +23,23 @@ getData ptr = (#peek BazType, data) ptr arkSMax :: Int arkSMax = #const ARK_S_MAX - - - +-- FIXME: We could just use inline-c instead +sDIRK_2_1_2 :: Int +sDIRK_2_1_2 = #const SDIRK_2_1_2 +-- #define BILLINGTON_3_3_2 13 +-- #define TRBDF2_3_3_2 14 +kVAERNO_4_2_3 :: Int +kVAERNO_4_2_3 = #const KVAERNO_4_2_3 +-- #define ARK324L2SA_DIRK_4_2_3 16 +-- #define CASH_5_2_4 17 +-- #define CASH_5_3_4 18 +-- #define SDIRK_5_3_4 19 +-- #define KVAERNO_5_3_4 20 +-- #define ARK436L2SA_DIRK_6_3_4 21 +-- #define KVAERNO_7_4_5 22 +-- #define ARK548L2SA_DIRK_8_4_5 23 + +-- #define DEFAULT_DIRK_2 SDIRK_2_1_2 +-- #define DEFAULT_DIRK_3 ARK324L2SA_DIRK_4_2_3 +-- #define DEFAULT_DIRK_4 SDIRK_5_3_4 +-- #define DEFAULT_DIRK_5 ARK548L2SA_DIRK_8_4_5 diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index c81d1a3..71bcbac 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -73,17 +73,17 @@ main = do -- \end{array} -- $$ - let res = btGet + let res = btGet SDIRK_2_1_2 putStrLn $ show res putStrLn $ butcherTableauTex res - let res = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) + let res = odeSolve KVAERNO_4_2_3 brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) putStrLn $ show res renderRasterific "diagrams/brusselator.png" (D.dims2D 500.0 500.0) (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res)) - let res = odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) + let res = odeSolve KVAERNO_4_2_3 stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) putStrLn $ show res renderRasterific "diagrams/stiffish.png" (D.dims2D 500.0 500.0) diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index ff4ede8..d0c58dc 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -9,7 +9,7 @@ -- | -- Module: Numeric.Sundials.ARKode -- --- Blah +-- KVAERNO_4_2_3 -- -- \[ -- \begin{array}{c|cccc} @@ -20,11 +20,21 @@ -- \end{array} -- \] -- +-- SDIRK_2_1_2 +-- +-- \[ +-- \begin{array}{c|cc} +-- c_1 & 1.0 & 0.0 \\ +-- c_2 & -1.0 & 1.0 \\ +-- \end{array} +-- \] +-- module Numeric.Sundials.Arkode.ODE ( solveOde , odeSolve , getButcherTable , getBT , btGet + , ODEMethod(..) ) where import qualified Language.C.Inline as C @@ -48,6 +58,7 @@ import Numeric.LinearAlgebra.Devel (createVector) import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) import qualified Types as T +import Bar (sDIRK_2_1_2, kVAERNO_4_2_3) import qualified Bar as B C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -113,13 +124,36 @@ data SundialsDiagnostics = SundialsDiagnostics { , aRKDlsGetNumRhsEvals :: Int } deriving Show -odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) +-- | Stepping functions +data ODEMethod = SDIRK_2_1_2 + | KVAERNO_4_2_3 + +instance Enum ODEMethod where + fromEnum SDIRK_2_1_2 = sDIRK_2_1_2 + fromEnum KVAERNO_4_2_3 = kVAERNO_4_2_3 + toEnum _ = error "toEnum not defined for ODEMethod" + +-- | A version of 'odeSolveVWith' with reasonable default step control. +odeSolveV + :: ODEMethod + -> Double -- ^ initial step size + -> Double -- ^ absolute tolerance for the state vector + -> Double -- ^ relative tolerance for the state vector + -> (Double -> Vector Double -> Vector Double) -- ^ x'(t,x) + -> Vector Double -- ^ initial conditions + -> Vector Double -- ^ desired solution times + -> Matrix Double -- ^ solution +odeSolveV meth hi epsAbs epsRel = undefined + +odeSolve :: ODEMethod + -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> [Double] -- ^ initial conditions -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution -odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of - Left c -> error $ show c -- FIXME - Right (v, _) -> (nR >< nC) (V.toList v) +odeSolve method f y0 ts = + case solveOde method g (V.fromList y0) (V.fromList $ toList ts) of + Left c -> error $ show c -- FIXME + Right (v, _) -> (nR >< nC) (V.toList v) where us = toList ts nR = length us @@ -127,20 +161,23 @@ odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of g t x0 = V.fromList $ f t (V.toList x0) solveOde :: - (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) - -> V.Vector Double -- ^ Initial conditions - -> V.Vector Double -- ^ Desired solution times - -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of - Left c -> Left $ fromIntegral c - Right (v, d) -> Right (coerce v, d) + ODEMethod + -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> V.Vector Double -- ^ Initial conditions + -> V.Vector Double -- ^ Desired solution times + -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution +solveOde method f y0 tt = + case solveOdeC (fromIntegral $ fromEnum method) (coerce f) (coerce y0) (coerce tt) of + Left c -> Left $ fromIntegral c + Right (v, d) -> Right (coerce v, d) solveOdeC :: + CInt -> (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) - -> V.Vector CDouble -- ^ Initial conditions - -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution -solveOdeC fun f0 ts = unsafePerformIO $ do + -> V.Vector CDouble -- ^ Initial conditions + -> V.Vector CDouble -- ^ Desired solution times + -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution +solveOdeC method fun f0 ts = unsafePerformIO $ do let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim @@ -226,7 +263,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); } - flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); + flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); if (check_flag(&flag, "ARKode", 1)) return 1; int s, q, p; @@ -330,18 +367,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do else do return $ Left res -btGet :: Matrix Double -btGet = case getBT of - Left c -> error $ show c -- FIXME - Right (v, sqp) -> subMatrix (0, 0) (4, 4) $ (B.arkSMax >< B.arkSMax) (V.toList v) +btGet :: ODEMethod -> Matrix Double +btGet method = case getBT method of + Left c -> error $ show c -- FIXME + -- FIXME + Right (v, sqp) -> subMatrix (0, 0) (2, 2) $ (B.arkSMax >< B.arkSMax) (V.toList v) -getBT :: Either Int (V.Vector Double, V.Vector Int) -getBT = case getButcherTable of - Left c -> Left $ fromIntegral c - Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) +getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) +getBT method = case getButcherTable method of + Left c -> Left $ fromIntegral c + Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) -getButcherTable :: Either CInt ((V.Vector CDouble), V.Vector CInt) -getButcherTable = unsafePerformIO $ do +getButcherTable :: ODEMethod -> Either CInt ((V.Vector CDouble), V.Vector CInt) +getButcherTable method = unsafePerformIO $ do -- arkode seems to want an ODE in order to set and then get the -- Butcher tableau so here's one to keep it happy let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble @@ -351,6 +389,8 @@ getButcherTable = unsafePerformIO $ do dim = V.length f0 nEq :: CLong nEq = fromIntegral dim + mN :: CInt + mN = fromIntegral $ fromEnum method -- FIXME: I believe these gets taken from the ghc heap and so should -- be subject to garbage collection. @@ -396,7 +436,7 @@ getButcherTable = unsafePerformIO $ do flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); if (check_flag(&flag, "ARKodeInit", 1)) return 1; - flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); + flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); if (check_flag(&flag, "ARKode", 1)) return 1; int s, q, p; -- cgit v1.2.3