diff options
author | Peter Dobsan <pdobsan@gmail.com> | 2018-05-03 20:08:33 +0200 |
---|---|---|
committer | Peter Dobsan <pdobsan@gmail.com> | 2018-05-03 20:08:33 +0200 |
commit | cafdc664c01ea7392c81c352b5c5444dc2963531 (patch) | |
tree | c6d9a758fa7c36730c0468b393a6dc8c47cbfac2 /packages | |
parent | ea1bfea4486f8f2c646f82dabd1ff9a222b68506 (diff) | |
parent | 1675813d8f540af9832a78c7a7a40bbdf1cec42c (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'packages')
-rw-r--r-- | packages/sundials/hmatrix-sundials.cabal | 18 | ||||
-rw-r--r-- | packages/sundials/src/Main.hs | 72 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 269 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/Arkode.hsc (renamed from packages/sundials/src/Arkode.hsc) | 98 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | 476 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 32 | ||||
-rw-r--r-- | packages/sundials/src/Types.hs | 40 |
7 files changed, 810 insertions, 195 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 388f1db..cd2be4e 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -25,21 +25,24 @@ library | |||
25 | template-haskell >=2.12 && <2.13, | 25 | template-haskell >=2.12 && <2.13, |
26 | containers >=0.5 && <0.6, | 26 | containers >=0.5 && <0.6, |
27 | hmatrix>=0.18 | 27 | hmatrix>=0.18 |
28 | extra-libraries: sundials_arkode | 28 | extra-libraries: sundials_arkode, |
29 | sundials_cvode | ||
29 | other-extensions: QuasiQuotes | 30 | other-extensions: QuasiQuotes |
30 | hs-source-dirs: src | 31 | hs-source-dirs: src |
31 | exposed-modules: Numeric.Sundials.ARKode.ODE | 32 | exposed-modules: Numeric.Sundials.ODEOpts, |
32 | other-modules: Types, | 33 | Numeric.Sundials.ARKode.ODE, |
33 | Arkode | 34 | Numeric.Sundials.CVode.ODE |
35 | other-modules: Numeric.Sundials.Arkode | ||
34 | c-sources: src/helpers.c src/helpers.h | 36 | c-sources: src/helpers.c src/helpers.h |
35 | default-language: Haskell2010 | 37 | default-language: Haskell2010 |
36 | 38 | ||
37 | test-suite hmatrix-sundials-testsuite | 39 | test-suite hmatrix-sundials-testsuite |
38 | type: exitcode-stdio-1.0 | 40 | type: exitcode-stdio-1.0 |
39 | main-is: Main.hs | 41 | main-is: Main.hs |
40 | other-modules: Types, | 42 | other-modules: Numeric.Sundials.ODEOpts, |
41 | Numeric.Sundials.ARKode.ODE, | 43 | Numeric.Sundials.ARKode.ODE, |
42 | Arkode | 44 | Numeric.Sundials.CVode.ODE, |
45 | Numeric.Sundials.Arkode | ||
43 | build-depends: base >=4.10 && <4.11, | 46 | build-depends: base >=4.10 && <4.11, |
44 | inline-c >=0.6 && <0.7, | 47 | inline-c >=0.6 && <0.7, |
45 | vector >=0.12 && <0.13, | 48 | vector >=0.12 && <0.13, |
@@ -52,6 +55,7 @@ test-suite hmatrix-sundials-testsuite | |||
52 | lens, | 55 | lens, |
53 | hspec | 56 | hspec |
54 | hs-source-dirs: src | 57 | hs-source-dirs: src |
55 | extra-libraries: sundials_arkode | 58 | extra-libraries: sundials_arkode, |
59 | sundials_cvode | ||
56 | c-sources: src/helpers.c src/helpers.h | 60 | c-sources: src/helpers.c src/helpers.h |
57 | default-language: Haskell2010 | 61 | default-language: Haskell2010 |
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 729d35a..16c21c5 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -1,6 +1,7 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | 1 | {-# OPTIONS_GHC -Wall #-} |
2 | 2 | ||
3 | import Numeric.Sundials.ARKode.ODE | 3 | import qualified Numeric.Sundials.ARKode.ODE as ARK |
4 | import qualified Numeric.Sundials.CVode.ODE as CV | ||
4 | import Numeric.LinearAlgebra | 5 | import Numeric.LinearAlgebra |
5 | 6 | ||
6 | import Plots as P | 7 | import Plots as P |
@@ -80,6 +81,23 @@ _stiffJac _t _v = (1><1) [ lamda ] | |||
80 | where | 81 | where |
81 | lamda = -100.0 | 82 | lamda = -100.0 |
82 | 83 | ||
84 | predatorPrey :: Double -> [Double] -> [Double] | ||
85 | predatorPrey _t v = [ x * a - b * x * y | ||
86 | , d * x * y - c * y - e * y * z | ||
87 | , (-f) * z + g * y * z | ||
88 | ] | ||
89 | where | ||
90 | x = v!!0 | ||
91 | y = v!!1 | ||
92 | z = v!!2 | ||
93 | a = 1.0 | ||
94 | b = 1.0 | ||
95 | c = 1.0 | ||
96 | d = 1.0 | ||
97 | e = 1.0 | ||
98 | f = 1.0 | ||
99 | g = 1.0 | ||
100 | |||
83 | lSaxis :: [[Double]] -> P.Axis B D.V2 Double | 101 | lSaxis :: [[Double]] -> P.Axis B D.V2 Double |
84 | lSaxis xs = P.r2Axis &~ do | 102 | lSaxis xs = P.r2Axis &~ do |
85 | let ts = xs!!0 | 103 | let ts = xs!!0 |
@@ -97,33 +115,37 @@ kSaxis xs = P.r2Axis &~ do | |||
97 | main :: IO () | 115 | main :: IO () |
98 | main = do | 116 | main = do |
99 | 117 | ||
100 | let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | 118 | let res1 = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) |
101 | renderRasterific "diagrams/brusselator.png" | 119 | renderRasterific "diagrams/brusselator.png" |
102 | (D.dims2D 500.0 500.0) | 120 | (D.dims2D 500.0 500.0) |
103 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | 121 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) |
104 | 122 | ||
105 | let res1a = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | 123 | let res1a = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) |
106 | renderRasterific "diagrams/brusselatorA.png" | 124 | renderRasterific "diagrams/brusselatorA.png" |
107 | (D.dims2D 500.0 500.0) | 125 | (D.dims2D 500.0 500.0) |
108 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a)) | 126 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a)) |
109 | 127 | ||
110 | let res2 = odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) | 128 | let res2 = ARK.odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) |
111 | renderRasterific "diagrams/stiffish.png" | 129 | renderRasterific "diagrams/stiffish.png" |
112 | (D.dims2D 500.0 500.0) | 130 | (D.dims2D 500.0 500.0) |
113 | (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) | 131 | (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) |
114 | 132 | ||
115 | let res2a = odeSolveV (SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) | 133 | let res2a = ARK.odeSolveV (ARK.SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) |
116 | 134 | ||
117 | let res2b = odeSolveV (TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) | 135 | let res2b = ARK.odeSolveV (ARK.TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) |
118 | 136 | ||
119 | let maxDiff = maximum $ map abs $ | 137 | let maxDiffA = maximum $ map abs $ |
120 | zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) | 138 | zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) |
121 | 139 | ||
122 | hspec $ describe "Compare results" $ do | 140 | let res2c = CV.odeSolveV (CV.BDF) Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) |
123 | it "for two different RK methods" $ | 141 | |
124 | maxDiff < 1.0e-6 | 142 | let maxDiffB = maximum $ map abs $ |
143 | zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2c)!!0) | ||
125 | 144 | ||
126 | let res3 = odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) | 145 | let maxDiffC = maximum $ map abs $ |
146 | zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0) | ||
147 | |||
148 | let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) | ||
127 | 149 | ||
128 | renderRasterific "diagrams/lorenz.png" | 150 | renderRasterific "diagrams/lorenz.png" |
129 | (D.dims2D 500.0 500.0) | 151 | (D.dims2D 500.0 500.0) |
@@ -136,3 +158,29 @@ main = do | |||
136 | renderRasterific "diagrams/lorenz2.png" | 158 | renderRasterific "diagrams/lorenz2.png" |
137 | (D.dims2D 500.0 500.0) | 159 | (D.dims2D 500.0 500.0) |
138 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2)) | 160 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2)) |
161 | |||
162 | let res4 = CV.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0]) | ||
163 | |||
164 | renderRasterific "diagrams/predatorPrey.png" | ||
165 | (D.dims2D 500.0 500.0) | ||
166 | (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!1)) | ||
167 | |||
168 | renderRasterific "diagrams/predatorPrey1.png" | ||
169 | (D.dims2D 500.0 500.0) | ||
170 | (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!2)) | ||
171 | |||
172 | renderRasterific "diagrams/predatorPrey2.png" | ||
173 | (D.dims2D 500.0 500.0) | ||
174 | (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!1) ((toLists $ tr res4)!!2)) | ||
175 | |||
176 | let res4a = ARK.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0]) | ||
177 | |||
178 | let maxDiffPpA = maximum $ map abs $ | ||
179 | zipWith (-) ((toLists $ tr res4)!!0) ((toLists $ tr res4a)!!0) | ||
180 | |||
181 | hspec $ describe "Compare results" $ do | ||
182 | it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6 | ||
183 | it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6 | ||
184 | it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6 | ||
185 | it "for CV and ARK for the Predator Prey model" $ maxDiffPpA < 1.0e-3 | ||
186 | |||
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index e5a2e4d..fafc237 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -1,5 +1,3 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | 1 | {-# LANGUAGE QuasiQuotes #-} |
4 | {-# LANGUAGE TemplateHaskell #-} | 2 | {-# LANGUAGE TemplateHaskell #-} |
5 | {-# LANGUAGE MultiWayIf #-} | 3 | {-# LANGUAGE MultiWayIf #-} |
@@ -22,8 +20,7 @@ | |||
22 | -- Stability : provisional | 20 | -- Stability : provisional |
23 | -- | 21 | -- |
24 | -- Solution of ordinary differential equation (ODE) initial value problems. | 22 | -- Solution of ordinary differential equation (ODE) initial value problems. |
25 | -- | 23 | -- See <https://computation.llnl.gov/projects/sundials/sundials-software> for more detail. |
26 | -- <https://computation.llnl.gov/projects/sundials/sundials-software> | ||
27 | -- | 24 | -- |
28 | -- A simple example: | 25 | -- A simple example: |
29 | -- | 26 | -- |
@@ -67,6 +64,54 @@ | |||
67 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | 64 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) |
68 | -- @ | 65 | -- @ |
69 | -- | 66 | -- |
67 | -- With Sundials ARKode, it is possible to retrieve the Butcher tableau for the solver. | ||
68 | -- | ||
69 | -- @ | ||
70 | -- import Numeric.Sundials.ARKode.ODE | ||
71 | -- import Numeric.LinearAlgebra | ||
72 | -- | ||
73 | -- import Data.List (intercalate) | ||
74 | -- | ||
75 | -- import Text.PrettyPrint.HughesPJClass | ||
76 | -- | ||
77 | -- | ||
78 | -- butcherTableauTex :: ButcherTable -> String | ||
79 | -- butcherTableauTex (ButcherTable m c b b2) = | ||
80 | -- render $ | ||
81 | -- vcat [ text ("\n\\begin{array}{c|" ++ (concat $ replicate n "c") ++ "}") | ||
82 | -- , us | ||
83 | -- , text "\\hline" | ||
84 | -- , text bs <+> text "\\\\" | ||
85 | -- , text b2s <+> text "\\\\" | ||
86 | -- , text "\\end{array}" | ||
87 | -- ] | ||
88 | -- where | ||
89 | -- n = rows m | ||
90 | -- rs = toLists m | ||
91 | -- ss = map (\r -> intercalate " & " $ map show r) rs | ||
92 | -- ts = zipWith (\i r -> show i ++ " & " ++ r) (toList c) ss | ||
93 | -- us = vcat $ map (\r -> text r <+> text "\\\\") ts | ||
94 | -- bs = " & " ++ (intercalate " & " $ map show $ toList b) | ||
95 | -- b2s = " & " ++ (intercalate " & " $ map show $ toList b2) | ||
96 | -- | ||
97 | -- main :: IO () | ||
98 | -- main = do | ||
99 | -- | ||
100 | -- let res = butcherTable (SDIRK_2_1_2 undefined) | ||
101 | -- putStrLn $ show res | ||
102 | -- putStrLn $ butcherTableauTex res | ||
103 | -- | ||
104 | -- let resA = butcherTable (KVAERNO_4_2_3 undefined) | ||
105 | -- putStrLn $ show resA | ||
106 | -- putStrLn $ butcherTableauTex resA | ||
107 | -- | ||
108 | -- let resB = butcherTable (SDIRK_5_3_4 undefined) | ||
109 | -- putStrLn $ show resB | ||
110 | -- putStrLn $ butcherTableauTex resB | ||
111 | -- @ | ||
112 | -- | ||
113 | -- Using the code above from the examples gives | ||
114 | -- | ||
70 | -- KVAERNO_4_2_3 | 115 | -- KVAERNO_4_2_3 |
71 | -- | 116 | -- |
72 | -- \[ | 117 | -- \[ |
@@ -116,8 +161,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve | |||
116 | , butcherTable | 161 | , butcherTable |
117 | , ODEMethod(..) | 162 | , ODEMethod(..) |
118 | , StepControl(..) | 163 | , StepControl(..) |
119 | , Jacobian | ||
120 | , SundialsDiagnostics(..) | ||
121 | ) where | 164 | ) where |
122 | 165 | ||
123 | import qualified Language.C.Inline as C | 166 | import qualified Language.C.Inline as C |
@@ -126,27 +169,50 @@ import qualified Language.C.Inline.Unsafe as CU | |||
126 | import Data.Monoid ((<>)) | 169 | import Data.Monoid ((<>)) |
127 | import Data.Maybe (isJust) | 170 | import Data.Maybe (isJust) |
128 | 171 | ||
129 | import Foreign.C.Types | 172 | import Foreign.C.Types (CDouble, CInt, CLong) |
130 | import Foreign.Ptr (Ptr) | 173 | import Foreign.Ptr (Ptr) |
131 | import Foreign.ForeignPtr (newForeignPtr_) | 174 | import Foreign.Storable (poke) |
132 | import Foreign.Storable (Storable) | ||
133 | 175 | ||
134 | import qualified Data.Vector.Storable as V | 176 | import qualified Data.Vector.Storable as V |
135 | import qualified Data.Vector.Storable.Mutable as VM | ||
136 | 177 | ||
137 | import Data.Coerce (coerce) | 178 | import Data.Coerce (coerce) |
138 | import System.IO.Unsafe (unsafePerformIO) | 179 | import System.IO.Unsafe (unsafePerformIO) |
139 | import GHC.Generics | 180 | import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..), |
181 | from, conName) | ||
140 | 182 | ||
141 | import Numeric.LinearAlgebra.Devel (createVector) | 183 | import Numeric.LinearAlgebra.Devel (createVector) |
142 | 184 | ||
143 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | 185 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, |
144 | subMatrix, rows, cols, toLists, | 186 | cols, toLists, size, reshape, |
145 | size, subVector) | 187 | subVector, subMatrix, (><)) |
146 | 188 | ||
147 | import qualified Types as T | 189 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) |
148 | import Arkode | 190 | import qualified Numeric.Sundials.Arkode as T |
149 | import qualified Arkode as B | 191 | import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, |
192 | sDIRK_2_1_2, | ||
193 | bILLINGTON_3_3_2, | ||
194 | tRBDF2_3_3_2, | ||
195 | kVAERNO_4_2_3, | ||
196 | aRK324L2SA_DIRK_4_2_3, | ||
197 | cASH_5_2_4, | ||
198 | cASH_5_3_4, | ||
199 | sDIRK_5_3_4, | ||
200 | kVAERNO_5_3_4, | ||
201 | aRK436L2SA_DIRK_6_3_4, | ||
202 | kVAERNO_7_4_5, | ||
203 | aRK548L2SA_DIRK_8_4_5, | ||
204 | hEUN_EULER_2_1_2, | ||
205 | bOGACKI_SHAMPINE_4_2_3, | ||
206 | aRK324L2SA_ERK_4_2_3, | ||
207 | zONNEVELD_5_3_4, | ||
208 | aRK436L2SA_ERK_6_3_4, | ||
209 | sAYFY_ABURUB_6_3_4, | ||
210 | cASH_KARP_6_4_5, | ||
211 | fEHLBERG_6_4_5, | ||
212 | dORMAND_PRINCE_7_4_5, | ||
213 | aRK548L2SA_ERK_8_4_5, | ||
214 | vERNER_8_5_6, | ||
215 | fEHLBERG_13_7_8) | ||
150 | 216 | ||
151 | 217 | ||
152 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 218 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -162,69 +228,8 @@ C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | |||
162 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | 228 | C.include "<sundials/sundials_types.h>" -- definition of type realtype |
163 | C.include "<sundials/sundials_math.h>" | 229 | C.include "<sundials/sundials_math.h>" |
164 | C.include "../../../helpers.h" | 230 | C.include "../../../helpers.h" |
165 | C.include "Arkode_hsc.h" | 231 | C.include "Numeric/Sundials/Arkode_hsc.h" |
166 | 232 | ||
167 | |||
168 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) | ||
169 | getDataFromContents len ptr = do | ||
170 | qtr <- B.getContentPtr ptr | ||
171 | rtr <- B.getData qtr | ||
172 | vectorFromC len rtr | ||
173 | |||
174 | -- FIXME: Potentially an instance of Storable | ||
175 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
176 | _getMatrixDataFromContents ptr = do | ||
177 | qtr <- B.getContentMatrixPtr ptr | ||
178 | rs <- B.getNRows qtr | ||
179 | cs <- B.getNCols qtr | ||
180 | rtr <- B.getMatrixData qtr | ||
181 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
182 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
183 | |||
184 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
185 | putMatrixDataFromContents mat ptr = do | ||
186 | let rs = T.rows mat | ||
187 | cs = T.cols mat | ||
188 | vs = T.vals mat | ||
189 | qtr <- B.getContentMatrixPtr ptr | ||
190 | B.putNRows rs qtr | ||
191 | B.putNCols cs qtr | ||
192 | rtr <- B.getMatrixData qtr | ||
193 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
194 | -- FIXME: END | ||
195 | |||
196 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
197 | putDataInContents vec len ptr = do | ||
198 | qtr <- B.getContentPtr ptr | ||
199 | rtr <- B.getData qtr | ||
200 | vectorToC vec len rtr | ||
201 | |||
202 | -- Utils | ||
203 | |||
204 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
205 | vectorFromC len ptr = do | ||
206 | ptr' <- newForeignPtr_ ptr | ||
207 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
208 | |||
209 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
210 | vectorToC vec len ptr = do | ||
211 | ptr' <- newForeignPtr_ ptr | ||
212 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
213 | |||
214 | data SundialsDiagnostics = SundialsDiagnostics { | ||
215 | aRKodeGetNumSteps :: Int | ||
216 | , aRKodeGetNumStepAttempts :: Int | ||
217 | , aRKodeGetNumRhsEvals_fe :: Int | ||
218 | , aRKodeGetNumRhsEvals_fi :: Int | ||
219 | , aRKodeGetNumLinSolvSetups :: Int | ||
220 | , aRKodeGetNumErrTestFails :: Int | ||
221 | , aRKodeGetNumNonlinSolvIters :: Int | ||
222 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
223 | , aRKDlsGetNumJacEvals :: Int | ||
224 | , aRKDlsGetNumRhsEvals :: Int | ||
225 | } deriving Show | ||
226 | |||
227 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
228 | 233 | ||
229 | -- | Stepping functions | 234 | -- | Stepping functions |
230 | data ODEMethod = SDIRK_2_1_2 Jacobian | 235 | data ODEMethod = SDIRK_2_1_2 Jacobian |
@@ -390,15 +395,9 @@ odeSolveV | |||
390 | -> Vector Double -- ^ desired solution times | 395 | -> Vector Double -- ^ desired solution times |
391 | -> Matrix Double -- ^ solution | 396 | -> Matrix Double -- ^ solution |
392 | odeSolveV meth hi epsAbs epsRel f y0 ts = | 397 | odeSolveV meth hi epsAbs epsRel f y0 ts = |
393 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of | 398 | odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts |
394 | Left c -> error $ show c -- FIXME | 399 | where |
395 | -- FIXME: Can we do better than using lists? | 400 | g t x0 = coerce $ f t x0 |
396 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
397 | where | ||
398 | us = toList ts | ||
399 | nR = length us | ||
400 | nC = size y0 | ||
401 | g t x0 = coerce $ f t x0 | ||
402 | 401 | ||
403 | -- | A version of 'odeSolveV' with reasonable default parameters and | 402 | -- | A version of 'odeSolveV' with reasonable default parameters and |
404 | -- system of equations defined using lists. FIXME: we should say | 403 | -- system of equations defined using lists. FIXME: we should say |
@@ -410,16 +409,11 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
410 | -> Matrix Double -- ^ solution | 409 | -> Matrix Double -- ^ solution |
411 | odeSolve f y0 ts = | 410 | odeSolve f y0 ts = |
412 | -- FIXME: These tolerances are different from the ones in GSL | 411 | -- FIXME: These tolerances are different from the ones in GSL |
413 | case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | 412 | odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) |
414 | Left c -> error $ show c -- FIXME | ||
415 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
416 | where | 413 | where |
417 | us = toList ts | ||
418 | nR = length us | ||
419 | nC = length y0 | ||
420 | g t x0 = V.fromList $ f t (V.toList x0) | 414 | g t x0 = V.fromList $ f t (V.toList x0) |
421 | 415 | ||
422 | odeSolveVWith' :: | 416 | odeSolveVWith :: |
423 | ODEMethod | 417 | ODEMethod |
424 | -> StepControl | 418 | -> StepControl |
425 | -> Maybe Double -- ^ initial step size - by default, ARKode | 419 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -432,16 +426,22 @@ odeSolveVWith' :: | |||
432 | -> V.Vector Double -- ^ Initial conditions | 426 | -> V.Vector Double -- ^ Initial conditions |
433 | -> V.Vector Double -- ^ Desired solution times | 427 | -> V.Vector Double -- ^ Desired solution times |
434 | -> Matrix Double -- ^ Error code or solution | 428 | -> Matrix Double -- ^ Error code or solution |
435 | odeSolveVWith' method control initStepSize f y0 tt = | 429 | odeSolveVWith method control initStepSize f y0 tt = |
436 | case odeSolveVWith method control initStepSize f y0 tt of | 430 | case odeSolveVWith' opts method control initStepSize f y0 tt of |
437 | Left c -> error $ show c -- FIXME | 431 | Left c -> error $ show c -- FIXME |
438 | Right (v, _d) -> (nR >< nC) (V.toList v) | 432 | Right (v, _d) -> v |
439 | where | 433 | where |
440 | nR = V.length tt | 434 | opts = ODEOpts { maxNumSteps = 10000 |
441 | nC = V.length y0 | 435 | , minStep = 1.0e-12 |
436 | , relTol = error "relTol" | ||
437 | , absTols = error "absTol" | ||
438 | , initStep = error "initStep" | ||
439 | , maxFail = 10 | ||
440 | } | ||
442 | 441 | ||
443 | odeSolveVWith :: | 442 | odeSolveVWith' :: |
444 | ODEMethod | 443 | ODEOpts |
444 | -> ODEMethod | ||
445 | -> StepControl | 445 | -> StepControl |
446 | -> Maybe Double -- ^ initial step size - by default, ARKode | 446 | -> Maybe Double -- ^ initial step size - by default, ARKode |
447 | -- estimates the initial step size to be the | 447 | -- estimates the initial step size to be the |
@@ -452,19 +452,21 @@ odeSolveVWith :: | |||
452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
453 | -> V.Vector Double -- ^ Initial conditions | 453 | -> V.Vector Double -- ^ Initial conditions |
454 | -> V.Vector Double -- ^ Desired solution times | 454 | -> V.Vector Double -- ^ Desired solution times |
455 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 455 | -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution |
456 | odeSolveVWith method control initStepSize f y0 tt = | 456 | odeSolveVWith' opts method control initStepSize f y0 tt = |
457 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | 457 | case solveOdeC (fromIntegral $ maxFail opts) |
458 | (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) | ||
459 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | ||
458 | (coerce f) (coerce y0) (coerce tt) of | 460 | (coerce f) (coerce y0) (coerce tt) of |
459 | Left c -> Left $ fromIntegral c | 461 | Left c -> Left $ fromIntegral c |
460 | Right (v, d) -> Right (coerce v, d) | 462 | Right (v, d) -> Right (reshape l (coerce v), d) |
461 | where | 463 | where |
462 | l = size y0 | 464 | l = size y0 |
463 | scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) | 465 | scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) |
464 | scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) | 466 | scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) |
465 | scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) | 467 | scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) |
466 | -- FIXME; Should we check that the length of ss is correct? | 468 | -- FIXME; Should we check that the length of ss is correct? |
467 | scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) | 469 | scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) |
468 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ | 470 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ |
469 | getJacobian method | 471 | getJacobian method |
470 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | 472 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } |
@@ -476,6 +478,9 @@ odeSolveVWith method control initStepSize f y0 tt = | |||
476 | 478 | ||
477 | solveOdeC :: | 479 | solveOdeC :: |
478 | CInt -> | 480 | CInt -> |
481 | CLong -> | ||
482 | CDouble -> | ||
483 | CInt -> | ||
479 | Maybe CDouble -> | 484 | Maybe CDouble -> |
480 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> | 485 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> |
481 | (V.Vector CDouble, CDouble) -> | 486 | (V.Vector CDouble, CDouble) -> |
@@ -483,7 +488,8 @@ solveOdeC :: | |||
483 | -> V.Vector CDouble -- ^ Initial conditions | 488 | -> V.Vector CDouble -- ^ Initial conditions |
484 | -> V.Vector CDouble -- ^ Desired solution times | 489 | -> V.Vector CDouble -- ^ Desired solution times |
485 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | 490 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution |
486 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | 491 | solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize |
492 | jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do | ||
487 | 493 | ||
488 | let isInitStepSize :: CInt | 494 | let isInitStepSize :: CInt |
489 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize | 495 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize |
@@ -494,14 +500,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
494 | -- used :( | 500 | -- used :( |
495 | Nothing -> 0.0 | 501 | Nothing -> 0.0 |
496 | Just x -> x | 502 | Just x -> x |
503 | |||
497 | let dim = V.length f0 | 504 | let dim = V.length f0 |
498 | nEq :: CLong | 505 | nEq :: CLong |
499 | nEq = fromIntegral dim | 506 | nEq = fromIntegral dim |
500 | nTs :: CInt | 507 | nTs :: CInt |
501 | nTs = fromIntegral $ V.length ts | 508 | nTs = fromIntegral $ V.length ts |
502 | -- FIXME: fMut is not actually mutatated | ||
503 | fMut <- V.thaw f0 | ||
504 | tMut <- V.thaw ts | ||
505 | -- FIXME: I believe this gets taken from the ghc heap and so should | 509 | -- FIXME: I believe this gets taken from the ghc heap and so should |
506 | -- be subject to garbage collection. | 510 | -- be subject to garbage collection. |
507 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | 511 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) |
@@ -509,7 +513,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
509 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | 513 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME |
510 | diagMut <- V.thaw diagnostics | 514 | diagMut <- V.thaw diagnostics |
511 | -- We need the types that sundials expects. These are tied together | 515 | -- We need the types that sundials expects. These are tied together |
512 | -- in 'Types'. FIXME: The Haskell type is currently empty! | 516 | -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty! |
513 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 517 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
514 | funIO x y f _ptr = do | 518 | funIO x y f _ptr = do |
515 | -- Convert the pointer we get from C (y) to a vector, and then | 519 | -- Convert the pointer we get from C (y) to a vector, and then |
@@ -529,7 +533,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
529 | case jacH of | 533 | case jacH of |
530 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | 534 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
531 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | 535 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y |
532 | putMatrixDataFromContents j jacS | 536 | poke jacS j |
533 | -- FIXME: I don't understand what this comment means | 537 | -- FIXME: I don't understand what this comment means |
534 | -- Unsafe since the function will be called many times. | 538 | -- Unsafe since the function will be called many times. |
535 | [CU.exp| int{ 0 } |] | 539 | [CU.exp| int{ 0 } |] |
@@ -549,7 +553,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
549 | 553 | ||
550 | /* general problem parameters */ | 554 | /* general problem parameters */ |
551 | 555 | ||
552 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ | 556 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ |
553 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | 557 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ |
554 | 558 | ||
555 | /* Initialize data structures */ | 559 | /* Initialize data structures */ |
@@ -558,14 +562,14 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
558 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | 562 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; |
559 | /* Specify initial condition */ | 563 | /* Specify initial condition */ |
560 | for (i = 0; i < NEQ; i++) { | 564 | for (i = 0; i < NEQ; i++) { |
561 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | 565 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; |
562 | }; | 566 | }; |
563 | 567 | ||
564 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ | 568 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ |
565 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; | 569 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; |
566 | /* Specify tolerances */ | 570 | /* Specify tolerances */ |
567 | for (i = 0; i < NEQ; i++) { | 571 | for (i = 0; i < NEQ; i++) { |
568 | NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; | 572 | NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; |
569 | }; | 573 | }; |
570 | 574 | ||
571 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | 575 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ |
@@ -577,7 +581,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
577 | /* problem as fully implicit and set f_E to NULL and f_I to f. */ | 581 | /* problem as fully implicit and set f_E to NULL and f_I to f. */ |
578 | 582 | ||
579 | /* Here we use the C types defined in helpers.h which tie up with */ | 583 | /* Here we use the C types defined in helpers.h which tie up with */ |
580 | /* the Haskell types defined in Types */ | 584 | /* the Haskell types defined in CLangToHaskellTypes */ |
581 | if ($(int method) < MIN_DIRK_NUM) { | 585 | if ($(int method) < MIN_DIRK_NUM) { |
582 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); | 586 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); |
583 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 587 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
@@ -586,14 +590,15 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
586 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 590 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
587 | } | 591 | } |
588 | 592 | ||
589 | /* FIXME: A hack for initial testing */ | 593 | flag = ARKodeSetMinStep(arkode_mem, $(double minStep_)); |
590 | flag = ARKodeSetMinStep(arkode_mem, 1.0e-12); | ||
591 | if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; | 594 | if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; |
592 | flag = ARKodeSetMaxNumSteps(arkode_mem, 10000); | 595 | flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_)); |
593 | if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; | 596 | if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; |
597 | flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails)); | ||
598 | if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1; | ||
594 | 599 | ||
595 | /* Set routines */ | 600 | /* Set routines */ |
596 | flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); | 601 | flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv); |
597 | if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; | 602 | if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; |
598 | 603 | ||
599 | /* Initialize dense matrix data structure and solver */ | 604 | /* Initialize dense matrix data structure and solver */ |
@@ -638,7 +643,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
638 | /* Stops when the final time has been reached */ | 643 | /* Stops when the final time has been reached */ |
639 | for (i = 1; i < $(int nTs); i++) { | 644 | for (i = 1; i < $(int nTs); i++) { |
640 | 645 | ||
641 | flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ | 646 | flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */ |
642 | if (check_flag(&flag, "ARKode", 1)) break; | 647 | if (check_flag(&flag, "ARKode", 1)) break; |
643 | 648 | ||
644 | /* Store the results for Haskell */ | 649 | /* Store the results for Haskell */ |
@@ -738,7 +743,7 @@ butcherTable method = | |||
738 | case getBT method of | 743 | case getBT method of |
739 | Left c -> error $ show c -- FIXME | 744 | Left c -> error $ show c -- FIXME |
740 | Right (ButcherTable' v w x y, sqp) -> | 745 | Right (ButcherTable' v w x y, sqp) -> |
741 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) | 746 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v) |
742 | , cv = subVector 0 s w | 747 | , cv = subVector 0 s w |
743 | , bv = subVector 0 s x | 748 | , bv = subVector 0 s x |
744 | , b2v = subVector 0 s y | 749 | , b2v = subVector 0 s y |
@@ -773,11 +778,11 @@ getButcherTable method = unsafePerformIO $ do | |||
773 | 778 | ||
774 | btSQP :: V.Vector CInt <- createVector 3 | 779 | btSQP :: V.Vector CInt <- createVector 3 |
775 | btSQPMut <- V.thaw btSQP | 780 | btSQPMut <- V.thaw btSQP |
776 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) | 781 | btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax) |
777 | btAsMut <- V.thaw btAs | 782 | btAsMut <- V.thaw btAs |
778 | btCs :: V.Vector CDouble <- createVector B.arkSMax | 783 | btCs :: V.Vector CDouble <- createVector arkSMax |
779 | btBs :: V.Vector CDouble <- createVector B.arkSMax | 784 | btBs :: V.Vector CDouble <- createVector arkSMax |
780 | btB2s :: V.Vector CDouble <- createVector B.arkSMax | 785 | btB2s :: V.Vector CDouble <- createVector arkSMax |
781 | btCsMut <- V.thaw btCs | 786 | btCsMut <- V.thaw btCs |
782 | btBsMut <- V.thaw btBs | 787 | btBsMut <- V.thaw btBs |
783 | btB2sMut <- V.thaw btB2s | 788 | btB2sMut <- V.thaw btB2s |
diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc index 9db37b5..0850258 100644 --- a/packages/sundials/src/Arkode.hsc +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc | |||
@@ -1,7 +1,23 @@ | |||
1 | module Arkode where | 1 | {-# LANGUAGE QuasiQuotes #-} |
2 | {-# LANGUAGE TemplateHaskell #-} | ||
3 | {-# LANGUAGE OverloadedStrings #-} | ||
4 | {-# LANGUAGE EmptyDataDecls #-} | ||
2 | 5 | ||
3 | import Foreign | 6 | module Numeric.Sundials.Arkode where |
4 | import Foreign.C.Types | 7 | |
8 | import Foreign | ||
9 | import Foreign.C.Types | ||
10 | |||
11 | import Language.C.Types as CT | ||
12 | |||
13 | import qualified Data.Vector.Storable as VS | ||
14 | import qualified Data.Vector.Storable.Mutable as VM | ||
15 | |||
16 | import qualified Language.Haskell.TH as TH | ||
17 | import qualified Data.Map as Map | ||
18 | import Language.C.Inline.Context | ||
19 | |||
20 | import qualified Data.Vector.Storable as V | ||
5 | 21 | ||
6 | 22 | ||
7 | #include <stdio.h> | 23 | #include <stdio.h> |
@@ -10,7 +26,76 @@ import Foreign.C.Types | |||
10 | #include <nvector/nvector_serial.h> | 26 | #include <nvector/nvector_serial.h> |
11 | #include <sunmatrix/sunmatrix_dense.h> | 27 | #include <sunmatrix/sunmatrix_dense.h> |
12 | #include <arkode/arkode.h> | 28 | #include <arkode/arkode.h> |
13 | 29 | #include <cvode/cvode.h> | |
30 | |||
31 | |||
32 | data SunVector | ||
33 | data SunMatrix = SunMatrix { rows :: CInt | ||
34 | , cols :: CInt | ||
35 | , vals :: V.Vector CDouble | ||
36 | } | ||
37 | |||
38 | -- | This is true only if configured/ built as 64 bits | ||
39 | type SunIndexType = CLong | ||
40 | |||
41 | sunTypesTable :: Map.Map TypeSpecifier TH.TypeQ | ||
42 | sunTypesTable = Map.fromList | ||
43 | [ | ||
44 | (TypeName "sunindextype", [t| SunIndexType |] ) | ||
45 | , (TypeName "SunVector", [t| SunVector |] ) | ||
46 | , (TypeName "SunMatrix", [t| SunMatrix |] ) | ||
47 | ] | ||
48 | |||
49 | sunCtx :: Context | ||
50 | sunCtx = mempty {ctxTypesTable = sunTypesTable} | ||
51 | |||
52 | getMatrixDataFromContents :: Ptr SunMatrix -> IO SunMatrix | ||
53 | getMatrixDataFromContents ptr = do | ||
54 | qtr <- getContentMatrixPtr ptr | ||
55 | rs <- getNRows qtr | ||
56 | cs <- getNCols qtr | ||
57 | rtr <- getMatrixData qtr | ||
58 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
59 | return $ SunMatrix { rows = rs, cols = cs, vals = vs } | ||
60 | |||
61 | putMatrixDataFromContents :: SunMatrix -> Ptr SunMatrix -> IO () | ||
62 | putMatrixDataFromContents mat ptr = do | ||
63 | let rs = rows mat | ||
64 | cs = cols mat | ||
65 | vs = vals mat | ||
66 | qtr <- getContentMatrixPtr ptr | ||
67 | putNRows rs qtr | ||
68 | putNCols cs qtr | ||
69 | rtr <- getMatrixData qtr | ||
70 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
71 | |||
72 | instance Storable SunMatrix where | ||
73 | poke = flip putMatrixDataFromContents | ||
74 | peek = getMatrixDataFromContents | ||
75 | sizeOf _ = error "sizeOf not supported for SunMatrix" | ||
76 | alignment _ = error "alignment not supported for SunMatrix" | ||
77 | |||
78 | vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) | ||
79 | vectorFromC len ptr = do | ||
80 | ptr' <- newForeignPtr_ ptr | ||
81 | VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
82 | |||
83 | vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () | ||
84 | vectorToC vec len ptr = do | ||
85 | ptr' <- newForeignPtr_ ptr | ||
86 | VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
87 | |||
88 | getDataFromContents :: Int -> Ptr SunVector -> IO (VS.Vector CDouble) | ||
89 | getDataFromContents len ptr = do | ||
90 | qtr <- getContentPtr ptr | ||
91 | rtr <- getData qtr | ||
92 | vectorFromC len rtr | ||
93 | |||
94 | putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () | ||
95 | putDataInContents vec len ptr = do | ||
96 | qtr <- getContentPtr ptr | ||
97 | rtr <- getData qtr | ||
98 | vectorToC vec len rtr | ||
14 | 99 | ||
15 | #def typedef struct _generic_N_Vector SunVector; | 100 | #def typedef struct _generic_N_Vector SunVector; |
16 | #def typedef struct _N_VectorContent_Serial SunContent; | 101 | #def typedef struct _N_VectorContent_Serial SunContent; |
@@ -40,6 +125,11 @@ getContentPtr ptr = (#peek SunVector, content) ptr | |||
40 | getData :: Storable a => Ptr b -> IO a | 125 | getData :: Storable a => Ptr b -> IO a |
41 | getData ptr = (#peek SunContent, data) ptr | 126 | getData ptr = (#peek SunContent, data) ptr |
42 | 127 | ||
128 | cV_ADAMS :: Int | ||
129 | cV_ADAMS = #const CV_ADAMS | ||
130 | cV_BDF :: Int | ||
131 | cV_BDF = #const CV_BDF | ||
132 | |||
43 | arkSMax :: Int | 133 | arkSMax :: Int |
44 | arkSMax = #const ARK_S_MAX | 134 | arkSMax = #const ARK_S_MAX |
45 | 135 | ||
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs new file mode 100644 index 0000000..a6f185e --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | |||
@@ -0,0 +1,476 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | |||
9 | ----------------------------------------------------------------------------- | ||
10 | -- | | ||
11 | -- Module : Numeric.Sundials.CVode.ODE | ||
12 | -- Copyright : Dominic Steinitz 2018, | ||
13 | -- Novadiscovery 2018 | ||
14 | -- License : BSD | ||
15 | -- Maintainer : Dominic Steinitz | ||
16 | -- Stability : provisional | ||
17 | -- | ||
18 | -- Solution of ordinary differential equation (ODE) initial value problems. | ||
19 | -- | ||
20 | -- <https://computation.llnl.gov/projects/sundials/sundials-software> | ||
21 | -- | ||
22 | -- A simple example: | ||
23 | -- | ||
24 | -- <<diagrams/brusselator.png#diagram=brusselator&height=400&width=500>> | ||
25 | -- | ||
26 | -- @ | ||
27 | -- import Numeric.Sundials.CVode.ODE | ||
28 | -- import Numeric.LinearAlgebra | ||
29 | -- | ||
30 | -- import Plots as P | ||
31 | -- import qualified Diagrams.Prelude as D | ||
32 | -- import Diagrams.Backend.Rasterific | ||
33 | -- | ||
34 | -- brusselator :: Double -> [Double] -> [Double] | ||
35 | -- brusselator _t x = [ a - (w + 1) * u + v * u * u | ||
36 | -- , w * u - v * u * u | ||
37 | -- , (b - w) / eps - w * u | ||
38 | -- ] | ||
39 | -- where | ||
40 | -- a = 1.0 | ||
41 | -- b = 3.5 | ||
42 | -- eps = 5.0e-6 | ||
43 | -- u = x !! 0 | ||
44 | -- v = x !! 1 | ||
45 | -- w = x !! 2 | ||
46 | -- | ||
47 | -- lSaxis :: [[Double]] -> P.Axis B D.V2 Double | ||
48 | -- lSaxis xs = P.r2Axis &~ do | ||
49 | -- let ts = xs!!0 | ||
50 | -- us = xs!!1 | ||
51 | -- vs = xs!!2 | ||
52 | -- ws = xs!!3 | ||
53 | -- P.linePlot' $ zip ts us | ||
54 | -- P.linePlot' $ zip ts vs | ||
55 | -- P.linePlot' $ zip ts ws | ||
56 | -- | ||
57 | -- main = do | ||
58 | -- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
59 | -- renderRasterific "diagrams/brusselator.png" | ||
60 | -- (D.dims2D 500.0 500.0) | ||
61 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | ||
62 | -- @ | ||
63 | -- | ||
64 | ----------------------------------------------------------------------------- | ||
65 | module Numeric.Sundials.CVode.ODE ( odeSolve | ||
66 | , odeSolveV | ||
67 | , odeSolveVWith | ||
68 | , odeSolveVWith' | ||
69 | , ODEMethod(..) | ||
70 | , StepControl(..) | ||
71 | ) where | ||
72 | |||
73 | import qualified Language.C.Inline as C | ||
74 | import qualified Language.C.Inline.Unsafe as CU | ||
75 | |||
76 | import Data.Monoid ((<>)) | ||
77 | import Data.Maybe (isJust) | ||
78 | |||
79 | import Foreign.C.Types (CDouble, CInt, CLong) | ||
80 | import Foreign.Ptr (Ptr) | ||
81 | import Foreign.Storable (poke) | ||
82 | |||
83 | import qualified Data.Vector.Storable as V | ||
84 | |||
85 | import Data.Coerce (coerce) | ||
86 | import System.IO.Unsafe (unsafePerformIO) | ||
87 | |||
88 | import Numeric.LinearAlgebra.Devel (createVector) | ||
89 | |||
90 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, | ||
91 | cols, toLists, size, reshape) | ||
92 | |||
93 | import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF, | ||
94 | getDataFromContents, putDataInContents) | ||
95 | import qualified Numeric.Sundials.Arkode as T | ||
96 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) | ||
97 | |||
98 | |||
99 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
100 | |||
101 | C.include "<stdlib.h>" | ||
102 | C.include "<stdio.h>" | ||
103 | C.include "<math.h>" | ||
104 | C.include "<cvode/cvode.h>" -- prototypes for CVODE fcts., consts. | ||
105 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
106 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
107 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
108 | C.include "<cvode/cvode_direct.h>" -- access to CVDls interface | ||
109 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
110 | C.include "<sundials/sundials_math.h>" | ||
111 | C.include "../../../helpers.h" | ||
112 | C.include "Numeric/Sundials/Arkode_hsc.h" | ||
113 | |||
114 | |||
115 | -- | Stepping functions | ||
116 | data ODEMethod = ADAMS | ||
117 | | BDF | ||
118 | |||
119 | getMethod :: ODEMethod -> Int | ||
120 | getMethod (ADAMS) = cV_ADAMS | ||
121 | getMethod (BDF) = cV_BDF | ||
122 | |||
123 | getJacobian :: ODEMethod -> Maybe Jacobian | ||
124 | getJacobian _ = Nothing | ||
125 | |||
126 | -- | A version of 'odeSolveVWith' with reasonable default step control. | ||
127 | odeSolveV | ||
128 | :: ODEMethod | ||
129 | -> Maybe Double -- ^ initial step size - by default, CVode | ||
130 | -- estimates the initial step size to be the | ||
131 | -- solution \(h\) of the equation | ||
132 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
133 | -- \(\ddot{y}\) is an estimated value of the | ||
134 | -- second derivative of the solution at \(t_0\) | ||
135 | -> Double -- ^ absolute tolerance for the state vector | ||
136 | -> Double -- ^ relative tolerance for the state vector | ||
137 | -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
138 | -> Vector Double -- ^ initial conditions | ||
139 | -> Vector Double -- ^ desired solution times | ||
140 | -> Matrix Double -- ^ solution | ||
141 | odeSolveV meth hi epsAbs epsRel f y0 ts = | ||
142 | odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts | ||
143 | where | ||
144 | g t x0 = coerce $ f t x0 | ||
145 | |||
146 | -- | A version of 'odeSolveV' with reasonable default parameters and | ||
147 | -- system of equations defined using lists. FIXME: we should say | ||
148 | -- something about the fact we could use the Jacobian but don't for | ||
149 | -- compatibility with hmatrix-gsl. | ||
150 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
151 | -> [Double] -- ^ initial conditions | ||
152 | -> Vector Double -- ^ desired solution times | ||
153 | -> Matrix Double -- ^ solution | ||
154 | odeSolve f y0 ts = | ||
155 | -- FIXME: These tolerances are different from the ones in GSL | ||
156 | odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) | ||
157 | where | ||
158 | g t x0 = V.fromList $ f t (V.toList x0) | ||
159 | |||
160 | odeSolveVWith :: | ||
161 | ODEMethod | ||
162 | -> StepControl | ||
163 | -> Maybe Double -- ^ initial step size - by default, CVode | ||
164 | -- estimates the initial step size to be the | ||
165 | -- solution \(h\) of the equation | ||
166 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
167 | -- \(\ddot{y}\) is an estimated value of the second | ||
168 | -- derivative of the solution at \(t_0\) | ||
169 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
170 | -> V.Vector Double -- ^ Initial conditions | ||
171 | -> V.Vector Double -- ^ Desired solution times | ||
172 | -> Matrix Double -- ^ Error code or solution | ||
173 | odeSolveVWith method control initStepSize f y0 tt = | ||
174 | case odeSolveVWith' opts method control initStepSize f y0 tt of | ||
175 | Left c -> error $ show c -- FIXME | ||
176 | Right (v, _d) -> v | ||
177 | where | ||
178 | opts = ODEOpts { maxNumSteps = 10000 | ||
179 | , minStep = 1.0e-12 | ||
180 | , relTol = error "relTol" | ||
181 | , absTols = error "absTol" | ||
182 | , initStep = error "initStep" | ||
183 | , maxFail = 10 | ||
184 | } | ||
185 | |||
186 | odeSolveVWith' :: | ||
187 | ODEOpts | ||
188 | -> ODEMethod | ||
189 | -> StepControl | ||
190 | -> Maybe Double -- ^ initial step size - by default, CVode | ||
191 | -- estimates the initial step size to be the | ||
192 | -- solution \(h\) of the equation | ||
193 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
194 | -- \(\ddot{y}\) is an estimated value of the second | ||
195 | -- derivative of the solution at \(t_0\) | ||
196 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
197 | -> V.Vector Double -- ^ Initial conditions | ||
198 | -> V.Vector Double -- ^ Desired solution times | ||
199 | -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution | ||
200 | odeSolveVWith' opts method control initStepSize f y0 tt = | ||
201 | case solveOdeC (fromIntegral $ maxFail opts) | ||
202 | (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) | ||
203 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | ||
204 | (coerce f) (coerce y0) (coerce tt) of | ||
205 | Left c -> Left $ fromIntegral c | ||
206 | Right (v, d) -> Right (reshape l (coerce v), d) | ||
207 | where | ||
208 | l = size y0 | ||
209 | scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) | ||
210 | scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) | ||
211 | scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) | ||
212 | -- FIXME; Should we check that the length of ss is correct? | ||
213 | scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) | ||
214 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ | ||
215 | getJacobian method | ||
216 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | ||
217 | where | ||
218 | nr = fromIntegral $ rows m | ||
219 | nc = fromIntegral $ cols m | ||
220 | -- FIXME: efficiency | ||
221 | vs = V.fromList $ map coerce $ concat $ toLists m | ||
222 | |||
223 | solveOdeC :: | ||
224 | CInt -> | ||
225 | CLong -> | ||
226 | CDouble -> | ||
227 | CInt -> | ||
228 | Maybe CDouble -> | ||
229 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> | ||
230 | (V.Vector CDouble, CDouble) -> | ||
231 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
232 | -> V.Vector CDouble -- ^ Initial conditions | ||
233 | -> V.Vector CDouble -- ^ Desired solution times | ||
234 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | ||
235 | solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize | ||
236 | jacH (aTols, rTol) fun f0 ts = | ||
237 | unsafePerformIO $ do | ||
238 | |||
239 | let isInitStepSize :: CInt | ||
240 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize | ||
241 | ss :: CDouble | ||
242 | ss = case initStepSize of | ||
243 | -- It would be better to put an error message here but | ||
244 | -- inline-c seems to evaluate this even if it is never | ||
245 | -- used :( | ||
246 | Nothing -> 0.0 | ||
247 | Just x -> x | ||
248 | |||
249 | let dim = V.length f0 | ||
250 | nEq :: CLong | ||
251 | nEq = fromIntegral dim | ||
252 | nTs :: CInt | ||
253 | nTs = fromIntegral $ V.length ts | ||
254 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | ||
255 | qMatMut <- V.thaw quasiMatrixRes | ||
256 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | ||
257 | diagMut <- V.thaw diagnostics | ||
258 | -- We need the types that sundials expects. These are tied together | ||
259 | -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty! | ||
260 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
261 | funIO x y f _ptr = do | ||
262 | -- Convert the pointer we get from C (y) to a vector, and then | ||
263 | -- apply the user-supplied function. | ||
264 | fImm <- fun x <$> getDataFromContents dim y | ||
265 | -- Fill in the provided pointer with the resulting vector. | ||
266 | putDataInContents fImm dim f | ||
267 | -- FIXME: I don't understand what this comment means | ||
268 | -- Unsafe since the function will be called many times. | ||
269 | [CU.exp| int{ 0 } |] | ||
270 | let isJac :: CInt | ||
271 | isJac = fromIntegral $ fromEnum $ isJust jacH | ||
272 | jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> | ||
273 | Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> | ||
274 | IO CInt | ||
275 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | ||
276 | case jacH of | ||
277 | Nothing -> error "Numeric.Sundials.CVode.ODE: Jacobian not defined" | ||
278 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | ||
279 | poke jacS j | ||
280 | -- FIXME: I don't understand what this comment means | ||
281 | -- Unsafe since the function will be called many times. | ||
282 | [CU.exp| int{ 0 } |] | ||
283 | |||
284 | res <- [C.block| int { | ||
285 | /* general problem variables */ | ||
286 | |||
287 | int flag; /* reusable error-checking flag */ | ||
288 | int i, j; /* reusable loop indices */ | ||
289 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
290 | N_Vector tv = NULL; /* empty vector for storing absolute tolerances */ | ||
291 | |||
292 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
293 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
294 | void *cvode_mem = NULL; /* empty CVODE memory structure */ | ||
295 | realtype t; | ||
296 | long int nst, nfe, nsetups, nje, nfeLS, nni, ncfn, netf, nge; | ||
297 | |||
298 | /* general problem parameters */ | ||
299 | |||
300 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ | ||
301 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
302 | |||
303 | /* Initialize data structures */ | ||
304 | |||
305 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
306 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
307 | /* Specify initial condition */ | ||
308 | for (i = 0; i < NEQ; i++) { | ||
309 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | ||
310 | }; | ||
311 | |||
312 | cvode_mem = CVodeCreate($(int method), CV_NEWTON); | ||
313 | if (check_flag((void *)cvode_mem, "CVodeCreate", 0)) return(1); | ||
314 | |||
315 | /* Call CVodeInit to initialize the integrator memory and specify the | ||
316 | * user's right hand side function in y'=f(t,y), the inital time T0, and | ||
317 | * the initial dependent variable vector y. */ | ||
318 | flag = CVodeInit(cvode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | ||
319 | if (check_flag(&flag, "CVodeInit", 1)) return(1); | ||
320 | |||
321 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ | ||
322 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; | ||
323 | /* Specify tolerances */ | ||
324 | for (i = 0; i < NEQ; i++) { | ||
325 | NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; | ||
326 | }; | ||
327 | |||
328 | flag = CVodeSetMinStep(cvode_mem, $(double minStep_)); | ||
329 | if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; | ||
330 | flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_)); | ||
331 | if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; | ||
332 | flag = CVodeSetMaxErrTestFails(cvode_mem, $(int maxErrTestFails)); | ||
333 | if (check_flag(&flag, "CVodeSetMaxErrTestFails", 1)) return 1; | ||
334 | |||
335 | /* Call CVodeSVtolerances to specify the scalar relative tolerance | ||
336 | * and vector absolute tolerances */ | ||
337 | flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv); | ||
338 | if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); | ||
339 | |||
340 | /* Initialize dense matrix data structure and solver */ | ||
341 | A = SUNDenseMatrix(NEQ, NEQ); | ||
342 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
343 | LS = SUNDenseLinearSolver(y, A); | ||
344 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
345 | |||
346 | /* Attach matrix and linear solver */ | ||
347 | flag = CVDlsSetLinearSolver(cvode_mem, LS, A); | ||
348 | if (check_flag(&flag, "CVDlsSetLinearSolver", 1)) return 1; | ||
349 | |||
350 | /* Set the initial step size if there is one */ | ||
351 | if ($(int isInitStepSize)) { | ||
352 | /* FIXME: We could check if the initial step size is 0 */ | ||
353 | /* or even NaN and then throw an error */ | ||
354 | flag = CVodeSetInitStep(cvode_mem, $(double ss)); | ||
355 | if (check_flag(&flag, "CVodeSetInitStep", 1)) return 1; | ||
356 | } | ||
357 | |||
358 | /* Set the Jacobian if there is one */ | ||
359 | if ($(int isJac)) { | ||
360 | flag = CVDlsSetJacFn(cvode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); | ||
361 | if (check_flag(&flag, "CVDlsSetJacFn", 1)) return 1; | ||
362 | } | ||
363 | |||
364 | /* Store initial conditions */ | ||
365 | for (j = 0; j < NEQ; j++) { | ||
366 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | ||
367 | } | ||
368 | |||
369 | /* Main time-stepping loop: calls CVode to perform the integration */ | ||
370 | /* Stops when the final time has been reached */ | ||
371 | for (i = 1; i < $(int nTs); i++) { | ||
372 | |||
373 | flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */ | ||
374 | if (check_flag(&flag, "CVode", 1)) break; | ||
375 | |||
376 | /* Store the results for Haskell */ | ||
377 | for (j = 0; j < NEQ; j++) { | ||
378 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); | ||
379 | } | ||
380 | |||
381 | /* unsuccessful solve: break */ | ||
382 | if (flag < 0) { | ||
383 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
384 | break; | ||
385 | } | ||
386 | } | ||
387 | |||
388 | /* Get some final statistics on how the solve progressed */ | ||
389 | |||
390 | flag = CVodeGetNumSteps(cvode_mem, &nst); | ||
391 | check_flag(&flag, "CVodeGetNumSteps", 1); | ||
392 | ($vec-ptr:(long int *diagMut))[0] = nst; | ||
393 | |||
394 | /* FIXME */ | ||
395 | ($vec-ptr:(long int *diagMut))[1] = 0; | ||
396 | |||
397 | flag = CVodeGetNumRhsEvals(cvode_mem, &nfe); | ||
398 | check_flag(&flag, "CVodeGetNumRhsEvals", 1); | ||
399 | ($vec-ptr:(long int *diagMut))[2] = nfe; | ||
400 | /* FIXME */ | ||
401 | ($vec-ptr:(long int *diagMut))[3] = 0; | ||
402 | |||
403 | flag = CVodeGetNumLinSolvSetups(cvode_mem, &nsetups); | ||
404 | check_flag(&flag, "CVodeGetNumLinSolvSetups", 1); | ||
405 | ($vec-ptr:(long int *diagMut))[4] = nsetups; | ||
406 | |||
407 | flag = CVodeGetNumErrTestFails(cvode_mem, &netf); | ||
408 | check_flag(&flag, "CVodeGetNumErrTestFails", 1); | ||
409 | ($vec-ptr:(long int *diagMut))[5] = netf; | ||
410 | |||
411 | flag = CVodeGetNumNonlinSolvIters(cvode_mem, &nni); | ||
412 | check_flag(&flag, "CVodeGetNumNonlinSolvIters", 1); | ||
413 | ($vec-ptr:(long int *diagMut))[6] = nni; | ||
414 | |||
415 | flag = CVodeGetNumNonlinSolvConvFails(cvode_mem, &ncfn); | ||
416 | check_flag(&flag, "CVodeGetNumNonlinSolvConvFails", 1); | ||
417 | ($vec-ptr:(long int *diagMut))[7] = ncfn; | ||
418 | |||
419 | flag = CVDlsGetNumJacEvals(cvode_mem, &nje); | ||
420 | check_flag(&flag, "CVDlsGetNumJacEvals", 1); | ||
421 | ($vec-ptr:(long int *diagMut))[8] = ncfn; | ||
422 | |||
423 | flag = CVDlsGetNumRhsEvals(cvode_mem, &nfeLS); | ||
424 | check_flag(&flag, "CVDlsGetNumRhsEvals", 1); | ||
425 | ($vec-ptr:(long int *diagMut))[9] = ncfn; | ||
426 | |||
427 | /* Clean up and return */ | ||
428 | |||
429 | N_VDestroy(y); /* Free y vector */ | ||
430 | N_VDestroy(tv); /* Free tv vector */ | ||
431 | CVodeFree(&cvode_mem); /* Free integrator memory */ | ||
432 | SUNLinSolFree(LS); /* Free linear solver */ | ||
433 | SUNMatDestroy(A); /* Free A matrix */ | ||
434 | |||
435 | return flag; | ||
436 | } |] | ||
437 | if res == 0 | ||
438 | then do | ||
439 | preD <- V.freeze diagMut | ||
440 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | ||
441 | (fromIntegral $ preD V.!1) | ||
442 | (fromIntegral $ preD V.!2) | ||
443 | (fromIntegral $ preD V.!3) | ||
444 | (fromIntegral $ preD V.!4) | ||
445 | (fromIntegral $ preD V.!5) | ||
446 | (fromIntegral $ preD V.!6) | ||
447 | (fromIntegral $ preD V.!7) | ||
448 | (fromIntegral $ preD V.!8) | ||
449 | (fromIntegral $ preD V.!9) | ||
450 | m <- V.freeze qMatMut | ||
451 | return $ Right (m, d) | ||
452 | else do | ||
453 | return $ Left res | ||
454 | |||
455 | -- | Adaptive step-size control | ||
456 | -- functions. | ||
457 | -- | ||
458 | -- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control) | ||
459 | -- allows the user to control the step size adjustment using | ||
460 | -- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where | ||
461 | -- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\) | ||
462 | -- is the required relative error, \(s_i\) is a vector of scaling | ||
463 | -- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and | ||
464 | -- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\). | ||
465 | -- | ||
466 | -- [ARKode](https://computation.llnl.gov/projects/sundials/arkode) | ||
467 | -- allows the user to control the step size adjustment using | ||
468 | -- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with | ||
469 | -- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl), | ||
470 | -- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no | ||
471 | -- effect. | ||
472 | data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical | ||
473 | | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem | ||
474 | | XX' Double Double Double Double -- ^ include both via relative tolerance | ||
475 | -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
476 | | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs new file mode 100644 index 0000000..027d99a --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs | |||
@@ -0,0 +1,32 @@ | |||
1 | module Numeric.Sundials.ODEOpts where | ||
2 | |||
3 | import Data.Word (Word32) | ||
4 | import qualified Data.Vector.Storable as VS | ||
5 | |||
6 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
7 | |||
8 | |||
9 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
10 | |||
11 | data ODEOpts = ODEOpts { | ||
12 | maxNumSteps :: Word32 | ||
13 | , minStep :: Double | ||
14 | , relTol :: Double | ||
15 | , absTols :: VS.Vector Double | ||
16 | , initStep :: Maybe Double | ||
17 | , maxFail :: Word32 | ||
18 | } deriving (Read, Show, Eq, Ord) | ||
19 | |||
20 | data SundialsDiagnostics = SundialsDiagnostics { | ||
21 | aRKodeGetNumSteps :: Int | ||
22 | , aRKodeGetNumStepAttempts :: Int | ||
23 | , aRKodeGetNumRhsEvals_fe :: Int | ||
24 | , aRKodeGetNumRhsEvals_fi :: Int | ||
25 | , aRKodeGetNumLinSolvSetups :: Int | ||
26 | , aRKodeGetNumErrTestFails :: Int | ||
27 | , aRKodeGetNumNonlinSolvIters :: Int | ||
28 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
29 | , aRKDlsGetNumJacEvals :: Int | ||
30 | , aRKDlsGetNumRhsEvals :: Int | ||
31 | } deriving Show | ||
32 | |||
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs deleted file mode 100644 index 04e4280..0000000 --- a/packages/sundials/src/Types.hs +++ /dev/null | |||
@@ -1,40 +0,0 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE EmptyDataDecls #-} | ||
8 | |||
9 | module Types where | ||
10 | |||
11 | import Foreign.C.Types | ||
12 | |||
13 | import qualified Language.Haskell.TH as TH | ||
14 | import qualified Language.C.Types as CT | ||
15 | import qualified Data.Map as Map | ||
16 | import Language.C.Inline.Context | ||
17 | |||
18 | import qualified Data.Vector.Storable as V | ||
19 | |||
20 | |||
21 | data SunVector | ||
22 | data SunMatrix = SunMatrix { rows :: CInt | ||
23 | , cols :: CInt | ||
24 | , vals :: V.Vector CDouble | ||
25 | } | ||
26 | |||
27 | -- FIXME: Is this true? | ||
28 | type SunIndexType = CLong | ||
29 | |||
30 | sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ | ||
31 | sunTypesTable = Map.fromList | ||
32 | [ | ||
33 | (CT.TypeName "sunindextype", [t| SunIndexType |] ) | ||
34 | , (CT.TypeName "SunVector", [t| SunVector |] ) | ||
35 | , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) | ||
36 | ] | ||
37 | |||
38 | sunCtx :: Context | ||
39 | sunCtx = mempty {ctxTypesTable = sunTypesTable} | ||
40 | |||