summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs903
1 files changed, 903 insertions, 0 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
new file mode 100644
index 0000000..fafc237
--- /dev/null
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -0,0 +1,903 @@
1{-# LANGUAGE QuasiQuotes #-}
2{-# LANGUAGE TemplateHaskell #-}
3{-# LANGUAGE MultiWayIf #-}
4{-# LANGUAGE OverloadedStrings #-}
5{-# LANGUAGE ScopedTypeVariables #-}
6{-# LANGUAGE DeriveGeneric #-}
7{-# LANGUAGE TypeOperators #-}
8{-# LANGUAGE KindSignatures #-}
9{-# LANGUAGE TypeSynonymInstances #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE FlexibleContexts #-}
12
13-----------------------------------------------------------------------------
14-- |
15-- Module : Numeric.Sundials.ARKode.ODE
16-- Copyright : Dominic Steinitz 2018,
17-- Novadiscovery 2018
18-- License : BSD
19-- Maintainer : Dominic Steinitz
20-- Stability : provisional
21--
22-- Solution of ordinary differential equation (ODE) initial value problems.
23-- See <https://computation.llnl.gov/projects/sundials/sundials-software> for more detail.
24--
25-- A simple example:
26--
27-- <<diagrams/brusselator.png#diagram=brusselator&height=400&width=500>>
28--
29-- @
30-- import Numeric.Sundials.ARKode.ODE
31-- import Numeric.LinearAlgebra
32--
33-- import Plots as P
34-- import qualified Diagrams.Prelude as D
35-- import Diagrams.Backend.Rasterific
36--
37-- brusselator :: Double -> [Double] -> [Double]
38-- brusselator _t x = [ a - (w + 1) * u + v * u * u
39-- , w * u - v * u * u
40-- , (b - w) / eps - w * u
41-- ]
42-- where
43-- a = 1.0
44-- b = 3.5
45-- eps = 5.0e-6
46-- u = x !! 0
47-- v = x !! 1
48-- w = x !! 2
49--
50-- lSaxis :: [[Double]] -> P.Axis B D.V2 Double
51-- lSaxis xs = P.r2Axis &~ do
52-- let ts = xs!!0
53-- us = xs!!1
54-- vs = xs!!2
55-- ws = xs!!3
56-- P.linePlot' $ zip ts us
57-- P.linePlot' $ zip ts vs
58-- P.linePlot' $ zip ts ws
59--
60-- main = do
61-- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
62-- renderRasterific "diagrams/brusselator.png"
63-- (D.dims2D 500.0 500.0)
64-- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1))
65-- @
66--
67-- With Sundials ARKode, it is possible to retrieve the Butcher tableau for the solver.
68--
69-- @
70-- import Numeric.Sundials.ARKode.ODE
71-- import Numeric.LinearAlgebra
72--
73-- import Data.List (intercalate)
74--
75-- import Text.PrettyPrint.HughesPJClass
76--
77--
78-- butcherTableauTex :: ButcherTable -> String
79-- butcherTableauTex (ButcherTable m c b b2) =
80-- render $
81-- vcat [ text ("\n\\begin{array}{c|" ++ (concat $ replicate n "c") ++ "}")
82-- , us
83-- , text "\\hline"
84-- , text bs <+> text "\\\\"
85-- , text b2s <+> text "\\\\"
86-- , text "\\end{array}"
87-- ]
88-- where
89-- n = rows m
90-- rs = toLists m
91-- ss = map (\r -> intercalate " & " $ map show r) rs
92-- ts = zipWith (\i r -> show i ++ " & " ++ r) (toList c) ss
93-- us = vcat $ map (\r -> text r <+> text "\\\\") ts
94-- bs = " & " ++ (intercalate " & " $ map show $ toList b)
95-- b2s = " & " ++ (intercalate " & " $ map show $ toList b2)
96--
97-- main :: IO ()
98-- main = do
99--
100-- let res = butcherTable (SDIRK_2_1_2 undefined)
101-- putStrLn $ show res
102-- putStrLn $ butcherTableauTex res
103--
104-- let resA = butcherTable (KVAERNO_4_2_3 undefined)
105-- putStrLn $ show resA
106-- putStrLn $ butcherTableauTex resA
107--
108-- let resB = butcherTable (SDIRK_5_3_4 undefined)
109-- putStrLn $ show resB
110-- putStrLn $ butcherTableauTex resB
111-- @
112--
113-- Using the code above from the examples gives
114--
115-- KVAERNO_4_2_3
116--
117-- \[
118-- \begin{array}{c|cccc}
119-- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\
120-- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\
121-- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\
122-- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\
123-- \hline
124-- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\
125-- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\
126-- \end{array}
127-- \]
128--
129-- SDIRK_2_1_2
130--
131-- \[
132-- \begin{array}{c|cc}
133-- 1.0 & 1.0 & 0.0 \\
134-- 0.0 & -1.0 & 1.0 \\
135-- \hline
136-- & 0.5 & 0.5 \\
137-- & 1.0 & 0.0 \\
138-- \end{array}
139-- \]
140--
141-- SDIRK_5_3_4
142--
143-- \[
144-- \begin{array}{c|ccccc}
145-- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\
146-- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\
147-- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\
148-- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\
149-- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\
150-- \hline
151-- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\
152-- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\
153-- \end{array}
154-- \]
155-----------------------------------------------------------------------------
156module Numeric.Sundials.ARKode.ODE ( odeSolve
157 , odeSolveV
158 , odeSolveVWith
159 , odeSolveVWith'
160 , ButcherTable(..)
161 , butcherTable
162 , ODEMethod(..)
163 , StepControl(..)
164 ) where
165
166import qualified Language.C.Inline as C
167import qualified Language.C.Inline.Unsafe as CU
168
169import Data.Monoid ((<>))
170import Data.Maybe (isJust)
171
172import Foreign.C.Types (CDouble, CInt, CLong)
173import Foreign.Ptr (Ptr)
174import Foreign.Storable (poke)
175
176import qualified Data.Vector.Storable as V
177
178import Data.Coerce (coerce)
179import System.IO.Unsafe (unsafePerformIO)
180import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..),
181 from, conName)
182
183import Numeric.LinearAlgebra.Devel (createVector)
184
185import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows,
186 cols, toLists, size, reshape,
187 subVector, subMatrix, (><))
188
189import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..))
190import qualified Numeric.Sundials.Arkode as T
191import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax,
192 sDIRK_2_1_2,
193 bILLINGTON_3_3_2,
194 tRBDF2_3_3_2,
195 kVAERNO_4_2_3,
196 aRK324L2SA_DIRK_4_2_3,
197 cASH_5_2_4,
198 cASH_5_3_4,
199 sDIRK_5_3_4,
200 kVAERNO_5_3_4,
201 aRK436L2SA_DIRK_6_3_4,
202 kVAERNO_7_4_5,
203 aRK548L2SA_DIRK_8_4_5,
204 hEUN_EULER_2_1_2,
205 bOGACKI_SHAMPINE_4_2_3,
206 aRK324L2SA_ERK_4_2_3,
207 zONNEVELD_5_3_4,
208 aRK436L2SA_ERK_6_3_4,
209 sAYFY_ABURUB_6_3_4,
210 cASH_KARP_6_4_5,
211 fEHLBERG_6_4_5,
212 dORMAND_PRINCE_7_4_5,
213 aRK548L2SA_ERK_8_4_5,
214 vERNER_8_5_6,
215 fEHLBERG_13_7_8)
216
217
218C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
219
220C.include "<stdlib.h>"
221C.include "<stdio.h>"
222C.include "<math.h>"
223C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts.
224C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
225C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
226C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
227C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
228C.include "<sundials/sundials_types.h>" -- definition of type realtype
229C.include "<sundials/sundials_math.h>"
230C.include "../../../helpers.h"
231C.include "Numeric/Sundials/Arkode_hsc.h"
232
233
234-- | Stepping functions
235data ODEMethod = SDIRK_2_1_2 Jacobian
236 | SDIRK_2_1_2'
237 | BILLINGTON_3_3_2 Jacobian
238 | BILLINGTON_3_3_2'
239 | TRBDF2_3_3_2 Jacobian
240 | TRBDF2_3_3_2'
241 | KVAERNO_4_2_3 Jacobian
242 | KVAERNO_4_2_3'
243 | ARK324L2SA_DIRK_4_2_3 Jacobian
244 | ARK324L2SA_DIRK_4_2_3'
245 | CASH_5_2_4 Jacobian
246 | CASH_5_2_4'
247 | CASH_5_3_4 Jacobian
248 | CASH_5_3_4'
249 | SDIRK_5_3_4 Jacobian
250 | SDIRK_5_3_4'
251 | KVAERNO_5_3_4 Jacobian
252 | KVAERNO_5_3_4'
253 | ARK436L2SA_DIRK_6_3_4 Jacobian
254 | ARK436L2SA_DIRK_6_3_4'
255 | KVAERNO_7_4_5 Jacobian
256 | KVAERNO_7_4_5'
257 | ARK548L2SA_DIRK_8_4_5 Jacobian
258 | ARK548L2SA_DIRK_8_4_5'
259 | HEUN_EULER_2_1_2 Jacobian
260 | HEUN_EULER_2_1_2'
261 | BOGACKI_SHAMPINE_4_2_3 Jacobian
262 | BOGACKI_SHAMPINE_4_2_3'
263 | ARK324L2SA_ERK_4_2_3 Jacobian
264 | ARK324L2SA_ERK_4_2_3'
265 | ZONNEVELD_5_3_4 Jacobian
266 | ZONNEVELD_5_3_4'
267 | ARK436L2SA_ERK_6_3_4 Jacobian
268 | ARK436L2SA_ERK_6_3_4'
269 | SAYFY_ABURUB_6_3_4 Jacobian
270 | SAYFY_ABURUB_6_3_4'
271 | CASH_KARP_6_4_5 Jacobian
272 | CASH_KARP_6_4_5'
273 | FEHLBERG_6_4_5 Jacobian
274 | FEHLBERG_6_4_5'
275 | DORMAND_PRINCE_7_4_5 Jacobian
276 | DORMAND_PRINCE_7_4_5'
277 | ARK548L2SA_ERK_8_4_5 Jacobian
278 | ARK548L2SA_ERK_8_4_5'
279 | VERNER_8_5_6 Jacobian
280 | VERNER_8_5_6'
281 | FEHLBERG_13_7_8 Jacobian
282 | FEHLBERG_13_7_8'
283 deriving Generic
284
285constrName :: (HasConstructor (Rep a), Generic a)=> a -> String
286constrName = genericConstrName . from
287
288class HasConstructor (f :: * -> *) where
289 genericConstrName :: f x -> String
290
291instance HasConstructor f => HasConstructor (D1 c f) where
292 genericConstrName (M1 x) = genericConstrName x
293
294instance (HasConstructor x, HasConstructor y) => HasConstructor (x :+: y) where
295 genericConstrName (L1 l) = genericConstrName l
296 genericConstrName (R1 r) = genericConstrName r
297
298instance Constructor c => HasConstructor (C1 c f) where
299 genericConstrName x = conName x
300
301instance Show ODEMethod where
302 show x = constrName x
303
304-- FIXME: We can probably do better here with generics
305getMethod :: ODEMethod -> Int
306getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2
307getMethod (SDIRK_2_1_2') = sDIRK_2_1_2
308getMethod (BILLINGTON_3_3_2 _) = bILLINGTON_3_3_2
309getMethod (BILLINGTON_3_3_2') = bILLINGTON_3_3_2
310getMethod (TRBDF2_3_3_2 _) = tRBDF2_3_3_2
311getMethod (TRBDF2_3_3_2') = tRBDF2_3_3_2
312getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3
313getMethod (KVAERNO_4_2_3') = kVAERNO_4_2_3
314getMethod (ARK324L2SA_DIRK_4_2_3 _) = aRK324L2SA_DIRK_4_2_3
315getMethod (ARK324L2SA_DIRK_4_2_3') = aRK324L2SA_DIRK_4_2_3
316getMethod (CASH_5_2_4 _) = cASH_5_2_4
317getMethod (CASH_5_2_4') = cASH_5_2_4
318getMethod (CASH_5_3_4 _) = cASH_5_3_4
319getMethod (CASH_5_3_4') = cASH_5_3_4
320getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4
321getMethod (SDIRK_5_3_4') = sDIRK_5_3_4
322getMethod (KVAERNO_5_3_4 _) = kVAERNO_5_3_4
323getMethod (KVAERNO_5_3_4') = kVAERNO_5_3_4
324getMethod (ARK436L2SA_DIRK_6_3_4 _) = aRK436L2SA_DIRK_6_3_4
325getMethod (ARK436L2SA_DIRK_6_3_4') = aRK436L2SA_DIRK_6_3_4
326getMethod (KVAERNO_7_4_5 _) = kVAERNO_7_4_5
327getMethod (KVAERNO_7_4_5') = kVAERNO_7_4_5
328getMethod (ARK548L2SA_DIRK_8_4_5 _) = aRK548L2SA_DIRK_8_4_5
329getMethod (ARK548L2SA_DIRK_8_4_5') = aRK548L2SA_DIRK_8_4_5
330getMethod (HEUN_EULER_2_1_2 _) = hEUN_EULER_2_1_2
331getMethod (HEUN_EULER_2_1_2') = hEUN_EULER_2_1_2
332getMethod (BOGACKI_SHAMPINE_4_2_3 _) = bOGACKI_SHAMPINE_4_2_3
333getMethod (BOGACKI_SHAMPINE_4_2_3') = bOGACKI_SHAMPINE_4_2_3
334getMethod (ARK324L2SA_ERK_4_2_3 _) = aRK324L2SA_ERK_4_2_3
335getMethod (ARK324L2SA_ERK_4_2_3') = aRK324L2SA_ERK_4_2_3
336getMethod (ZONNEVELD_5_3_4 _) = zONNEVELD_5_3_4
337getMethod (ZONNEVELD_5_3_4') = zONNEVELD_5_3_4
338getMethod (ARK436L2SA_ERK_6_3_4 _) = aRK436L2SA_ERK_6_3_4
339getMethod (ARK436L2SA_ERK_6_3_4') = aRK436L2SA_ERK_6_3_4
340getMethod (SAYFY_ABURUB_6_3_4 _) = sAYFY_ABURUB_6_3_4
341getMethod (SAYFY_ABURUB_6_3_4') = sAYFY_ABURUB_6_3_4
342getMethod (CASH_KARP_6_4_5 _) = cASH_KARP_6_4_5
343getMethod (CASH_KARP_6_4_5') = cASH_KARP_6_4_5
344getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5
345getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5
346getMethod (DORMAND_PRINCE_7_4_5 _) = dORMAND_PRINCE_7_4_5
347getMethod (DORMAND_PRINCE_7_4_5') = dORMAND_PRINCE_7_4_5
348getMethod (ARK548L2SA_ERK_8_4_5 _) = aRK548L2SA_ERK_8_4_5
349getMethod (ARK548L2SA_ERK_8_4_5') = aRK548L2SA_ERK_8_4_5
350getMethod (VERNER_8_5_6 _) = vERNER_8_5_6
351getMethod (VERNER_8_5_6') = vERNER_8_5_6
352getMethod (FEHLBERG_13_7_8 _) = fEHLBERG_13_7_8
353getMethod (FEHLBERG_13_7_8') = fEHLBERG_13_7_8
354
355getJacobian :: ODEMethod -> Maybe Jacobian
356getJacobian (SDIRK_2_1_2 j) = Just j
357getJacobian (BILLINGTON_3_3_2 j) = Just j
358getJacobian (TRBDF2_3_3_2 j) = Just j
359getJacobian (KVAERNO_4_2_3 j) = Just j
360getJacobian (ARK324L2SA_DIRK_4_2_3 j) = Just j
361getJacobian (CASH_5_2_4 j) = Just j
362getJacobian (CASH_5_3_4 j) = Just j
363getJacobian (SDIRK_5_3_4 j) = Just j
364getJacobian (KVAERNO_5_3_4 j) = Just j
365getJacobian (ARK436L2SA_DIRK_6_3_4 j) = Just j
366getJacobian (KVAERNO_7_4_5 j) = Just j
367getJacobian (ARK548L2SA_DIRK_8_4_5 j) = Just j
368getJacobian (HEUN_EULER_2_1_2 j) = Just j
369getJacobian (BOGACKI_SHAMPINE_4_2_3 j) = Just j
370getJacobian (ARK324L2SA_ERK_4_2_3 j) = Just j
371getJacobian (ZONNEVELD_5_3_4 j) = Just j
372getJacobian (ARK436L2SA_ERK_6_3_4 j) = Just j
373getJacobian (SAYFY_ABURUB_6_3_4 j) = Just j
374getJacobian (CASH_KARP_6_4_5 j) = Just j
375getJacobian (FEHLBERG_6_4_5 j) = Just j
376getJacobian (DORMAND_PRINCE_7_4_5 j) = Just j
377getJacobian (ARK548L2SA_ERK_8_4_5 j) = Just j
378getJacobian (VERNER_8_5_6 j) = Just j
379getJacobian (FEHLBERG_13_7_8 j) = Just j
380getJacobian _ = Nothing
381
382-- | A version of 'odeSolveVWith' with reasonable default step control.
383odeSolveV
384 :: ODEMethod
385 -> Maybe Double -- ^ initial step size - by default, ARKode
386 -- estimates the initial step size to be the
387 -- solution \(h\) of the equation
388 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
389 -- \(\ddot{y}\) is an estimated value of the
390 -- second derivative of the solution at \(t_0\)
391 -> Double -- ^ absolute tolerance for the state vector
392 -> Double -- ^ relative tolerance for the state vector
393 -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
394 -> Vector Double -- ^ initial conditions
395 -> Vector Double -- ^ desired solution times
396 -> Matrix Double -- ^ solution
397odeSolveV meth hi epsAbs epsRel f y0 ts =
398 odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts
399 where
400 g t x0 = coerce $ f t x0
401
402-- | A version of 'odeSolveV' with reasonable default parameters and
403-- system of equations defined using lists. FIXME: we should say
404-- something about the fact we could use the Jacobian but don't for
405-- compatibility with hmatrix-gsl.
406odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
407 -> [Double] -- ^ initial conditions
408 -> Vector Double -- ^ desired solution times
409 -> Matrix Double -- ^ solution
410odeSolve f y0 ts =
411 -- FIXME: These tolerances are different from the ones in GSL
412 odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts)
413 where
414 g t x0 = V.fromList $ f t (V.toList x0)
415
416odeSolveVWith ::
417 ODEMethod
418 -> StepControl
419 -> Maybe Double -- ^ initial step size - by default, ARKode
420 -- estimates the initial step size to be the
421 -- solution \(h\) of the equation
422 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
423 -- \(\ddot{y}\) is an estimated value of the second
424 -- derivative of the solution at \(t_0\)
425 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
426 -> V.Vector Double -- ^ Initial conditions
427 -> V.Vector Double -- ^ Desired solution times
428 -> Matrix Double -- ^ Error code or solution
429odeSolveVWith method control initStepSize f y0 tt =
430 case odeSolveVWith' opts method control initStepSize f y0 tt of
431 Left c -> error $ show c -- FIXME
432 Right (v, _d) -> v
433 where
434 opts = ODEOpts { maxNumSteps = 10000
435 , minStep = 1.0e-12
436 , relTol = error "relTol"
437 , absTols = error "absTol"
438 , initStep = error "initStep"
439 , maxFail = 10
440 }
441
442odeSolveVWith' ::
443 ODEOpts
444 -> ODEMethod
445 -> StepControl
446 -> Maybe Double -- ^ initial step size - by default, ARKode
447 -- estimates the initial step size to be the
448 -- solution \(h\) of the equation
449 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
450 -- \(\ddot{y}\) is an estimated value of the second
451 -- derivative of the solution at \(t_0\)
452 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
453 -> V.Vector Double -- ^ Initial conditions
454 -> V.Vector Double -- ^ Desired solution times
455 -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution
456odeSolveVWith' opts method control initStepSize f y0 tt =
457 case solveOdeC (fromIntegral $ maxFail opts)
458 (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
459 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
460 (coerce f) (coerce y0) (coerce tt) of
461 Left c -> Left $ fromIntegral c
462 Right (v, d) -> Right (reshape l (coerce v), d)
463 where
464 l = size y0
465 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
466 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
467 scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol)
468 -- FIXME; Should we check that the length of ss is correct?
469 scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol)
470 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
471 getJacobian method
472 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
473 where
474 nr = fromIntegral $ rows m
475 nc = fromIntegral $ cols m
476 -- FIXME: efficiency
477 vs = V.fromList $ map coerce $ concat $ toLists m
478
479solveOdeC ::
480 CInt ->
481 CLong ->
482 CDouble ->
483 CInt ->
484 Maybe CDouble ->
485 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
486 (V.Vector CDouble, CDouble) ->
487 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
488 -> V.Vector CDouble -- ^ Initial conditions
489 -> V.Vector CDouble -- ^ Desired solution times
490 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
491solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize
492 jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do
493
494 let isInitStepSize :: CInt
495 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
496 ss :: CDouble
497 ss = case initStepSize of
498 -- It would be better to put an error message here but
499 -- inline-c seems to evaluate this even if it is never
500 -- used :(
501 Nothing -> 0.0
502 Just x -> x
503
504 let dim = V.length f0
505 nEq :: CLong
506 nEq = fromIntegral dim
507 nTs :: CInt
508 nTs = fromIntegral $ V.length ts
509 -- FIXME: I believe this gets taken from the ghc heap and so should
510 -- be subject to garbage collection.
511 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
512 qMatMut <- V.thaw quasiMatrixRes
513 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
514 diagMut <- V.thaw diagnostics
515 -- We need the types that sundials expects. These are tied together
516 -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty!
517 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
518 funIO x y f _ptr = do
519 -- Convert the pointer we get from C (y) to a vector, and then
520 -- apply the user-supplied function.
521 fImm <- fun x <$> getDataFromContents dim y
522 -- Fill in the provided pointer with the resulting vector.
523 putDataInContents fImm dim f
524 -- FIXME: I don't understand what this comment means
525 -- Unsafe since the function will be called many times.
526 [CU.exp| int{ 0 } |]
527 let isJac :: CInt
528 isJac = fromIntegral $ fromEnum $ isJust jacH
529 jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix ->
530 Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector ->
531 IO CInt
532 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
533 case jacH of
534 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined"
535 Just jacI -> do j <- jacI t <$> getDataFromContents dim y
536 poke jacS j
537 -- FIXME: I don't understand what this comment means
538 -- Unsafe since the function will be called many times.
539 [CU.exp| int{ 0 } |]
540
541 res <- [C.block| int {
542 /* general problem variables */
543
544 int flag; /* reusable error-checking flag */
545 int i, j; /* reusable loop indices */
546 N_Vector y = NULL; /* empty vector for storing solution */
547 N_Vector tv = NULL; /* empty vector for storing absolute tolerances */
548 SUNMatrix A = NULL; /* empty matrix for linear solver */
549 SUNLinearSolver LS = NULL; /* empty linear solver object */
550 void *arkode_mem = NULL; /* empty ARKode memory structure */
551 realtype t;
552 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
553
554 /* general problem parameters */
555
556 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
557 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
558
559 /* Initialize data structures */
560
561 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
562 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
563 /* Specify initial condition */
564 for (i = 0; i < NEQ; i++) {
565 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
566 };
567
568 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
569 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
570 /* Specify tolerances */
571 for (i = 0; i < NEQ; i++) {
572 NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i];
573 };
574
575 arkode_mem = ARKodeCreate(); /* Create the solver memory */
576 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
577
578 /* Call ARKodeInit to initialize the integrator memory and specify the */
579 /* right-hand side function in y'=f(t,y), the inital time T0, and */
580 /* the initial dependent variable vector y. Note: we treat the */
581 /* problem as fully implicit and set f_E to NULL and f_I to f. */
582
583 /* Here we use the C types defined in helpers.h which tie up with */
584 /* the Haskell types defined in CLangToHaskellTypes */
585 if ($(int method) < MIN_DIRK_NUM) {
586 flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y);
587 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
588 } else {
589 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
590 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
591 }
592
593 flag = ARKodeSetMinStep(arkode_mem, $(double minStep_));
594 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1;
595 flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_));
596 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1;
597 flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails));
598 if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1;
599
600 /* Set routines */
601 flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv);
602 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1;
603
604 /* Initialize dense matrix data structure and solver */
605 A = SUNDenseMatrix(NEQ, NEQ);
606 if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
607 LS = SUNDenseLinearSolver(y, A);
608 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
609
610 /* Attach matrix and linear solver */
611 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A);
612 if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1;
613
614 /* Set the initial step size if there is one */
615 if ($(int isInitStepSize)) {
616 /* FIXME: We could check if the initial step size is 0 */
617 /* or even NaN and then throw an error */
618 flag = ARKodeSetInitStep(arkode_mem, $(double ss));
619 if (check_flag(&flag, "ARKodeSetInitStep", 1)) return 1;
620 }
621
622 /* Set the Jacobian if there is one */
623 if ($(int isJac)) {
624 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
625 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1;
626 }
627
628 /* Store initial conditions */
629 for (j = 0; j < NEQ; j++) {
630 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
631 }
632
633 /* Explicitly set the method */
634 if ($(int method) >= MIN_DIRK_NUM) {
635 flag = ARKodeSetIRKTableNum(arkode_mem, $(int method));
636 if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1;
637 } else {
638 flag = ARKodeSetERKTableNum(arkode_mem, $(int method));
639 if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1;
640 }
641
642 /* Main time-stepping loop: calls ARKode to perform the integration */
643 /* Stops when the final time has been reached */
644 for (i = 1; i < $(int nTs); i++) {
645
646 flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */
647 if (check_flag(&flag, "ARKode", 1)) break;
648
649 /* Store the results for Haskell */
650 for (j = 0; j < NEQ; j++) {
651 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
652 }
653
654 /* unsuccessful solve: break */
655 if (flag < 0) {
656 fprintf(stderr,"Solver failure, stopping integration\n");
657 break;
658 }
659 }
660
661 /* Get some final statistics on how the solve progressed */
662
663 flag = ARKodeGetNumSteps(arkode_mem, &nst);
664 check_flag(&flag, "ARKodeGetNumSteps", 1);
665 ($vec-ptr:(long int *diagMut))[0] = nst;
666
667 flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a);
668 check_flag(&flag, "ARKodeGetNumStepAttempts", 1);
669 ($vec-ptr:(long int *diagMut))[1] = nst_a;
670
671 flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi);
672 check_flag(&flag, "ARKodeGetNumRhsEvals", 1);
673 ($vec-ptr:(long int *diagMut))[2] = nfe;
674 ($vec-ptr:(long int *diagMut))[3] = nfi;
675
676 flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups);
677 check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1);
678 ($vec-ptr:(long int *diagMut))[4] = nsetups;
679
680 flag = ARKodeGetNumErrTestFails(arkode_mem, &netf);
681 check_flag(&flag, "ARKodeGetNumErrTestFails", 1);
682 ($vec-ptr:(long int *diagMut))[5] = netf;
683
684 flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni);
685 check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1);
686 ($vec-ptr:(long int *diagMut))[6] = nni;
687
688 flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn);
689 check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1);
690 ($vec-ptr:(long int *diagMut))[7] = ncfn;
691
692 flag = ARKDlsGetNumJacEvals(arkode_mem, &nje);
693 check_flag(&flag, "ARKDlsGetNumJacEvals", 1);
694 ($vec-ptr:(long int *diagMut))[8] = ncfn;
695
696 flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS);
697 check_flag(&flag, "ARKDlsGetNumRhsEvals", 1);
698 ($vec-ptr:(long int *diagMut))[9] = ncfn;
699
700 /* Clean up and return */
701 N_VDestroy(y); /* Free y vector */
702 N_VDestroy(tv); /* Free tv vector */
703 ARKodeFree(&arkode_mem); /* Free integrator memory */
704 SUNLinSolFree(LS); /* Free linear solver */
705 SUNMatDestroy(A); /* Free A matrix */
706
707 return flag;
708 } |]
709 if res == 0
710 then do
711 preD <- V.freeze diagMut
712 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
713 (fromIntegral $ preD V.!1)
714 (fromIntegral $ preD V.!2)
715 (fromIntegral $ preD V.!3)
716 (fromIntegral $ preD V.!4)
717 (fromIntegral $ preD V.!5)
718 (fromIntegral $ preD V.!6)
719 (fromIntegral $ preD V.!7)
720 (fromIntegral $ preD V.!8)
721 (fromIntegral $ preD V.!9)
722 m <- V.freeze qMatMut
723 return $ Right (m, d)
724 else do
725 return $ Left res
726
727data ButcherTable = ButcherTable { am :: Matrix Double
728 , cv :: Vector Double
729 , bv :: Vector Double
730 , b2v :: Vector Double
731 }
732 deriving Show
733
734data ButcherTable' a = ButcherTable' { am' :: V.Vector a
735 , cv' :: V.Vector a
736 , bv' :: V.Vector a
737 , b2v' :: V.Vector a
738 }
739 deriving Show
740
741butcherTable :: ODEMethod -> ButcherTable
742butcherTable method =
743 case getBT method of
744 Left c -> error $ show c -- FIXME
745 Right (ButcherTable' v w x y, sqp) ->
746 ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v)
747 , cv = subVector 0 s w
748 , bv = subVector 0 s x
749 , b2v = subVector 0 s y
750 }
751 where
752 s = fromIntegral $ sqp V.! 0
753
754getBT :: ODEMethod -> Either Int (ButcherTable' Double, V.Vector Int)
755getBT method = case getButcherTable method of
756 Left c ->
757 Left $ fromIntegral c
758 Right (ButcherTable' a b c d, sqp) ->
759 Right $ ( ButcherTable' (coerce a) (coerce b) (coerce c) (coerce d)
760 , V.map fromIntegral sqp )
761
762getButcherTable :: ODEMethod
763 -> Either CInt (ButcherTable' CDouble, V.Vector CInt)
764getButcherTable method = unsafePerformIO $ do
765 -- ARKode seems to want an ODE in order to set and then get the
766 -- Butcher tableau so here's one to keep it happy
767 let funI :: CDouble -> V.Vector CDouble -> V.Vector CDouble
768 funI _t ys = V.fromList [ ys V.! 0 ]
769 let funE :: CDouble -> V.Vector CDouble -> V.Vector CDouble
770 funE _t ys = V.fromList [ ys V.! 0 ]
771 f0 = V.fromList [ 1.0 ]
772 ts = V.fromList [ 0.0 ]
773 dim = V.length f0
774 nEq :: CLong
775 nEq = fromIntegral dim
776 mN :: CInt
777 mN = fromIntegral $ getMethod method
778
779 btSQP :: V.Vector CInt <- createVector 3
780 btSQPMut <- V.thaw btSQP
781 btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax)
782 btAsMut <- V.thaw btAs
783 btCs :: V.Vector CDouble <- createVector arkSMax
784 btBs :: V.Vector CDouble <- createVector arkSMax
785 btB2s :: V.Vector CDouble <- createVector arkSMax
786 btCsMut <- V.thaw btCs
787 btBsMut <- V.thaw btBs
788 btB2sMut <- V.thaw btB2s
789 let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
790 funIOI x y f _ptr = do
791 fImm <- funI x <$> getDataFromContents dim y
792 putDataInContents fImm dim f
793 -- FIXME: I don't understand what this comment means
794 -- Unsafe since the function will be called many times.
795 [CU.exp| int{ 0 } |]
796 let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
797 funIOE x y f _ptr = do
798 fImm <- funE x <$> getDataFromContents dim y
799 putDataInContents fImm dim f
800 -- FIXME: I don't understand what this comment means
801 -- Unsafe since the function will be called many times.
802 [CU.exp| int{ 0 } |]
803 res <- [C.block| int {
804 /* general problem variables */
805
806 int flag; /* reusable error-checking flag */
807 N_Vector y = NULL; /* empty vector for storing solution */
808 void *arkode_mem = NULL; /* empty ARKode memory structure */
809 int i, j; /* reusable loop indices */
810
811 /* general problem parameters */
812
813 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
814 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars */
815
816 /* Initialize data structures */
817
818 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
819 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
820 /* Specify initial condition */
821 for (i = 0; i < NEQ; i++) {
822 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
823 };
824 arkode_mem = ARKodeCreate(); /* Create the solver memory */
825 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
826
827 flag = ARKodeInit(arkode_mem, $fun:(int (* funIOE) (double t, SunVector y[], SunVector dydt[], void * params)), $fun:(int (* funIOI) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
828 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
829
830 if ($(int mN) >= MIN_DIRK_NUM) {
831 flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN));
832 if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1;
833 } else {
834 flag = ARKodeSetERKTableNum(arkode_mem, $(int mN));
835 if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1;
836 }
837
838 int s, q, p;
839 realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
840 realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
841 realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
842 realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
843 realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
844 realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
845 realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
846 realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
847 flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e);
848 if (check_flag(&flag, "ARKode", 1)) return 1;
849 $vec-ptr:(int *btSQPMut)[0] = s;
850 $vec-ptr:(int *btSQPMut)[1] = q;
851 $vec-ptr:(int *btSQPMut)[2] = p;
852 for (i = 0; i < s; i++) {
853 for (j = 0; j < s; j++) {
854 /* FIXME: double should be realtype */
855 ($vec-ptr:(double *btAsMut))[i * ARK_S_MAX + j] = ai[i * ARK_S_MAX + j];
856 }
857 }
858
859 for (i = 0; i < s; i++) {
860 ($vec-ptr:(double *btCsMut))[i] = ci[i];
861 ($vec-ptr:(double *btBsMut))[i] = bi[i];
862 ($vec-ptr:(double *btB2sMut))[i] = b2i[i];
863 }
864
865 /* Clean up and return */
866 N_VDestroy(y); /* Free y vector */
867 ARKodeFree(&arkode_mem); /* Free integrator memory */
868
869 return flag;
870 } |]
871 if res == 0
872 then do
873 x <- V.freeze btAsMut
874 y <- V.freeze btSQPMut
875 z <- V.freeze btCsMut
876 u <- V.freeze btBsMut
877 v <- V.freeze btB2sMut
878 return $ Right (ButcherTable' { am' = x, cv' = z, bv' = u, b2v' = v }, y)
879 else do
880 return $ Left res
881
882-- | Adaptive step-size control
883-- functions.
884--
885-- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control)
886-- allows the user to control the step size adjustment using
887-- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where
888-- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\)
889-- is the required relative error, \(s_i\) is a vector of scaling
890-- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and
891-- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\).
892--
893-- [ARKode](https://computation.llnl.gov/projects/sundials/arkode)
894-- allows the user to control the step size adjustment using
895-- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with
896-- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl),
897-- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no
898-- effect.
899data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical
900 | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem
901 | XX' Double Double Double Double -- ^ include both via relative tolerance
902 -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\)
903 | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\)