summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-04-05 12:10:29 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-04-05 12:10:29 +0100
commit6fcc1b01cecc88f1a8eb1608667368c7e72048aa (patch)
treefa64d59de7e3adb1aef9490090a86002de91e2b9
parenta918ec611a6a54c1260349591369ec33d8e873c3 (diff)
Start of mirroring hmatrix-gsl ODE module
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs145
1 files changed, 112 insertions, 33 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 8358954..0973c82 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -6,8 +6,34 @@
6{-# LANGUAGE OverloadedStrings #-} 6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-} 7{-# LANGUAGE ScopedTypeVariables #-}
8 8
9-----------------------------------------------------------------------------
9-- | 10-- |
10-- Module: Numeric.Sundials.ARKode 11-- Module : Numeric.Sundials.ARKode
12-- Copyright : Dominic Steinitz 2018,
13-- Novadiscovery 2018
14-- License : BSD
15-- Maintainer : Dominic Steinitz
16-- Stability : provisional
17--
18-- Solution of ordinary differential equation (ODE) initial value problems.
19--
20-- <https://computation.llnl.gov/projects/sundials/sundials-software>
21--
22-- A simple example:
23--
24-- @
25-- import Numeric.Sundials.ARKode
26-- import Numeric.LinearAlgebra
27-- import Graphics.Plot(mplot)
28--
29-- xdot t [x,v] = [v, -0.95*x - 0.1*v]
30--
31-- ts = linspace 100 (0,20 :: Double)
32--
33-- sol = odeSolve xdot [10,0] ts
34--
35-- main = mplot (ts : toColumns sol)
36-- @
11-- 37--
12-- KVAERNO_4_2_3 38-- KVAERNO_4_2_3
13-- 39--
@@ -29,19 +55,34 @@
29-- \end{array} 55-- \end{array}
30-- \] 56-- \]
31-- 57--
32module Numeric.Sundials.ARKode.ODE ( solveOde 58-- SDIRK_5_3_4
33 , odeSolve 59--
60-- \[
61-- \begin{array}{c|ccccc}
62-- c_1 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\
63-- c_2 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\
64-- c_3 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\
65-- c_4 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\
66-- c_5 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\
67-- \end{array}
68-- \]
69-----------------------------------------------------------------------------
70module Numeric.Sundials.ARKode.ODE ( odeSolve
71 , odeSolveV
72 , odeSolveVWith
73 , odeSolve'
34 , getButcherTable 74 , getButcherTable
35 , getBT 75 , getBT
36 , btGet 76 , btGet
37 , ODEMethod(..) 77 , ODEMethod(..)
38 , odeSolveV 78 , StepControl(..)
39 ) where 79 ) where
40 80
41import qualified Language.C.Inline as C 81import qualified Language.C.Inline as C
42import qualified Language.C.Inline.Unsafe as CU 82import qualified Language.C.Inline.Unsafe as CU
43 83
44import Data.Monoid ((<>)) 84import Data.Monoid ((<>))
85import Data.Maybe (isJust)
45 86
46import Foreign.C.Types 87import Foreign.C.Types
47import Foreign.Ptr (Ptr) 88import Foreign.Ptr (Ptr)
@@ -60,9 +101,11 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
60 subMatrix, rows, cols, toLists) 101 subMatrix, rows, cols, toLists)
61 102
62import qualified Types as T 103import qualified Types as T
63import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) 104import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4)
64import qualified Arkode as B 105import qualified Arkode as B
65 106
107import Debug.Trace
108
66 109
67C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) 110C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
68 111
@@ -139,14 +182,17 @@ data SundialsDiagnostics = SundialsDiagnostics {
139 , _aRKDlsGetNumRhsEvals :: Int 182 , _aRKDlsGetNumRhsEvals :: Int
140 } deriving Show 183 } deriving Show
141 184
185type Jacobian = Double -> Vector Double -> Matrix Double
186
142-- | Stepping functions 187-- | Stepping functions
143data ODEMethod = SDIRK_2_1_2 188data ODEMethod = SDIRK_2_1_2 Jacobian
144 | KVAERNO_4_2_3 189 | KVAERNO_4_2_3 Jacobian
190 | SDIRK_5_3_4 Jacobian
145 191
146instance Enum ODEMethod where 192getMethod :: ODEMethod -> Int
147 fromEnum SDIRK_2_1_2 = sDIRK_2_1_2 193getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2
148 fromEnum KVAERNO_4_2_3 = kVAERNO_4_2_3 194getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3
149 toEnum _ = error "toEnum not defined for ODEMethod" 195getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4
150 196
151-- | A version of 'odeSolveVWith' with reasonable default step control. 197-- | A version of 'odeSolveVWith' with reasonable default step control.
152odeSolveV 198odeSolveV
@@ -160,16 +206,34 @@ odeSolveV
160 -> Matrix Double -- ^ solution 206 -> Matrix Double -- ^ solution
161odeSolveV _meth _hi _epsAbs _epsRel = undefined 207odeSolveV _meth _hi _epsAbs _epsRel = undefined
162 208
163odeSolve :: ODEMethod 209-- | A version of 'odeSolveV' with reasonable default parameters and
210-- system of equations defined using lists. FIXME: we should say
211-- something about the fact we could use the Jacobian but don't for
212-- compatibility with hmatrix-gsl.
213odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
214 -> [Double] -- ^ initial conditions
215 -> Vector Double -- ^ desired solution times
216 -> Matrix Double -- ^ solution
217odeSolve f y0 ts =
218 case odeSolveVWith (SDIRK_5_3_4 undefined) Nothing 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of
219 Left c -> error $ show c -- FIXME
220 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
221 where
222 us = toList ts
223 nR = length us
224 nC = length y0
225 g t x0 = V.fromList $ f t (V.toList x0)
226
227odeSolve' :: ODEMethod
164 -> (Double -> Vector Double -> Matrix Double) 228 -> (Double -> Vector Double -> Matrix Double)
165 -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 229 -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
166 -> [Double] -- ^ initial conditions 230 -> [Double] -- ^ initial conditions
167 -> Vector Double -- ^ desired solution times 231 -> Vector Double -- ^ desired solution times
168 -> Matrix Double -- ^ solution 232 -> Matrix Double -- ^ solution
169odeSolve method jac f y0 ts = 233odeSolve' method jac f y0 ts =
170 case solveOde method jac' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of 234 case odeSolveVWith method (pure jac') 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of
171 Left c -> error $ show c -- FIXME 235 Left c -> error $ show c -- FIXME
172 Right (v, _) -> (nR >< nC) (V.toList v) 236 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
173 where 237 where
174 us = toList ts 238 us = toList ts
175 nR = length us 239 nR = length us
@@ -182,26 +246,27 @@ odeSolve method jac f y0 ts =
182 nc = fromIntegral $ cols m 246 nc = fromIntegral $ cols m
183 vs = V.fromList $ map coerce $ concat $ toLists m 247 vs = V.fromList $ map coerce $ concat $ toLists m
184 248
185solveOde :: 249odeSolveVWith ::
186 ODEMethod 250 ODEMethod
187 -> (Double -> V.Vector Double -> T.SunMatrix) 251 -> (Maybe (Double -> V.Vector Double -> T.SunMatrix))
188 -> Double 252 -> Double
189 -> Double 253 -> Double
190 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 254 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
191 -> V.Vector Double -- ^ Initial conditions 255 -> V.Vector Double -- ^ Initial conditions
192 -> V.Vector Double -- ^ Desired solution times 256 -> V.Vector Double -- ^ Desired solution times
193 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 257 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
194solveOde method jac relTol absTol f y0 tt = 258odeSolveVWith method jac relTol absTol f y0 tt =
195 case solveOdeC (fromIntegral $ fromEnum method) jacH (CDouble relTol) (CDouble absTol) 259 case solveOdeC (fromIntegral $ getMethod method) jacH (CDouble relTol) (CDouble absTol)
196 (coerce f) (coerce y0) (coerce tt) of 260 (coerce f) (coerce y0) (coerce tt) of
197 Left c -> Left $ fromIntegral c 261 Left c -> Left $ fromIntegral c
198 Right (v, d) -> Right (coerce v, d) 262 Right (v, d) -> Right (coerce v, d)
199 where 263 where
200 jacH t v = jac (coerce t) (coerce v) 264 jacH :: Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)
265 jacH = fmap (\g -> (\t v -> g (coerce t) (coerce v))) jac
201 266
202solveOdeC :: 267solveOdeC ::
203 CInt -> 268 CInt ->
204 (CDouble -> V.Vector CDouble -> T.SunMatrix) -> 269 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
205 CDouble -> 270 CDouble ->
206 CDouble -> 271 CDouble ->
207 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 272 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
@@ -235,15 +300,19 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
235 -- FIXME: I don't understand what this comment means 300 -- FIXME: I don't understand what this comment means
236 -- Unsafe since the function will be called many times. 301 -- Unsafe since the function will be called many times.
237 [CU.exp| int{ 0 } |] 302 [CU.exp| int{ 0 } |]
238 let jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> 303 let isJac :: CInt
304 isJac = fromIntegral $ fromEnum $ isJust jacH
305 jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix ->
239 Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> 306 Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector ->
240 IO CInt 307 IO CInt
241 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do 308 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
242 j <- jacH t <$> getDataFromContents dim y 309 case jacH of
243 putMatrixDataFromContents j jacS 310 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined"
244 -- FIXME: I don't understand what this comment means 311 Just jacI -> do j <- jacI t <$> getDataFromContents dim y
245 -- Unsafe since the function will be called many times. 312 putMatrixDataFromContents j jacS
246 [CU.exp| int{ 0 } |] 313 -- FIXME: I don't understand what this comment means
314 -- Unsafe since the function will be called many times.
315 [CU.exp| int{ 0 } |]
247 316
248 res <- [C.block| int { 317 res <- [C.block| int {
249 /* general problem variables */ 318 /* general problem variables */
@@ -292,9 +361,10 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
292 /* Linear solver interface */ 361 /* Linear solver interface */
293 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ 362 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
294 363
295 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); 364 if ($(int isJac)) {
296 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; 365 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
297 366 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1;
367 }
298 368
299 /* Store initial conditions */ 369 /* Store initial conditions */
300 for (j = 0; j < NEQ; j++) { 370 for (j = 0; j < NEQ; j++) {
@@ -389,9 +459,10 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
389btGet :: ODEMethod -> Matrix Double 459btGet :: ODEMethod -> Matrix Double
390btGet method = case getBT method of 460btGet method = case getBT method of
391 Left c -> error $ show c -- FIXME 461 Left c -> error $ show c -- FIXME
392 -- FIXME 462 Right (v, sqp) -> subMatrix (0, 0) (s, s) $
393 Right (v, _sqp) -> subMatrix (0, 0) (2, 2) $
394 (B.arkSMax >< B.arkSMax) (V.toList v) 463 (B.arkSMax >< B.arkSMax) (V.toList v)
464 where
465 s = fromIntegral $ sqp V.! 0
395 466
396getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) 467getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int)
397getBT method = case getButcherTable method of 468getBT method = case getButcherTable method of
@@ -410,7 +481,7 @@ getButcherTable method = unsafePerformIO $ do
410 nEq :: CLong 481 nEq :: CLong
411 nEq = fromIntegral dim 482 nEq = fromIntegral dim
412 mN :: CInt 483 mN :: CInt
413 mN = fromIntegral $ fromEnum method 484 mN = fromIntegral $ getMethod method
414 485
415 -- FIXME: I believe these gets taken from the ghc heap and so should 486 -- FIXME: I believe these gets taken from the ghc heap and so should
416 -- be subject to garbage collection. 487 -- be subject to garbage collection.
@@ -493,3 +564,11 @@ getButcherTable method = unsafePerformIO $ do
493 return $ Right (x, y) 564 return $ Right (x, y)
494 else do 565 else do
495 return $ Left res 566 return $ Left res
567
568-- | Adaptive step-size control functions. FIXME: It may not be
569-- possible to scale the tolerances for the derivatives in sundials so
570-- for now we ignore them and emit a warning.
571data StepControl = X Double Double -- ^ abs. and rel. tolerance for x(t)
572 | X' Double Double -- ^ abs. and rel. tolerance for x'(t)
573 | XX' Double Double Double Double -- ^ include both via rel. tolerance scaling factors a_x, a_x'
574 | ScXX' Double Double Double Double (Vector Double) -- ^ scale abs. tolerance of x(t) components