summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-28 18:50:54 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-03-28 18:50:54 +0100
commit6148374bbf11e443ed3c64eb4add3f20d612f362 (patch)
treec8e14a7f8a0ea85fda3c151d3beb77a0ee534b56 /packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
parent8612df2834fa652e6f9ab9b6e18617b6eac267a6 (diff)
Make more methods available
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs96
1 files changed, 68 insertions, 28 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index ff4ede8..d0c58dc 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -9,7 +9,7 @@
9-- | 9-- |
10-- Module: Numeric.Sundials.ARKode 10-- Module: Numeric.Sundials.ARKode
11-- 11--
12-- Blah 12-- KVAERNO_4_2_3
13-- 13--
14-- \[ 14-- \[
15-- \begin{array}{c|cccc} 15-- \begin{array}{c|cccc}
@@ -20,11 +20,21 @@
20-- \end{array} 20-- \end{array}
21-- \] 21-- \]
22-- 22--
23-- SDIRK_2_1_2
24--
25-- \[
26-- \begin{array}{c|cc}
27-- c_1 & 1.0 & 0.0 \\
28-- c_2 & -1.0 & 1.0 \\
29-- \end{array}
30-- \]
31--
23module Numeric.Sundials.Arkode.ODE ( solveOde 32module Numeric.Sundials.Arkode.ODE ( solveOde
24 , odeSolve 33 , odeSolve
25 , getButcherTable 34 , getButcherTable
26 , getBT 35 , getBT
27 , btGet 36 , btGet
37 , ODEMethod(..)
28 ) where 38 ) where
29 39
30import qualified Language.C.Inline as C 40import qualified Language.C.Inline as C
@@ -48,6 +58,7 @@ import Numeric.LinearAlgebra.Devel (createVector)
48import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) 58import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix)
49 59
50import qualified Types as T 60import qualified Types as T
61import Bar (sDIRK_2_1_2, kVAERNO_4_2_3)
51import qualified Bar as B 62import qualified Bar as B
52 63
53C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) 64C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
@@ -113,13 +124,36 @@ data SundialsDiagnostics = SundialsDiagnostics {
113 , aRKDlsGetNumRhsEvals :: Int 124 , aRKDlsGetNumRhsEvals :: Int
114 } deriving Show 125 } deriving Show
115 126
116odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 127-- | Stepping functions
128data ODEMethod = SDIRK_2_1_2
129 | KVAERNO_4_2_3
130
131instance Enum ODEMethod where
132 fromEnum SDIRK_2_1_2 = sDIRK_2_1_2
133 fromEnum KVAERNO_4_2_3 = kVAERNO_4_2_3
134 toEnum _ = error "toEnum not defined for ODEMethod"
135
136-- | A version of 'odeSolveVWith' with reasonable default step control.
137odeSolveV
138 :: ODEMethod
139 -> Double -- ^ initial step size
140 -> Double -- ^ absolute tolerance for the state vector
141 -> Double -- ^ relative tolerance for the state vector
142 -> (Double -> Vector Double -> Vector Double) -- ^ x'(t,x)
143 -> Vector Double -- ^ initial conditions
144 -> Vector Double -- ^ desired solution times
145 -> Matrix Double -- ^ solution
146odeSolveV meth hi epsAbs epsRel = undefined
147
148odeSolve :: ODEMethod
149 -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
117 -> [Double] -- ^ initial conditions 150 -> [Double] -- ^ initial conditions
118 -> Vector Double -- ^ desired solution times 151 -> Vector Double -- ^ desired solution times
119 -> Matrix Double -- ^ solution 152 -> Matrix Double -- ^ solution
120odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of 153odeSolve method f y0 ts =
121 Left c -> error $ show c -- FIXME 154 case solveOde method g (V.fromList y0) (V.fromList $ toList ts) of
122 Right (v, _) -> (nR >< nC) (V.toList v) 155 Left c -> error $ show c -- FIXME
156 Right (v, _) -> (nR >< nC) (V.toList v)
123 where 157 where
124 us = toList ts 158 us = toList ts
125 nR = length us 159 nR = length us
@@ -127,20 +161,23 @@ odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of
127 g t x0 = V.fromList $ f t (V.toList x0) 161 g t x0 = V.fromList $ f t (V.toList x0)
128 162
129solveOde :: 163solveOde ::
130 (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 164 ODEMethod
131 -> V.Vector Double -- ^ Initial conditions 165 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
132 -> V.Vector Double -- ^ Desired solution times 166 -> V.Vector Double -- ^ Initial conditions
133 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 167 -> V.Vector Double -- ^ Desired solution times
134solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of 168 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
135 Left c -> Left $ fromIntegral c 169solveOde method f y0 tt =
136 Right (v, d) -> Right (coerce v, d) 170 case solveOdeC (fromIntegral $ fromEnum method) (coerce f) (coerce y0) (coerce tt) of
171 Left c -> Left $ fromIntegral c
172 Right (v, d) -> Right (coerce v, d)
137 173
138solveOdeC :: 174solveOdeC ::
175 CInt ->
139 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 176 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
140 -> V.Vector CDouble -- ^ Initial conditions 177 -> V.Vector CDouble -- ^ Initial conditions
141 -> V.Vector CDouble -- ^ Desired solution times 178 -> V.Vector CDouble -- ^ Desired solution times
142 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 179 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
143solveOdeC fun f0 ts = unsafePerformIO $ do 180solveOdeC method fun f0 ts = unsafePerformIO $ do
144 let dim = V.length f0 181 let dim = V.length f0
145 nEq :: CLong 182 nEq :: CLong
146 nEq = fromIntegral dim 183 nEq = fromIntegral dim
@@ -226,7 +263,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
226 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); 263 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
227 } 264 }
228 265
229 flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); 266 flag = ARKodeSetIRKTableNum(arkode_mem, $(int method));
230 if (check_flag(&flag, "ARKode", 1)) return 1; 267 if (check_flag(&flag, "ARKode", 1)) return 1;
231 268
232 int s, q, p; 269 int s, q, p;
@@ -330,18 +367,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
330 else do 367 else do
331 return $ Left res 368 return $ Left res
332 369
333btGet :: Matrix Double 370btGet :: ODEMethod -> Matrix Double
334btGet = case getBT of 371btGet method = case getBT method of
335 Left c -> error $ show c -- FIXME 372 Left c -> error $ show c -- FIXME
336 Right (v, sqp) -> subMatrix (0, 0) (4, 4) $ (B.arkSMax >< B.arkSMax) (V.toList v) 373 -- FIXME
374 Right (v, sqp) -> subMatrix (0, 0) (2, 2) $ (B.arkSMax >< B.arkSMax) (V.toList v)
337 375
338getBT :: Either Int (V.Vector Double, V.Vector Int) 376getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int)
339getBT = case getButcherTable of 377getBT method = case getButcherTable method of
340 Left c -> Left $ fromIntegral c 378 Left c -> Left $ fromIntegral c
341 Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) 379 Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp)
342 380
343getButcherTable :: Either CInt ((V.Vector CDouble), V.Vector CInt) 381getButcherTable :: ODEMethod -> Either CInt ((V.Vector CDouble), V.Vector CInt)
344getButcherTable = unsafePerformIO $ do 382getButcherTable method = unsafePerformIO $ do
345 -- arkode seems to want an ODE in order to set and then get the 383 -- arkode seems to want an ODE in order to set and then get the
346 -- Butcher tableau so here's one to keep it happy 384 -- Butcher tableau so here's one to keep it happy
347 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble 385 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble
@@ -351,6 +389,8 @@ getButcherTable = unsafePerformIO $ do
351 dim = V.length f0 389 dim = V.length f0
352 nEq :: CLong 390 nEq :: CLong
353 nEq = fromIntegral dim 391 nEq = fromIntegral dim
392 mN :: CInt
393 mN = fromIntegral $ fromEnum method
354 394
355 -- FIXME: I believe these gets taken from the ghc heap and so should 395 -- FIXME: I believe these gets taken from the ghc heap and so should
356 -- be subject to garbage collection. 396 -- be subject to garbage collection.
@@ -396,7 +436,7 @@ getButcherTable = unsafePerformIO $ do
396 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); 436 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y);
397 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 437 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
398 438
399 flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); 439 flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN));
400 if (check_flag(&flag, "ARKode", 1)) return 1; 440 if (check_flag(&flag, "ARKode", 1)) return 1;
401 441
402 int s, q, p; 442 int s, q, p;