diff options
Diffstat (limited to 'packages/sundials')
-rw-r--r-- | packages/sundials/hmatrix-sundials.cabal | 7 | ||||
-rw-r--r-- | packages/sundials/src/Main.hs | 54 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 43 |
3 files changed, 84 insertions, 20 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 331e6c4..762537e 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -33,14 +33,17 @@ library | |||
33 | 33 | ||
34 | executable sundials | 34 | executable sundials |
35 | main-is: Main.hs | 35 | main-is: Main.hs |
36 | other-modules: Types | 36 | other-modules: Types, Numeric.Sundials.Arkode.ODE |
37 | other-extensions: QuasiQuotes, TemplateHaskell, MultiWayIf, OverloadedStrings | 37 | other-extensions: QuasiQuotes, TemplateHaskell, MultiWayIf, OverloadedStrings |
38 | build-depends: base >=4.10 && <4.11, | 38 | build-depends: base >=4.10 && <4.11, |
39 | inline-c >=0.6 && <0.7, | 39 | inline-c >=0.6 && <0.7, |
40 | vector >=0.12 && <0.13, | 40 | vector >=0.12 && <0.13, |
41 | template-haskell >=2.12 && <2.13, | 41 | template-haskell >=2.12 && <2.13, |
42 | containers >=0.5 && <0.6, | 42 | containers >=0.5 && <0.6, |
43 | hmatrix>=0.18 | 43 | hmatrix>=0.18, |
44 | plots, | ||
45 | diagrams-lib, | ||
46 | diagrams-rasterific | ||
44 | hs-source-dirs: src | 47 | hs-source-dirs: src |
45 | default-language: Haskell2010 | 48 | default-language: Haskell2010 |
46 | extra-libraries: sundials_arkode | 49 | extra-libraries: sundials_arkode |
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 2a561c4..5e51372 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -2,28 +2,58 @@ | |||
2 | 2 | ||
3 | import qualified Data.Vector.Storable as V | 3 | import qualified Data.Vector.Storable as V |
4 | import Numeric.Sundials.Arkode.ODE | 4 | import Numeric.Sundials.Arkode.ODE |
5 | import Numeric.LinearAlgebra | ||
5 | 6 | ||
6 | brusselator :: Double -> V.Vector Double -> V.Vector Double | 7 | import Plots as P |
7 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | 8 | import qualified Diagrams.Prelude as D |
8 | , w * u - v * u^2 | 9 | import Diagrams.Backend.Rasterific |
9 | , (b - w) / eps - w * u | 10 | |
10 | ] | 11 | import Control.Lens |
12 | import Data.List (zip4) | ||
13 | |||
14 | |||
15 | brusselator _t x = [ a - (w + 1) * u + v * u^2 | ||
16 | , w * u - v * u^2 | ||
17 | , (b - w) / eps - w * u | ||
18 | ] | ||
11 | where | 19 | where |
12 | a = 1.0 | 20 | a = 1.0 |
13 | b = 3.5 | 21 | b = 3.5 |
14 | eps = 5.0e-6 | 22 | eps = 5.0e-6 |
15 | u = x V.! 0 | 23 | u = x !! 0 |
16 | v = x V.! 1 | 24 | v = x !! 1 |
17 | w = x V.! 2 | 25 | w = x !! 2 |
18 | 26 | ||
19 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | 27 | stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] |
20 | where | 28 | where |
21 | lamda = -100.0 | 29 | lamda = -100.0 |
22 | u = v V.! 0 | 30 | u = v !! 0 |
31 | |||
32 | lSaxis :: [[Double]] -> P.Axis B D.V2 Double | ||
33 | lSaxis xs = P.r2Axis &~ do | ||
34 | let ts = xs!!0 | ||
35 | us = xs!!1 | ||
36 | vs = xs!!2 | ||
37 | ws = xs!!3 | ||
38 | P.linePlot' $ zip ts us | ||
39 | P.linePlot' $ zip ts vs | ||
40 | P.linePlot' $ zip ts ws | ||
41 | |||
42 | kSaxis :: [(Double, Double)] -> P.Axis B D.V2 Double | ||
43 | kSaxis xs = P.r2Axis &~ do | ||
44 | P.linePlot' xs | ||
23 | 45 | ||
24 | main :: IO () | 46 | main :: IO () |
25 | main = do | 47 | main = do |
26 | let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0]) | 48 | let res = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) |
27 | putStrLn $ show res | 49 | putStrLn $ show res |
28 | let res = solveOde stiffish (V.fromList [1.0]) (V.fromList [0.0, 0.1 .. 10.0]) | 50 | renderRasterific "diagrams/brusselator.png" |
51 | (D.dims2D 500.0 500.0) | ||
52 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res)) | ||
53 | |||
54 | let res = odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) | ||
29 | putStrLn $ show res | 55 | putStrLn $ show res |
56 | renderRasterific "diagrams/stiffish.png" | ||
57 | (D.dims2D 500.0 500.0) | ||
58 | (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res)) | ||
59 | |||
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 630827c..f432951 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -6,7 +6,9 @@ | |||
6 | {-# LANGUAGE OverloadedStrings #-} | 6 | {-# LANGUAGE OverloadedStrings #-} |
7 | {-# LANGUAGE ScopedTypeVariables #-} | 7 | {-# LANGUAGE ScopedTypeVariables #-} |
8 | 8 | ||
9 | module Numeric.Sundials.Arkode.ODE ( solveOde ) where | 9 | module Numeric.Sundials.Arkode.ODE ( solveOde |
10 | , odeSolve | ||
11 | ) where | ||
10 | 12 | ||
11 | import qualified Language.C.Inline as C | 13 | import qualified Language.C.Inline as C |
12 | import qualified Language.C.Inline.Unsafe as CU | 14 | import qualified Language.C.Inline.Unsafe as CU |
@@ -26,13 +28,14 @@ import System.IO.Unsafe (unsafePerformIO) | |||
26 | 28 | ||
27 | import Numeric.LinearAlgebra.Devel (createVector) | 29 | import Numeric.LinearAlgebra.Devel (createVector) |
28 | 30 | ||
29 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | 31 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><)) |
30 | 32 | ||
31 | import qualified Types as T | 33 | import qualified Types as T |
32 | 34 | ||
33 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 35 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
34 | 36 | ||
35 | -- C includes | 37 | -- C includes |
38 | C.include "<stdlib.h>" | ||
36 | C.include "<stdio.h>" | 39 | C.include "<stdio.h>" |
37 | C.include "<math.h>" | 40 | C.include "<math.h>" |
38 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | 41 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. |
@@ -96,7 +99,14 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
96 | -> [Double] -- ^ initial conditions | 99 | -> [Double] -- ^ initial conditions |
97 | -> Vector Double -- ^ desired solution times | 100 | -> Vector Double -- ^ desired solution times |
98 | -> Matrix Double -- ^ solution | 101 | -> Matrix Double -- ^ solution |
99 | odeSolve = undefined | 102 | odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of |
103 | Left c -> error $ show c -- FIXME | ||
104 | Right (v, _) -> (nR >< nC) (V.toList v) | ||
105 | where | ||
106 | us = toList ts | ||
107 | nR = length us | ||
108 | nC = length y0 | ||
109 | g t x0 = V.fromList $ f t (V.toList x0) | ||
100 | 110 | ||
101 | solveOde :: | 111 | solveOde :: |
102 | (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 112 | (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
@@ -146,13 +156,12 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
146 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | 156 | SUNMatrix A = NULL; /* empty matrix for linear solver */ |
147 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | 157 | SUNLinearSolver LS = NULL; /* empty linear solver object */ |
148 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | 158 | void *arkode_mem = NULL; /* empty ARKode memory structure */ |
149 | FILE *UFID; | ||
150 | realtype t; | 159 | realtype t; |
151 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | 160 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; |
152 | 161 | ||
153 | /* general problem parameters */ | 162 | /* general problem parameters */ |
154 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ | 163 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ |
155 | realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ | 164 | |
156 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | 165 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ |
157 | realtype reltol = 1.0e-6; /* tolerances */ | 166 | realtype reltol = 1.0e-6; /* tolerances */ |
158 | realtype abstol = 1.0e-10; | 167 | realtype abstol = 1.0e-10; |
@@ -198,7 +207,29 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
198 | for (j = 0; j < NEQ; j++) { | 207 | for (j = 0; j < NEQ; j++) { |
199 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | 208 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); |
200 | } | 209 | } |
201 | 210 | ||
211 | flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); | ||
212 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
213 | |||
214 | int s, q, p; | ||
215 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
216 | realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
217 | realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
218 | realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
219 | realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
220 | realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
221 | realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
222 | realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
223 | flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e); | ||
224 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
225 | fprintf(stderr, "s = %d, q = %d, p = %d\n", s, q, p); | ||
226 | for (i = 0; i < s; i++) { | ||
227 | for (j = 0; j < s; j++) { | ||
228 | fprintf(stderr, "ai[%d,%d] = %f", i, j, ai[i * ARK_S_MAX + j]); | ||
229 | } | ||
230 | fprintf(stderr, "\n"); | ||
231 | } | ||
232 | |||
202 | /* Main time-stepping loop: calls ARKode to perform the integration */ | 233 | /* Main time-stepping loop: calls ARKode to perform the integration */ |
203 | /* Stops when the final time has been reached */ | 234 | /* Stops when the final time has been reached */ |
204 | for (i = 1; i < $(int nTs); i++) { | 235 | for (i = 1; i < $(int nTs); i++) { |