diff options
Diffstat (limited to 'packages/sundials/src/Numeric')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 96 |
1 files changed, 68 insertions, 28 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index ff4ede8..d0c58dc 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -9,7 +9,7 @@ | |||
9 | -- | | 9 | -- | |
10 | -- Module: Numeric.Sundials.ARKode | 10 | -- Module: Numeric.Sundials.ARKode |
11 | -- | 11 | -- |
12 | -- Blah | 12 | -- KVAERNO_4_2_3 |
13 | -- | 13 | -- |
14 | -- \[ | 14 | -- \[ |
15 | -- \begin{array}{c|cccc} | 15 | -- \begin{array}{c|cccc} |
@@ -20,11 +20,21 @@ | |||
20 | -- \end{array} | 20 | -- \end{array} |
21 | -- \] | 21 | -- \] |
22 | -- | 22 | -- |
23 | -- SDIRK_2_1_2 | ||
24 | -- | ||
25 | -- \[ | ||
26 | -- \begin{array}{c|cc} | ||
27 | -- c_1 & 1.0 & 0.0 \\ | ||
28 | -- c_2 & -1.0 & 1.0 \\ | ||
29 | -- \end{array} | ||
30 | -- \] | ||
31 | -- | ||
23 | module Numeric.Sundials.Arkode.ODE ( solveOde | 32 | module Numeric.Sundials.Arkode.ODE ( solveOde |
24 | , odeSolve | 33 | , odeSolve |
25 | , getButcherTable | 34 | , getButcherTable |
26 | , getBT | 35 | , getBT |
27 | , btGet | 36 | , btGet |
37 | , ODEMethod(..) | ||
28 | ) where | 38 | ) where |
29 | 39 | ||
30 | import qualified Language.C.Inline as C | 40 | import qualified Language.C.Inline as C |
@@ -48,6 +58,7 @@ import Numeric.LinearAlgebra.Devel (createVector) | |||
48 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) | 58 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) |
49 | 59 | ||
50 | import qualified Types as T | 60 | import qualified Types as T |
61 | import Bar (sDIRK_2_1_2, kVAERNO_4_2_3) | ||
51 | import qualified Bar as B | 62 | import qualified Bar as B |
52 | 63 | ||
53 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 64 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -113,13 +124,36 @@ data SundialsDiagnostics = SundialsDiagnostics { | |||
113 | , aRKDlsGetNumRhsEvals :: Int | 124 | , aRKDlsGetNumRhsEvals :: Int |
114 | } deriving Show | 125 | } deriving Show |
115 | 126 | ||
116 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 127 | -- | Stepping functions |
128 | data ODEMethod = SDIRK_2_1_2 | ||
129 | | KVAERNO_4_2_3 | ||
130 | |||
131 | instance Enum ODEMethod where | ||
132 | fromEnum SDIRK_2_1_2 = sDIRK_2_1_2 | ||
133 | fromEnum KVAERNO_4_2_3 = kVAERNO_4_2_3 | ||
134 | toEnum _ = error "toEnum not defined for ODEMethod" | ||
135 | |||
136 | -- | A version of 'odeSolveVWith' with reasonable default step control. | ||
137 | odeSolveV | ||
138 | :: ODEMethod | ||
139 | -> Double -- ^ initial step size | ||
140 | -> Double -- ^ absolute tolerance for the state vector | ||
141 | -> Double -- ^ relative tolerance for the state vector | ||
142 | -> (Double -> Vector Double -> Vector Double) -- ^ x'(t,x) | ||
143 | -> Vector Double -- ^ initial conditions | ||
144 | -> Vector Double -- ^ desired solution times | ||
145 | -> Matrix Double -- ^ solution | ||
146 | odeSolveV meth hi epsAbs epsRel = undefined | ||
147 | |||
148 | odeSolve :: ODEMethod | ||
149 | -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
117 | -> [Double] -- ^ initial conditions | 150 | -> [Double] -- ^ initial conditions |
118 | -> Vector Double -- ^ desired solution times | 151 | -> Vector Double -- ^ desired solution times |
119 | -> Matrix Double -- ^ solution | 152 | -> Matrix Double -- ^ solution |
120 | odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of | 153 | odeSolve method f y0 ts = |
121 | Left c -> error $ show c -- FIXME | 154 | case solveOde method g (V.fromList y0) (V.fromList $ toList ts) of |
122 | Right (v, _) -> (nR >< nC) (V.toList v) | 155 | Left c -> error $ show c -- FIXME |
156 | Right (v, _) -> (nR >< nC) (V.toList v) | ||
123 | where | 157 | where |
124 | us = toList ts | 158 | us = toList ts |
125 | nR = length us | 159 | nR = length us |
@@ -127,20 +161,23 @@ odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of | |||
127 | g t x0 = V.fromList $ f t (V.toList x0) | 161 | g t x0 = V.fromList $ f t (V.toList x0) |
128 | 162 | ||
129 | solveOde :: | 163 | solveOde :: |
130 | (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 164 | ODEMethod |
131 | -> V.Vector Double -- ^ Initial conditions | 165 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
132 | -> V.Vector Double -- ^ Desired solution times | 166 | -> V.Vector Double -- ^ Initial conditions |
133 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 167 | -> V.Vector Double -- ^ Desired solution times |
134 | solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of | 168 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution |
135 | Left c -> Left $ fromIntegral c | 169 | solveOde method f y0 tt = |
136 | Right (v, d) -> Right (coerce v, d) | 170 | case solveOdeC (fromIntegral $ fromEnum method) (coerce f) (coerce y0) (coerce tt) of |
171 | Left c -> Left $ fromIntegral c | ||
172 | Right (v, d) -> Right (coerce v, d) | ||
137 | 173 | ||
138 | solveOdeC :: | 174 | solveOdeC :: |
175 | CInt -> | ||
139 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 176 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
140 | -> V.Vector CDouble -- ^ Initial conditions | 177 | -> V.Vector CDouble -- ^ Initial conditions |
141 | -> V.Vector CDouble -- ^ Desired solution times | 178 | -> V.Vector CDouble -- ^ Desired solution times |
142 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | 179 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution |
143 | solveOdeC fun f0 ts = unsafePerformIO $ do | 180 | solveOdeC method fun f0 ts = unsafePerformIO $ do |
144 | let dim = V.length f0 | 181 | let dim = V.length f0 |
145 | nEq :: CLong | 182 | nEq :: CLong |
146 | nEq = fromIntegral dim | 183 | nEq = fromIntegral dim |
@@ -226,7 +263,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
226 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | 263 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); |
227 | } | 264 | } |
228 | 265 | ||
229 | flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); | 266 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); |
230 | if (check_flag(&flag, "ARKode", 1)) return 1; | 267 | if (check_flag(&flag, "ARKode", 1)) return 1; |
231 | 268 | ||
232 | int s, q, p; | 269 | int s, q, p; |
@@ -330,18 +367,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
330 | else do | 367 | else do |
331 | return $ Left res | 368 | return $ Left res |
332 | 369 | ||
333 | btGet :: Matrix Double | 370 | btGet :: ODEMethod -> Matrix Double |
334 | btGet = case getBT of | 371 | btGet method = case getBT method of |
335 | Left c -> error $ show c -- FIXME | 372 | Left c -> error $ show c -- FIXME |
336 | Right (v, sqp) -> subMatrix (0, 0) (4, 4) $ (B.arkSMax >< B.arkSMax) (V.toList v) | 373 | -- FIXME |
374 | Right (v, sqp) -> subMatrix (0, 0) (2, 2) $ (B.arkSMax >< B.arkSMax) (V.toList v) | ||
337 | 375 | ||
338 | getBT :: Either Int (V.Vector Double, V.Vector Int) | 376 | getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) |
339 | getBT = case getButcherTable of | 377 | getBT method = case getButcherTable method of |
340 | Left c -> Left $ fromIntegral c | 378 | Left c -> Left $ fromIntegral c |
341 | Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) | 379 | Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) |
342 | 380 | ||
343 | getButcherTable :: Either CInt ((V.Vector CDouble), V.Vector CInt) | 381 | getButcherTable :: ODEMethod -> Either CInt ((V.Vector CDouble), V.Vector CInt) |
344 | getButcherTable = unsafePerformIO $ do | 382 | getButcherTable method = unsafePerformIO $ do |
345 | -- arkode seems to want an ODE in order to set and then get the | 383 | -- arkode seems to want an ODE in order to set and then get the |
346 | -- Butcher tableau so here's one to keep it happy | 384 | -- Butcher tableau so here's one to keep it happy |
347 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble | 385 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble |
@@ -351,6 +389,8 @@ getButcherTable = unsafePerformIO $ do | |||
351 | dim = V.length f0 | 389 | dim = V.length f0 |
352 | nEq :: CLong | 390 | nEq :: CLong |
353 | nEq = fromIntegral dim | 391 | nEq = fromIntegral dim |
392 | mN :: CInt | ||
393 | mN = fromIntegral $ fromEnum method | ||
354 | 394 | ||
355 | -- FIXME: I believe these gets taken from the ghc heap and so should | 395 | -- FIXME: I believe these gets taken from the ghc heap and so should |
356 | -- be subject to garbage collection. | 396 | -- be subject to garbage collection. |
@@ -396,7 +436,7 @@ getButcherTable = unsafePerformIO $ do | |||
396 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | 436 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); |
397 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 437 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
398 | 438 | ||
399 | flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); | 439 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); |
400 | if (check_flag(&flag, "ARKode", 1)) return 1; | 440 | if (check_flag(&flag, "ARKode", 1)) return 1; |
401 | 441 | ||
402 | int s, q, p; | 442 | int s, q, p; |