summaryrefslogtreecommitdiff
path: root/packages/sundials/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src')
-rw-r--r--packages/sundials/src/Arkode.hsc114
-rw-r--r--packages/sundials/src/Main.hs138
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs898
-rw-r--r--packages/sundials/src/Types.hs40
-rw-r--r--packages/sundials/src/helpers.c44
-rw-r--r--packages/sundials/src/helpers.h9
6 files changed, 1243 insertions, 0 deletions
diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Arkode.hsc
new file mode 100644
index 0000000..9db37b5
--- /dev/null
+++ b/packages/sundials/src/Arkode.hsc
@@ -0,0 +1,114 @@
1module Arkode where
2
3import Foreign
4import Foreign.C.Types
5
6
7#include <stdio.h>
8#include <sundials/sundials_nvector.h>
9#include <sundials/sundials_matrix.h>
10#include <nvector/nvector_serial.h>
11#include <sunmatrix/sunmatrix_dense.h>
12#include <arkode/arkode.h>
13
14
15#def typedef struct _generic_N_Vector SunVector;
16#def typedef struct _N_VectorContent_Serial SunContent;
17
18#def typedef struct _generic_SUNMatrix SunMatrix;
19#def typedef struct _SUNMatrixContent_Dense SunMatrixContent;
20
21getContentMatrixPtr :: Storable a => Ptr b -> IO a
22getContentMatrixPtr ptr = (#peek SunMatrix, content) ptr
23
24getNRows :: Ptr b -> IO CInt
25getNRows ptr = (#peek SunMatrixContent, M) ptr
26putNRows :: CInt -> Ptr b -> IO ()
27putNRows nr ptr = (#poke SunMatrixContent, M) ptr nr
28
29getNCols :: Ptr b -> IO CInt
30getNCols ptr = (#peek SunMatrixContent, N) ptr
31putNCols :: CInt -> Ptr b -> IO ()
32putNCols nc ptr = (#poke SunMatrixContent, N) ptr nc
33
34getMatrixData :: Storable a => Ptr b -> IO a
35getMatrixData ptr = (#peek SunMatrixContent, data) ptr
36
37getContentPtr :: Storable a => Ptr b -> IO a
38getContentPtr ptr = (#peek SunVector, content) ptr
39
40getData :: Storable a => Ptr b -> IO a
41getData ptr = (#peek SunContent, data) ptr
42
43arkSMax :: Int
44arkSMax = #const ARK_S_MAX
45
46mIN_DIRK_NUM, mAX_DIRK_NUM :: Int
47mIN_DIRK_NUM = #const MIN_DIRK_NUM
48mAX_DIRK_NUM = #const MAX_DIRK_NUM
49
50-- FIXME: We could just use inline-c instead
51
52-- Butcher table accessors -- implicit
53sDIRK_2_1_2 :: Int
54sDIRK_2_1_2 = #const SDIRK_2_1_2
55bILLINGTON_3_3_2 :: Int
56bILLINGTON_3_3_2 = #const BILLINGTON_3_3_2
57tRBDF2_3_3_2 :: Int
58tRBDF2_3_3_2 = #const TRBDF2_3_3_2
59kVAERNO_4_2_3 :: Int
60kVAERNO_4_2_3 = #const KVAERNO_4_2_3
61aRK324L2SA_DIRK_4_2_3 :: Int
62aRK324L2SA_DIRK_4_2_3 = #const ARK324L2SA_DIRK_4_2_3
63cASH_5_2_4 :: Int
64cASH_5_2_4 = #const CASH_5_2_4
65cASH_5_3_4 :: Int
66cASH_5_3_4 = #const CASH_5_3_4
67sDIRK_5_3_4 :: Int
68sDIRK_5_3_4 = #const SDIRK_5_3_4
69kVAERNO_5_3_4 :: Int
70kVAERNO_5_3_4 = #const KVAERNO_5_3_4
71aRK436L2SA_DIRK_6_3_4 :: Int
72aRK436L2SA_DIRK_6_3_4 = #const ARK436L2SA_DIRK_6_3_4
73kVAERNO_7_4_5 :: Int
74kVAERNO_7_4_5 = #const KVAERNO_7_4_5
75aRK548L2SA_DIRK_8_4_5 :: Int
76aRK548L2SA_DIRK_8_4_5 = #const ARK548L2SA_DIRK_8_4_5
77
78-- #define DEFAULT_DIRK_2 SDIRK_2_1_2
79-- #define DEFAULT_DIRK_3 ARK324L2SA_DIRK_4_2_3
80-- #define DEFAULT_DIRK_4 SDIRK_5_3_4
81-- #define DEFAULT_DIRK_5 ARK548L2SA_DIRK_8_4_5
82
83-- Butcher table accessors -- explicit
84hEUN_EULER_2_1_2 :: Int
85hEUN_EULER_2_1_2 = #const HEUN_EULER_2_1_2
86bOGACKI_SHAMPINE_4_2_3 :: Int
87bOGACKI_SHAMPINE_4_2_3 = #const BOGACKI_SHAMPINE_4_2_3
88aRK324L2SA_ERK_4_2_3 :: Int
89aRK324L2SA_ERK_4_2_3 = #const ARK324L2SA_ERK_4_2_3
90zONNEVELD_5_3_4 :: Int
91zONNEVELD_5_3_4 = #const ZONNEVELD_5_3_4
92aRK436L2SA_ERK_6_3_4 :: Int
93aRK436L2SA_ERK_6_3_4 = #const ARK436L2SA_ERK_6_3_4
94sAYFY_ABURUB_6_3_4 :: Int
95sAYFY_ABURUB_6_3_4 = #const SAYFY_ABURUB_6_3_4
96cASH_KARP_6_4_5 :: Int
97cASH_KARP_6_4_5 = #const CASH_KARP_6_4_5
98fEHLBERG_6_4_5 :: Int
99fEHLBERG_6_4_5 = #const FEHLBERG_6_4_5
100dORMAND_PRINCE_7_4_5 :: Int
101dORMAND_PRINCE_7_4_5 = #const DORMAND_PRINCE_7_4_5
102aRK548L2SA_ERK_8_4_5 :: Int
103aRK548L2SA_ERK_8_4_5 = #const ARK548L2SA_ERK_8_4_5
104vERNER_8_5_6 :: Int
105vERNER_8_5_6 = #const VERNER_8_5_6
106fEHLBERG_13_7_8 :: Int
107fEHLBERG_13_7_8 = #const FEHLBERG_13_7_8
108
109-- #define DEFAULT_ERK_2 HEUN_EULER_2_1_2
110-- #define DEFAULT_ERK_3 BOGACKI_SHAMPINE_4_2_3
111-- #define DEFAULT_ERK_4 ZONNEVELD_5_3_4
112-- #define DEFAULT_ERK_5 CASH_KARP_6_4_5
113-- #define DEFAULT_ERK_6 VERNER_8_5_6
114-- #define DEFAULT_ERK_8 FEHLBERG_13_7_8
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
new file mode 100644
index 0000000..729d35a
--- /dev/null
+++ b/packages/sundials/src/Main.hs
@@ -0,0 +1,138 @@
1{-# OPTIONS_GHC -Wall #-}
2
3import Numeric.Sundials.ARKode.ODE
4import Numeric.LinearAlgebra
5
6import Plots as P
7import qualified Diagrams.Prelude as D
8import Diagrams.Backend.Rasterific
9
10import Control.Lens
11
12import Test.Hspec
13
14
15lorenz :: Double -> [Double] -> [Double]
16lorenz _t u = [ sigma * (y - x)
17 , x * (rho - z) - y
18 , x * y - beta * z
19 ]
20 where
21 rho = 28.0
22 sigma = 10.0
23 beta = 8.0 / 3.0
24 x = u !! 0
25 y = u !! 1
26 z = u !! 2
27
28_lorenzJac :: Double -> Vector Double -> Matrix Double
29_lorenzJac _t u = (3><3) [ (-sigma), rho - z, y
30 , sigma , -1.0 , x
31 , 0.0 , (-x) , (-beta)
32 ]
33 where
34 rho = 28.0
35 sigma = 10.0
36 beta = 8.0 / 3.0
37 x = u ! 0
38 y = u ! 1
39 z = u ! 2
40
41brusselator :: Double -> [Double] -> [Double]
42brusselator _t x = [ a - (w + 1) * u + v * u * u
43 , w * u - v * u * u
44 , (b - w) / eps - w * u
45 ]
46 where
47 a = 1.0
48 b = 3.5
49 eps = 5.0e-6
50 u = x !! 0
51 v = x !! 1
52 w = x !! 2
53
54_brussJac :: Double -> Vector Double -> Matrix Double
55_brussJac _t x = (3><3) [ (-(w + 1.0)) + 2.0 * u * v, w - 2.0 * u * v, (-w)
56 , u * u , (-(u * u)) , 0.0
57 , (-u) , u , (-1.0) / eps - u
58 ]
59 where
60 y = toList x
61 u = y !! 0
62 v = y !! 1
63 w = y !! 2
64 eps = 5.0e-6
65
66stiffish :: Double -> [Double] -> [Double]
67stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
68 where
69 lamda = -100.0
70 u = v !! 0
71
72stiffishV :: Double -> Vector Double -> Vector Double
73stiffishV t v = fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
74 where
75 lamda = -100.0
76 u = v ! 0
77
78_stiffJac :: Double -> Vector Double -> Matrix Double
79_stiffJac _t _v = (1><1) [ lamda ]
80 where
81 lamda = -100.0
82
83lSaxis :: [[Double]] -> P.Axis B D.V2 Double
84lSaxis xs = P.r2Axis &~ do
85 let ts = xs!!0
86 us = xs!!1
87 vs = xs!!2
88 ws = xs!!3
89 P.linePlot' $ zip ts us
90 P.linePlot' $ zip ts vs
91 P.linePlot' $ zip ts ws
92
93kSaxis :: [(Double, Double)] -> P.Axis B D.V2 Double
94kSaxis xs = P.r2Axis &~ do
95 P.linePlot' xs
96
97main :: IO ()
98main = do
99
100 let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
101 renderRasterific "diagrams/brusselator.png"
102 (D.dims2D 500.0 500.0)
103 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1))
104
105 let res1a = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
106 renderRasterific "diagrams/brusselatorA.png"
107 (D.dims2D 500.0 500.0)
108 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a))
109
110 let res2 = odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0])
111 renderRasterific "diagrams/stiffish.png"
112 (D.dims2D 500.0 500.0)
113 (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2))
114
115 let res2a = odeSolveV (SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
116
117 let res2b = odeSolveV (TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
118
119 let maxDiff = maximum $ map abs $
120 zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0)
121
122 hspec $ describe "Compare results" $ do
123 it "for two different RK methods" $
124 maxDiff < 1.0e-6
125
126 let res3 = odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0])
127
128 renderRasterific "diagrams/lorenz.png"
129 (D.dims2D 500.0 500.0)
130 (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!0) ((toLists $ tr res3)!!1))
131
132 renderRasterific "diagrams/lorenz1.png"
133 (D.dims2D 500.0 500.0)
134 (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!0) ((toLists $ tr res3)!!2))
135
136 renderRasterific "diagrams/lorenz2.png"
137 (D.dims2D 500.0 500.0)
138 (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2))
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
new file mode 100644
index 0000000..e5a2e4d
--- /dev/null
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -0,0 +1,898 @@
1{-# OPTIONS_GHC -Wall #-}
2
3{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-}
6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE DeriveGeneric #-}
9{-# LANGUAGE TypeOperators #-}
10{-# LANGUAGE KindSignatures #-}
11{-# LANGUAGE TypeSynonymInstances #-}
12{-# LANGUAGE FlexibleInstances #-}
13{-# LANGUAGE FlexibleContexts #-}
14
15-----------------------------------------------------------------------------
16-- |
17-- Module : Numeric.Sundials.ARKode.ODE
18-- Copyright : Dominic Steinitz 2018,
19-- Novadiscovery 2018
20-- License : BSD
21-- Maintainer : Dominic Steinitz
22-- Stability : provisional
23--
24-- Solution of ordinary differential equation (ODE) initial value problems.
25--
26-- <https://computation.llnl.gov/projects/sundials/sundials-software>
27--
28-- A simple example:
29--
30-- <<diagrams/brusselator.png#diagram=brusselator&height=400&width=500>>
31--
32-- @
33-- import Numeric.Sundials.ARKode.ODE
34-- import Numeric.LinearAlgebra
35--
36-- import Plots as P
37-- import qualified Diagrams.Prelude as D
38-- import Diagrams.Backend.Rasterific
39--
40-- brusselator :: Double -> [Double] -> [Double]
41-- brusselator _t x = [ a - (w + 1) * u + v * u * u
42-- , w * u - v * u * u
43-- , (b - w) / eps - w * u
44-- ]
45-- where
46-- a = 1.0
47-- b = 3.5
48-- eps = 5.0e-6
49-- u = x !! 0
50-- v = x !! 1
51-- w = x !! 2
52--
53-- lSaxis :: [[Double]] -> P.Axis B D.V2 Double
54-- lSaxis xs = P.r2Axis &~ do
55-- let ts = xs!!0
56-- us = xs!!1
57-- vs = xs!!2
58-- ws = xs!!3
59-- P.linePlot' $ zip ts us
60-- P.linePlot' $ zip ts vs
61-- P.linePlot' $ zip ts ws
62--
63-- main = do
64-- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
65-- renderRasterific "diagrams/brusselator.png"
66-- (D.dims2D 500.0 500.0)
67-- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1))
68-- @
69--
70-- KVAERNO_4_2_3
71--
72-- \[
73-- \begin{array}{c|cccc}
74-- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\
75-- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\
76-- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\
77-- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\
78-- \hline
79-- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\
80-- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\
81-- \end{array}
82-- \]
83--
84-- SDIRK_2_1_2
85--
86-- \[
87-- \begin{array}{c|cc}
88-- 1.0 & 1.0 & 0.0 \\
89-- 0.0 & -1.0 & 1.0 \\
90-- \hline
91-- & 0.5 & 0.5 \\
92-- & 1.0 & 0.0 \\
93-- \end{array}
94-- \]
95--
96-- SDIRK_5_3_4
97--
98-- \[
99-- \begin{array}{c|ccccc}
100-- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\
101-- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\
102-- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\
103-- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\
104-- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\
105-- \hline
106-- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\
107-- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\
108-- \end{array}
109-- \]
110-----------------------------------------------------------------------------
111module Numeric.Sundials.ARKode.ODE ( odeSolve
112 , odeSolveV
113 , odeSolveVWith
114 , odeSolveVWith'
115 , ButcherTable(..)
116 , butcherTable
117 , ODEMethod(..)
118 , StepControl(..)
119 , Jacobian
120 , SundialsDiagnostics(..)
121 ) where
122
123import qualified Language.C.Inline as C
124import qualified Language.C.Inline.Unsafe as CU
125
126import Data.Monoid ((<>))
127import Data.Maybe (isJust)
128
129import Foreign.C.Types
130import Foreign.Ptr (Ptr)
131import Foreign.ForeignPtr (newForeignPtr_)
132import Foreign.Storable (Storable)
133
134import qualified Data.Vector.Storable as V
135import qualified Data.Vector.Storable.Mutable as VM
136
137import Data.Coerce (coerce)
138import System.IO.Unsafe (unsafePerformIO)
139import GHC.Generics
140
141import Numeric.LinearAlgebra.Devel (createVector)
142
143import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
144 subMatrix, rows, cols, toLists,
145 size, subVector)
146
147import qualified Types as T
148import Arkode
149import qualified Arkode as B
150
151
152C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
153
154C.include "<stdlib.h>"
155C.include "<stdio.h>"
156C.include "<math.h>"
157C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts.
158C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
159C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
160C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
161C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
162C.include "<sundials/sundials_types.h>" -- definition of type realtype
163C.include "<sundials/sundials_math.h>"
164C.include "../../../helpers.h"
165C.include "Arkode_hsc.h"
166
167
168getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble)
169getDataFromContents len ptr = do
170 qtr <- B.getContentPtr ptr
171 rtr <- B.getData qtr
172 vectorFromC len rtr
173
174-- FIXME: Potentially an instance of Storable
175_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix
176_getMatrixDataFromContents ptr = do
177 qtr <- B.getContentMatrixPtr ptr
178 rs <- B.getNRows qtr
179 cs <- B.getNCols qtr
180 rtr <- B.getMatrixData qtr
181 vs <- vectorFromC (fromIntegral $ rs * cs) rtr
182 return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs }
183
184putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO ()
185putMatrixDataFromContents mat ptr = do
186 let rs = T.rows mat
187 cs = T.cols mat
188 vs = T.vals mat
189 qtr <- B.getContentMatrixPtr ptr
190 B.putNRows rs qtr
191 B.putNCols cs qtr
192 rtr <- B.getMatrixData qtr
193 vectorToC vs (fromIntegral $ rs * cs) rtr
194-- FIXME: END
195
196putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
197putDataInContents vec len ptr = do
198 qtr <- B.getContentPtr ptr
199 rtr <- B.getData qtr
200 vectorToC vec len rtr
201
202-- Utils
203
204vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
205vectorFromC len ptr = do
206 ptr' <- newForeignPtr_ ptr
207 V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
208
209vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
210vectorToC vec len ptr = do
211 ptr' <- newForeignPtr_ ptr
212 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
213
214data SundialsDiagnostics = SundialsDiagnostics {
215 aRKodeGetNumSteps :: Int
216 , aRKodeGetNumStepAttempts :: Int
217 , aRKodeGetNumRhsEvals_fe :: Int
218 , aRKodeGetNumRhsEvals_fi :: Int
219 , aRKodeGetNumLinSolvSetups :: Int
220 , aRKodeGetNumErrTestFails :: Int
221 , aRKodeGetNumNonlinSolvIters :: Int
222 , aRKodeGetNumNonlinSolvConvFails :: Int
223 , aRKDlsGetNumJacEvals :: Int
224 , aRKDlsGetNumRhsEvals :: Int
225 } deriving Show
226
227type Jacobian = Double -> Vector Double -> Matrix Double
228
229-- | Stepping functions
230data ODEMethod = SDIRK_2_1_2 Jacobian
231 | SDIRK_2_1_2'
232 | BILLINGTON_3_3_2 Jacobian
233 | BILLINGTON_3_3_2'
234 | TRBDF2_3_3_2 Jacobian
235 | TRBDF2_3_3_2'
236 | KVAERNO_4_2_3 Jacobian
237 | KVAERNO_4_2_3'
238 | ARK324L2SA_DIRK_4_2_3 Jacobian
239 | ARK324L2SA_DIRK_4_2_3'
240 | CASH_5_2_4 Jacobian
241 | CASH_5_2_4'
242 | CASH_5_3_4 Jacobian
243 | CASH_5_3_4'
244 | SDIRK_5_3_4 Jacobian
245 | SDIRK_5_3_4'
246 | KVAERNO_5_3_4 Jacobian
247 | KVAERNO_5_3_4'
248 | ARK436L2SA_DIRK_6_3_4 Jacobian
249 | ARK436L2SA_DIRK_6_3_4'
250 | KVAERNO_7_4_5 Jacobian
251 | KVAERNO_7_4_5'
252 | ARK548L2SA_DIRK_8_4_5 Jacobian
253 | ARK548L2SA_DIRK_8_4_5'
254 | HEUN_EULER_2_1_2 Jacobian
255 | HEUN_EULER_2_1_2'
256 | BOGACKI_SHAMPINE_4_2_3 Jacobian
257 | BOGACKI_SHAMPINE_4_2_3'
258 | ARK324L2SA_ERK_4_2_3 Jacobian
259 | ARK324L2SA_ERK_4_2_3'
260 | ZONNEVELD_5_3_4 Jacobian
261 | ZONNEVELD_5_3_4'
262 | ARK436L2SA_ERK_6_3_4 Jacobian
263 | ARK436L2SA_ERK_6_3_4'
264 | SAYFY_ABURUB_6_3_4 Jacobian
265 | SAYFY_ABURUB_6_3_4'
266 | CASH_KARP_6_4_5 Jacobian
267 | CASH_KARP_6_4_5'
268 | FEHLBERG_6_4_5 Jacobian
269 | FEHLBERG_6_4_5'
270 | DORMAND_PRINCE_7_4_5 Jacobian
271 | DORMAND_PRINCE_7_4_5'
272 | ARK548L2SA_ERK_8_4_5 Jacobian
273 | ARK548L2SA_ERK_8_4_5'
274 | VERNER_8_5_6 Jacobian
275 | VERNER_8_5_6'
276 | FEHLBERG_13_7_8 Jacobian
277 | FEHLBERG_13_7_8'
278 deriving Generic
279
280constrName :: (HasConstructor (Rep a), Generic a)=> a -> String
281constrName = genericConstrName . from
282
283class HasConstructor (f :: * -> *) where
284 genericConstrName :: f x -> String
285
286instance HasConstructor f => HasConstructor (D1 c f) where
287 genericConstrName (M1 x) = genericConstrName x
288
289instance (HasConstructor x, HasConstructor y) => HasConstructor (x :+: y) where
290 genericConstrName (L1 l) = genericConstrName l
291 genericConstrName (R1 r) = genericConstrName r
292
293instance Constructor c => HasConstructor (C1 c f) where
294 genericConstrName x = conName x
295
296instance Show ODEMethod where
297 show x = constrName x
298
299-- FIXME: We can probably do better here with generics
300getMethod :: ODEMethod -> Int
301getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2
302getMethod (SDIRK_2_1_2') = sDIRK_2_1_2
303getMethod (BILLINGTON_3_3_2 _) = bILLINGTON_3_3_2
304getMethod (BILLINGTON_3_3_2') = bILLINGTON_3_3_2
305getMethod (TRBDF2_3_3_2 _) = tRBDF2_3_3_2
306getMethod (TRBDF2_3_3_2') = tRBDF2_3_3_2
307getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3
308getMethod (KVAERNO_4_2_3') = kVAERNO_4_2_3
309getMethod (ARK324L2SA_DIRK_4_2_3 _) = aRK324L2SA_DIRK_4_2_3
310getMethod (ARK324L2SA_DIRK_4_2_3') = aRK324L2SA_DIRK_4_2_3
311getMethod (CASH_5_2_4 _) = cASH_5_2_4
312getMethod (CASH_5_2_4') = cASH_5_2_4
313getMethod (CASH_5_3_4 _) = cASH_5_3_4
314getMethod (CASH_5_3_4') = cASH_5_3_4
315getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4
316getMethod (SDIRK_5_3_4') = sDIRK_5_3_4
317getMethod (KVAERNO_5_3_4 _) = kVAERNO_5_3_4
318getMethod (KVAERNO_5_3_4') = kVAERNO_5_3_4
319getMethod (ARK436L2SA_DIRK_6_3_4 _) = aRK436L2SA_DIRK_6_3_4
320getMethod (ARK436L2SA_DIRK_6_3_4') = aRK436L2SA_DIRK_6_3_4
321getMethod (KVAERNO_7_4_5 _) = kVAERNO_7_4_5
322getMethod (KVAERNO_7_4_5') = kVAERNO_7_4_5
323getMethod (ARK548L2SA_DIRK_8_4_5 _) = aRK548L2SA_DIRK_8_4_5
324getMethod (ARK548L2SA_DIRK_8_4_5') = aRK548L2SA_DIRK_8_4_5
325getMethod (HEUN_EULER_2_1_2 _) = hEUN_EULER_2_1_2
326getMethod (HEUN_EULER_2_1_2') = hEUN_EULER_2_1_2
327getMethod (BOGACKI_SHAMPINE_4_2_3 _) = bOGACKI_SHAMPINE_4_2_3
328getMethod (BOGACKI_SHAMPINE_4_2_3') = bOGACKI_SHAMPINE_4_2_3
329getMethod (ARK324L2SA_ERK_4_2_3 _) = aRK324L2SA_ERK_4_2_3
330getMethod (ARK324L2SA_ERK_4_2_3') = aRK324L2SA_ERK_4_2_3
331getMethod (ZONNEVELD_5_3_4 _) = zONNEVELD_5_3_4
332getMethod (ZONNEVELD_5_3_4') = zONNEVELD_5_3_4
333getMethod (ARK436L2SA_ERK_6_3_4 _) = aRK436L2SA_ERK_6_3_4
334getMethod (ARK436L2SA_ERK_6_3_4') = aRK436L2SA_ERK_6_3_4
335getMethod (SAYFY_ABURUB_6_3_4 _) = sAYFY_ABURUB_6_3_4
336getMethod (SAYFY_ABURUB_6_3_4') = sAYFY_ABURUB_6_3_4
337getMethod (CASH_KARP_6_4_5 _) = cASH_KARP_6_4_5
338getMethod (CASH_KARP_6_4_5') = cASH_KARP_6_4_5
339getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5
340getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5
341getMethod (DORMAND_PRINCE_7_4_5 _) = dORMAND_PRINCE_7_4_5
342getMethod (DORMAND_PRINCE_7_4_5') = dORMAND_PRINCE_7_4_5
343getMethod (ARK548L2SA_ERK_8_4_5 _) = aRK548L2SA_ERK_8_4_5
344getMethod (ARK548L2SA_ERK_8_4_5') = aRK548L2SA_ERK_8_4_5
345getMethod (VERNER_8_5_6 _) = vERNER_8_5_6
346getMethod (VERNER_8_5_6') = vERNER_8_5_6
347getMethod (FEHLBERG_13_7_8 _) = fEHLBERG_13_7_8
348getMethod (FEHLBERG_13_7_8') = fEHLBERG_13_7_8
349
350getJacobian :: ODEMethod -> Maybe Jacobian
351getJacobian (SDIRK_2_1_2 j) = Just j
352getJacobian (BILLINGTON_3_3_2 j) = Just j
353getJacobian (TRBDF2_3_3_2 j) = Just j
354getJacobian (KVAERNO_4_2_3 j) = Just j
355getJacobian (ARK324L2SA_DIRK_4_2_3 j) = Just j
356getJacobian (CASH_5_2_4 j) = Just j
357getJacobian (CASH_5_3_4 j) = Just j
358getJacobian (SDIRK_5_3_4 j) = Just j
359getJacobian (KVAERNO_5_3_4 j) = Just j
360getJacobian (ARK436L2SA_DIRK_6_3_4 j) = Just j
361getJacobian (KVAERNO_7_4_5 j) = Just j
362getJacobian (ARK548L2SA_DIRK_8_4_5 j) = Just j
363getJacobian (HEUN_EULER_2_1_2 j) = Just j
364getJacobian (BOGACKI_SHAMPINE_4_2_3 j) = Just j
365getJacobian (ARK324L2SA_ERK_4_2_3 j) = Just j
366getJacobian (ZONNEVELD_5_3_4 j) = Just j
367getJacobian (ARK436L2SA_ERK_6_3_4 j) = Just j
368getJacobian (SAYFY_ABURUB_6_3_4 j) = Just j
369getJacobian (CASH_KARP_6_4_5 j) = Just j
370getJacobian (FEHLBERG_6_4_5 j) = Just j
371getJacobian (DORMAND_PRINCE_7_4_5 j) = Just j
372getJacobian (ARK548L2SA_ERK_8_4_5 j) = Just j
373getJacobian (VERNER_8_5_6 j) = Just j
374getJacobian (FEHLBERG_13_7_8 j) = Just j
375getJacobian _ = Nothing
376
377-- | A version of 'odeSolveVWith' with reasonable default step control.
378odeSolveV
379 :: ODEMethod
380 -> Maybe Double -- ^ initial step size - by default, ARKode
381 -- estimates the initial step size to be the
382 -- solution \(h\) of the equation
383 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
384 -- \(\ddot{y}\) is an estimated value of the
385 -- second derivative of the solution at \(t_0\)
386 -> Double -- ^ absolute tolerance for the state vector
387 -> Double -- ^ relative tolerance for the state vector
388 -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
389 -> Vector Double -- ^ initial conditions
390 -> Vector Double -- ^ desired solution times
391 -> Matrix Double -- ^ solution
392odeSolveV meth hi epsAbs epsRel f y0 ts =
393 case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of
394 Left c -> error $ show c -- FIXME
395 -- FIXME: Can we do better than using lists?
396 Right (v, _d) -> (nR >< nC) (V.toList v)
397 where
398 us = toList ts
399 nR = length us
400 nC = size y0
401 g t x0 = coerce $ f t x0
402
403-- | A version of 'odeSolveV' with reasonable default parameters and
404-- system of equations defined using lists. FIXME: we should say
405-- something about the fact we could use the Jacobian but don't for
406-- compatibility with hmatrix-gsl.
407odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
408 -> [Double] -- ^ initial conditions
409 -> Vector Double -- ^ desired solution times
410 -> Matrix Double -- ^ solution
411odeSolve f y0 ts =
412 -- FIXME: These tolerances are different from the ones in GSL
413 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of
414 Left c -> error $ show c -- FIXME
415 Right (v, _d) -> (nR >< nC) (V.toList v)
416 where
417 us = toList ts
418 nR = length us
419 nC = length y0
420 g t x0 = V.fromList $ f t (V.toList x0)
421
422odeSolveVWith' ::
423 ODEMethod
424 -> StepControl
425 -> Maybe Double -- ^ initial step size - by default, ARKode
426 -- estimates the initial step size to be the
427 -- solution \(h\) of the equation
428 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
429 -- \(\ddot{y}\) is an estimated value of the second
430 -- derivative of the solution at \(t_0\)
431 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
432 -> V.Vector Double -- ^ Initial conditions
433 -> V.Vector Double -- ^ Desired solution times
434 -> Matrix Double -- ^ Error code or solution
435odeSolveVWith' method control initStepSize f y0 tt =
436 case odeSolveVWith method control initStepSize f y0 tt of
437 Left c -> error $ show c -- FIXME
438 Right (v, _d) -> (nR >< nC) (V.toList v)
439 where
440 nR = V.length tt
441 nC = V.length y0
442
443odeSolveVWith ::
444 ODEMethod
445 -> StepControl
446 -> Maybe Double -- ^ initial step size - by default, ARKode
447 -- estimates the initial step size to be the
448 -- solution \(h\) of the equation
449 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
450 -- \(\ddot{y}\) is an estimated value of the second
451 -- derivative of the solution at \(t_0\)
452 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
453 -> V.Vector Double -- ^ Initial conditions
454 -> V.Vector Double -- ^ Desired solution times
455 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
456odeSolveVWith method control initStepSize f y0 tt =
457 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
458 (coerce f) (coerce y0) (coerce tt) of
459 Left c -> Left $ fromIntegral c
460 Right (v, d) -> Right (coerce v, d)
461 where
462 l = size y0
463 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol)
464 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol)
465 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol)
466 -- FIXME; Should we check that the length of ss is correct?
467 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol)
468 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
469 getJacobian method
470 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
471 where
472 nr = fromIntegral $ rows m
473 nc = fromIntegral $ cols m
474 -- FIXME: efficiency
475 vs = V.fromList $ map coerce $ concat $ toLists m
476
477solveOdeC ::
478 CInt ->
479 Maybe CDouble ->
480 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
481 (V.Vector CDouble, CDouble) ->
482 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
483 -> V.Vector CDouble -- ^ Initial conditions
484 -> V.Vector CDouble -- ^ Desired solution times
485 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
486solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
487
488 let isInitStepSize :: CInt
489 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
490 ss :: CDouble
491 ss = case initStepSize of
492 -- It would be better to put an error message here but
493 -- inline-c seems to evaluate this even if it is never
494 -- used :(
495 Nothing -> 0.0
496 Just x -> x
497 let dim = V.length f0
498 nEq :: CLong
499 nEq = fromIntegral dim
500 nTs :: CInt
501 nTs = fromIntegral $ V.length ts
502 -- FIXME: fMut is not actually mutatated
503 fMut <- V.thaw f0
504 tMut <- V.thaw ts
505 -- FIXME: I believe this gets taken from the ghc heap and so should
506 -- be subject to garbage collection.
507 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
508 qMatMut <- V.thaw quasiMatrixRes
509 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
510 diagMut <- V.thaw diagnostics
511 -- We need the types that sundials expects. These are tied together
512 -- in 'Types'. FIXME: The Haskell type is currently empty!
513 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
514 funIO x y f _ptr = do
515 -- Convert the pointer we get from C (y) to a vector, and then
516 -- apply the user-supplied function.
517 fImm <- fun x <$> getDataFromContents dim y
518 -- Fill in the provided pointer with the resulting vector.
519 putDataInContents fImm dim f
520 -- FIXME: I don't understand what this comment means
521 -- Unsafe since the function will be called many times.
522 [CU.exp| int{ 0 } |]
523 let isJac :: CInt
524 isJac = fromIntegral $ fromEnum $ isJust jacH
525 jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix ->
526 Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector ->
527 IO CInt
528 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
529 case jacH of
530 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined"
531 Just jacI -> do j <- jacI t <$> getDataFromContents dim y
532 putMatrixDataFromContents j jacS
533 -- FIXME: I don't understand what this comment means
534 -- Unsafe since the function will be called many times.
535 [CU.exp| int{ 0 } |]
536
537 res <- [C.block| int {
538 /* general problem variables */
539
540 int flag; /* reusable error-checking flag */
541 int i, j; /* reusable loop indices */
542 N_Vector y = NULL; /* empty vector for storing solution */
543 N_Vector tv = NULL; /* empty vector for storing absolute tolerances */
544 SUNMatrix A = NULL; /* empty matrix for linear solver */
545 SUNLinearSolver LS = NULL; /* empty linear solver object */
546 void *arkode_mem = NULL; /* empty ARKode memory structure */
547 realtype t;
548 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
549
550 /* general problem parameters */
551
552 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */
553 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
554
555 /* Initialize data structures */
556
557 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
558 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
559 /* Specify initial condition */
560 for (i = 0; i < NEQ; i++) {
561 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
562 };
563
564 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
565 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
566 /* Specify tolerances */
567 for (i = 0; i < NEQ; i++) {
568 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i];
569 };
570
571 arkode_mem = ARKodeCreate(); /* Create the solver memory */
572 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
573
574 /* Call ARKodeInit to initialize the integrator memory and specify the */
575 /* right-hand side function in y'=f(t,y), the inital time T0, and */
576 /* the initial dependent variable vector y. Note: we treat the */
577 /* problem as fully implicit and set f_E to NULL and f_I to f. */
578
579 /* Here we use the C types defined in helpers.h which tie up with */
580 /* the Haskell types defined in Types */
581 if ($(int method) < MIN_DIRK_NUM) {
582 flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y);
583 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
584 } else {
585 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
586 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
587 }
588
589 /* FIXME: A hack for initial testing */
590 flag = ARKodeSetMinStep(arkode_mem, 1.0e-12);
591 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1;
592 flag = ARKodeSetMaxNumSteps(arkode_mem, 10000);
593 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1;
594
595 /* Set routines */
596 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv);
597 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1;
598
599 /* Initialize dense matrix data structure and solver */
600 A = SUNDenseMatrix(NEQ, NEQ);
601 if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
602 LS = SUNDenseLinearSolver(y, A);
603 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
604
605 /* Attach matrix and linear solver */
606 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A);
607 if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1;
608
609 /* Set the initial step size if there is one */
610 if ($(int isInitStepSize)) {
611 /* FIXME: We could check if the initial step size is 0 */
612 /* or even NaN and then throw an error */
613 flag = ARKodeSetInitStep(arkode_mem, $(double ss));
614 if (check_flag(&flag, "ARKodeSetInitStep", 1)) return 1;
615 }
616
617 /* Set the Jacobian if there is one */
618 if ($(int isJac)) {
619 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
620 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1;
621 }
622
623 /* Store initial conditions */
624 for (j = 0; j < NEQ; j++) {
625 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
626 }
627
628 /* Explicitly set the method */
629 if ($(int method) >= MIN_DIRK_NUM) {
630 flag = ARKodeSetIRKTableNum(arkode_mem, $(int method));
631 if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1;
632 } else {
633 flag = ARKodeSetERKTableNum(arkode_mem, $(int method));
634 if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1;
635 }
636
637 /* Main time-stepping loop: calls ARKode to perform the integration */
638 /* Stops when the final time has been reached */
639 for (i = 1; i < $(int nTs); i++) {
640
641 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */
642 if (check_flag(&flag, "ARKode", 1)) break;
643
644 /* Store the results for Haskell */
645 for (j = 0; j < NEQ; j++) {
646 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
647 }
648
649 /* unsuccessful solve: break */
650 if (flag < 0) {
651 fprintf(stderr,"Solver failure, stopping integration\n");
652 break;
653 }
654 }
655
656 /* Get some final statistics on how the solve progressed */
657
658 flag = ARKodeGetNumSteps(arkode_mem, &nst);
659 check_flag(&flag, "ARKodeGetNumSteps", 1);
660 ($vec-ptr:(long int *diagMut))[0] = nst;
661
662 flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a);
663 check_flag(&flag, "ARKodeGetNumStepAttempts", 1);
664 ($vec-ptr:(long int *diagMut))[1] = nst_a;
665
666 flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi);
667 check_flag(&flag, "ARKodeGetNumRhsEvals", 1);
668 ($vec-ptr:(long int *diagMut))[2] = nfe;
669 ($vec-ptr:(long int *diagMut))[3] = nfi;
670
671 flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups);
672 check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1);
673 ($vec-ptr:(long int *diagMut))[4] = nsetups;
674
675 flag = ARKodeGetNumErrTestFails(arkode_mem, &netf);
676 check_flag(&flag, "ARKodeGetNumErrTestFails", 1);
677 ($vec-ptr:(long int *diagMut))[5] = netf;
678
679 flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni);
680 check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1);
681 ($vec-ptr:(long int *diagMut))[6] = nni;
682
683 flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn);
684 check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1);
685 ($vec-ptr:(long int *diagMut))[7] = ncfn;
686
687 flag = ARKDlsGetNumJacEvals(arkode_mem, &nje);
688 check_flag(&flag, "ARKDlsGetNumJacEvals", 1);
689 ($vec-ptr:(long int *diagMut))[8] = ncfn;
690
691 flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS);
692 check_flag(&flag, "ARKDlsGetNumRhsEvals", 1);
693 ($vec-ptr:(long int *diagMut))[9] = ncfn;
694
695 /* Clean up and return */
696 N_VDestroy(y); /* Free y vector */
697 N_VDestroy(tv); /* Free tv vector */
698 ARKodeFree(&arkode_mem); /* Free integrator memory */
699 SUNLinSolFree(LS); /* Free linear solver */
700 SUNMatDestroy(A); /* Free A matrix */
701
702 return flag;
703 } |]
704 if res == 0
705 then do
706 preD <- V.freeze diagMut
707 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
708 (fromIntegral $ preD V.!1)
709 (fromIntegral $ preD V.!2)
710 (fromIntegral $ preD V.!3)
711 (fromIntegral $ preD V.!4)
712 (fromIntegral $ preD V.!5)
713 (fromIntegral $ preD V.!6)
714 (fromIntegral $ preD V.!7)
715 (fromIntegral $ preD V.!8)
716 (fromIntegral $ preD V.!9)
717 m <- V.freeze qMatMut
718 return $ Right (m, d)
719 else do
720 return $ Left res
721
722data ButcherTable = ButcherTable { am :: Matrix Double
723 , cv :: Vector Double
724 , bv :: Vector Double
725 , b2v :: Vector Double
726 }
727 deriving Show
728
729data ButcherTable' a = ButcherTable' { am' :: V.Vector a
730 , cv' :: V.Vector a
731 , bv' :: V.Vector a
732 , b2v' :: V.Vector a
733 }
734 deriving Show
735
736butcherTable :: ODEMethod -> ButcherTable
737butcherTable method =
738 case getBT method of
739 Left c -> error $ show c -- FIXME
740 Right (ButcherTable' v w x y, sqp) ->
741 ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v)
742 , cv = subVector 0 s w
743 , bv = subVector 0 s x
744 , b2v = subVector 0 s y
745 }
746 where
747 s = fromIntegral $ sqp V.! 0
748
749getBT :: ODEMethod -> Either Int (ButcherTable' Double, V.Vector Int)
750getBT method = case getButcherTable method of
751 Left c ->
752 Left $ fromIntegral c
753 Right (ButcherTable' a b c d, sqp) ->
754 Right $ ( ButcherTable' (coerce a) (coerce b) (coerce c) (coerce d)
755 , V.map fromIntegral sqp )
756
757getButcherTable :: ODEMethod
758 -> Either CInt (ButcherTable' CDouble, V.Vector CInt)
759getButcherTable method = unsafePerformIO $ do
760 -- ARKode seems to want an ODE in order to set and then get the
761 -- Butcher tableau so here's one to keep it happy
762 let funI :: CDouble -> V.Vector CDouble -> V.Vector CDouble
763 funI _t ys = V.fromList [ ys V.! 0 ]
764 let funE :: CDouble -> V.Vector CDouble -> V.Vector CDouble
765 funE _t ys = V.fromList [ ys V.! 0 ]
766 f0 = V.fromList [ 1.0 ]
767 ts = V.fromList [ 0.0 ]
768 dim = V.length f0
769 nEq :: CLong
770 nEq = fromIntegral dim
771 mN :: CInt
772 mN = fromIntegral $ getMethod method
773
774 btSQP :: V.Vector CInt <- createVector 3
775 btSQPMut <- V.thaw btSQP
776 btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax)
777 btAsMut <- V.thaw btAs
778 btCs :: V.Vector CDouble <- createVector B.arkSMax
779 btBs :: V.Vector CDouble <- createVector B.arkSMax
780 btB2s :: V.Vector CDouble <- createVector B.arkSMax
781 btCsMut <- V.thaw btCs
782 btBsMut <- V.thaw btBs
783 btB2sMut <- V.thaw btB2s
784 let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
785 funIOI x y f _ptr = do
786 fImm <- funI x <$> getDataFromContents dim y
787 putDataInContents fImm dim f
788 -- FIXME: I don't understand what this comment means
789 -- Unsafe since the function will be called many times.
790 [CU.exp| int{ 0 } |]
791 let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
792 funIOE x y f _ptr = do
793 fImm <- funE x <$> getDataFromContents dim y
794 putDataInContents fImm dim f
795 -- FIXME: I don't understand what this comment means
796 -- Unsafe since the function will be called many times.
797 [CU.exp| int{ 0 } |]
798 res <- [C.block| int {
799 /* general problem variables */
800
801 int flag; /* reusable error-checking flag */
802 N_Vector y = NULL; /* empty vector for storing solution */
803 void *arkode_mem = NULL; /* empty ARKode memory structure */
804 int i, j; /* reusable loop indices */
805
806 /* general problem parameters */
807
808 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
809 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars */
810
811 /* Initialize data structures */
812
813 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
814 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
815 /* Specify initial condition */
816 for (i = 0; i < NEQ; i++) {
817 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
818 };
819 arkode_mem = ARKodeCreate(); /* Create the solver memory */
820 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
821
822 flag = ARKodeInit(arkode_mem, $fun:(int (* funIOE) (double t, SunVector y[], SunVector dydt[], void * params)), $fun:(int (* funIOI) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
823 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
824
825 if ($(int mN) >= MIN_DIRK_NUM) {
826 flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN));
827 if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1;
828 } else {
829 flag = ARKodeSetERKTableNum(arkode_mem, $(int mN));
830 if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1;
831 }
832
833 int s, q, p;
834 realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
835 realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
836 realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
837 realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
838 realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
839 realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
840 realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
841 realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
842 flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e);
843 if (check_flag(&flag, "ARKode", 1)) return 1;
844 $vec-ptr:(int *btSQPMut)[0] = s;
845 $vec-ptr:(int *btSQPMut)[1] = q;
846 $vec-ptr:(int *btSQPMut)[2] = p;
847 for (i = 0; i < s; i++) {
848 for (j = 0; j < s; j++) {
849 /* FIXME: double should be realtype */
850 ($vec-ptr:(double *btAsMut))[i * ARK_S_MAX + j] = ai[i * ARK_S_MAX + j];
851 }
852 }
853
854 for (i = 0; i < s; i++) {
855 ($vec-ptr:(double *btCsMut))[i] = ci[i];
856 ($vec-ptr:(double *btBsMut))[i] = bi[i];
857 ($vec-ptr:(double *btB2sMut))[i] = b2i[i];
858 }
859
860 /* Clean up and return */
861 N_VDestroy(y); /* Free y vector */
862 ARKodeFree(&arkode_mem); /* Free integrator memory */
863
864 return flag;
865 } |]
866 if res == 0
867 then do
868 x <- V.freeze btAsMut
869 y <- V.freeze btSQPMut
870 z <- V.freeze btCsMut
871 u <- V.freeze btBsMut
872 v <- V.freeze btB2sMut
873 return $ Right (ButcherTable' { am' = x, cv' = z, bv' = u, b2v' = v }, y)
874 else do
875 return $ Left res
876
877-- | Adaptive step-size control
878-- functions.
879--
880-- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control)
881-- allows the user to control the step size adjustment using
882-- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where
883-- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\)
884-- is the required relative error, \(s_i\) is a vector of scaling
885-- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and
886-- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\).
887--
888-- [ARKode](https://computation.llnl.gov/projects/sundials/arkode)
889-- allows the user to control the step size adjustment using
890-- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with
891-- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl),
892-- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no
893-- effect.
894data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical
895 | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem
896 | XX' Double Double Double Double -- ^ include both via relative tolerance
897 -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\)
898 | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\)
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs
new file mode 100644
index 0000000..04e4280
--- /dev/null
+++ b/packages/sundials/src/Types.hs
@@ -0,0 +1,40 @@
1{-# OPTIONS_GHC -Wall #-}
2
3{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-}
6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE EmptyDataDecls #-}
8
9module Types where
10
11import Foreign.C.Types
12
13import qualified Language.Haskell.TH as TH
14import qualified Language.C.Types as CT
15import qualified Data.Map as Map
16import Language.C.Inline.Context
17
18import qualified Data.Vector.Storable as V
19
20
21data SunVector
22data SunMatrix = SunMatrix { rows :: CInt
23 , cols :: CInt
24 , vals :: V.Vector CDouble
25 }
26
27-- FIXME: Is this true?
28type SunIndexType = CLong
29
30sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ
31sunTypesTable = Map.fromList
32 [
33 (CT.TypeName "sunindextype", [t| SunIndexType |] )
34 , (CT.TypeName "SunVector", [t| SunVector |] )
35 , (CT.TypeName "SunMatrix", [t| SunMatrix |] )
36 ]
37
38sunCtx :: Context
39sunCtx = mempty {ctxTypesTable = sunTypesTable}
40
diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c
new file mode 100644
index 0000000..f0ca592
--- /dev/null
+++ b/packages/sundials/src/helpers.c
@@ -0,0 +1,44 @@
1#include <stdio.h>
2#include <math.h>
3#include <arkode/arkode.h> /* prototypes for ARKODE fcts., consts. */
4#include <nvector/nvector_serial.h> /* serial N_Vector types, fcts., macros */
5#include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix */
6#include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver */
7#include <arkode/arkode_direct.h> /* access to ARKDls interface */
8#include <sundials/sundials_types.h> /* definition of type realtype */
9#include <sundials/sundials_math.h>
10
11/* Check function return value...
12 opt == 0 means SUNDIALS function allocates memory so check if
13 returned NULL pointer
14 opt == 1 means SUNDIALS function returns a flag so check if
15 flag >= 0
16 opt == 2 means function allocates memory so check if returned
17 NULL pointer
18*/
19int check_flag(void *flagvalue, const char *funcname, int opt)
20{
21 int *errflag;
22
23 /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
24 if (opt == 0 && flagvalue == NULL) {
25 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
26 funcname);
27 return 1; }
28
29 /* Check if flag < 0 */
30 else if (opt == 1) {
31 errflag = (int *) flagvalue;
32 if (*errflag < 0) {
33 fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n",
34 funcname, *errflag);
35 return 1; }}
36
37 /* Check if function returned NULL pointer - no memory allocated */
38 else if (opt == 2 && flagvalue == NULL) {
39 fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
40 funcname);
41 return 1; }
42
43 return 0;
44}
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h
new file mode 100644
index 0000000..3d8fbc0
--- /dev/null
+++ b/packages/sundials/src/helpers.h
@@ -0,0 +1,9 @@
1/* Check function return value...
2 opt == 0 means SUNDIALS function allocates memory so check if
3 returned NULL pointer
4 opt == 1 means SUNDIALS function returns a flag so check if
5 flag >= 0
6 opt == 2 means function allocates memory so check if returned
7 NULL pointer
8*/
9int check_flag(void *flagvalue, const char *funcname, int opt);