diff options
Diffstat (limited to 'packages/sundials')
-rw-r--r-- | packages/sundials/ChangeLog.md | 5 | ||||
-rw-r--r-- | packages/sundials/LICENSE | 30 | ||||
-rw-r--r-- | packages/sundials/README.md | 8 | ||||
-rw-r--r-- | packages/sundials/Setup.hs | 2 | ||||
-rw-r--r-- | packages/sundials/diagrams/brusselator.png | bin | 0 -> 27362 bytes | |||
-rw-r--r-- | packages/sundials/hmatrix-sundials.cabal | 57 | ||||
-rw-r--r-- | packages/sundials/src/Arkode.hsc | 114 | ||||
-rw-r--r-- | packages/sundials/src/Main.hs | 138 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 898 | ||||
-rw-r--r-- | packages/sundials/src/Types.hs | 40 | ||||
-rw-r--r-- | packages/sundials/src/helpers.c | 44 | ||||
-rw-r--r-- | packages/sundials/src/helpers.h | 9 |
12 files changed, 1345 insertions, 0 deletions
diff --git a/packages/sundials/ChangeLog.md b/packages/sundials/ChangeLog.md new file mode 100644 index 0000000..7b15777 --- /dev/null +++ b/packages/sundials/ChangeLog.md | |||
@@ -0,0 +1,5 @@ | |||
1 | # Revision history for hmatrix-sundials | ||
2 | |||
3 | ## 0.1.0.0 -- 2018-04-21 | ||
4 | |||
5 | * First version. Released on an unsuspecting world. Just Runge-Kutta methods to start with. | ||
diff --git a/packages/sundials/LICENSE b/packages/sundials/LICENSE new file mode 100644 index 0000000..a162e98 --- /dev/null +++ b/packages/sundials/LICENSE | |||
@@ -0,0 +1,30 @@ | |||
1 | Copyright (c) 2018, Dominic Steinitz, Novadiscovery | ||
2 | |||
3 | All rights reserved. | ||
4 | |||
5 | Redistribution and use in source and binary forms, with or without | ||
6 | modification, are permitted provided that the following conditions are met: | ||
7 | |||
8 | * Redistributions of source code must retain the above copyright | ||
9 | notice, this list of conditions and the following disclaimer. | ||
10 | |||
11 | * Redistributions in binary form must reproduce the above | ||
12 | copyright notice, this list of conditions and the following | ||
13 | disclaimer in the documentation and/or other materials provided | ||
14 | with the distribution. | ||
15 | |||
16 | * Neither the name of Dominic Steinitz nor the names of other | ||
17 | contributors may be used to endorse or promote products derived | ||
18 | from this software without specific prior written permission. | ||
19 | |||
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||
21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||
22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||
23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | ||
24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | ||
25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | ||
26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | ||
27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||
28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
diff --git a/packages/sundials/README.md b/packages/sundials/README.md new file mode 100644 index 0000000..2fac5c2 --- /dev/null +++ b/packages/sundials/README.md | |||
@@ -0,0 +1,8 @@ | |||
1 | Currently only an interface to the Runge-Kutta methods: | ||
2 | [ARKode](https://computation.llnl.gov/projects/sundials/arkode) | ||
3 | |||
4 | The interface is almost certainly going to change. Sundials gives a | ||
5 | rich set of "combinators" for controlling the solution of your problem | ||
6 | and reporting on how it performed. The idea is to initially mimic | ||
7 | hmatrix-gsl and add extra, richer functions but ultimately upgrade the | ||
8 | whole interface both for sundials and for gsl. | ||
diff --git a/packages/sundials/Setup.hs b/packages/sundials/Setup.hs new file mode 100644 index 0000000..9a994af --- /dev/null +++ b/packages/sundials/Setup.hs | |||
@@ -0,0 +1,2 @@ | |||
1 | import Distribution.Simple | ||
2 | main = defaultMain | ||
diff --git a/packages/sundials/diagrams/brusselator.png b/packages/sundials/diagrams/brusselator.png new file mode 100644 index 0000000..740cacb --- /dev/null +++ b/packages/sundials/diagrams/brusselator.png | |||
Binary files differ | |||
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal new file mode 100644 index 0000000..388f1db --- /dev/null +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -0,0 +1,57 @@ | |||
1 | name: hmatrix-sundials | ||
2 | version: 0.19.0.0 | ||
3 | synopsis: hmatrix interface to sundials | ||
4 | description: An interface to the solving suite SUNDIALS. Currently, it | ||
5 | mimics the solving interface in hmstrix-gsl but | ||
6 | provides more diagnostic information and the | ||
7 | Butcher Tableaux (for Runge-Kutta methods). | ||
8 | homepage: https://github.com/idontgetoutmuch/hmatrix/tree/sundials | ||
9 | license: BSD3 | ||
10 | license-file: LICENSE | ||
11 | author: Dominic Steinitz | ||
12 | maintainer: dominic@steinitz.org | ||
13 | copyright: Dominic Steinitz 2018, Novadiscovery 2018 | ||
14 | category: Math | ||
15 | build-type: Simple | ||
16 | extra-source-files: ChangeLog.md, README.md, diagrams/*.png | ||
17 | extra-doc-files: diagrams/*.png | ||
18 | cabal-version: >=1.18 | ||
19 | |||
20 | |||
21 | library | ||
22 | build-depends: base >=4.10 && <4.11, | ||
23 | inline-c >=0.6 && <0.7, | ||
24 | vector >=0.12 && <0.13, | ||
25 | template-haskell >=2.12 && <2.13, | ||
26 | containers >=0.5 && <0.6, | ||
27 | hmatrix>=0.18 | ||
28 | extra-libraries: sundials_arkode | ||
29 | other-extensions: QuasiQuotes | ||
30 | hs-source-dirs: src | ||
31 | exposed-modules: Numeric.Sundials.ARKode.ODE | ||
32 | other-modules: Types, | ||
33 | Arkode | ||
34 | c-sources: src/helpers.c src/helpers.h | ||
35 | default-language: Haskell2010 | ||
36 | |||
37 | test-suite hmatrix-sundials-testsuite | ||
38 | type: exitcode-stdio-1.0 | ||
39 | main-is: Main.hs | ||
40 | other-modules: Types, | ||
41 | Numeric.Sundials.ARKode.ODE, | ||
42 | Arkode | ||
43 | build-depends: base >=4.10 && <4.11, | ||
44 | inline-c >=0.6 && <0.7, | ||
45 | vector >=0.12 && <0.13, | ||
46 | template-haskell >=2.12 && <2.13, | ||
47 | containers >=0.5 && <0.6, | ||
48 | hmatrix>=0.18, | ||
49 | plots, | ||
50 | diagrams-lib, | ||
51 | diagrams-rasterific, | ||
52 | lens, | ||
53 | hspec | ||
54 | hs-source-dirs: src | ||
55 | extra-libraries: sundials_arkode | ||
56 | c-sources: src/helpers.c src/helpers.h | ||
57 | default-language: Haskell2010 | ||
diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Arkode.hsc new file mode 100644 index 0000000..9db37b5 --- /dev/null +++ b/packages/sundials/src/Arkode.hsc | |||
@@ -0,0 +1,114 @@ | |||
1 | module Arkode where | ||
2 | |||
3 | import Foreign | ||
4 | import Foreign.C.Types | ||
5 | |||
6 | |||
7 | #include <stdio.h> | ||
8 | #include <sundials/sundials_nvector.h> | ||
9 | #include <sundials/sundials_matrix.h> | ||
10 | #include <nvector/nvector_serial.h> | ||
11 | #include <sunmatrix/sunmatrix_dense.h> | ||
12 | #include <arkode/arkode.h> | ||
13 | |||
14 | |||
15 | #def typedef struct _generic_N_Vector SunVector; | ||
16 | #def typedef struct _N_VectorContent_Serial SunContent; | ||
17 | |||
18 | #def typedef struct _generic_SUNMatrix SunMatrix; | ||
19 | #def typedef struct _SUNMatrixContent_Dense SunMatrixContent; | ||
20 | |||
21 | getContentMatrixPtr :: Storable a => Ptr b -> IO a | ||
22 | getContentMatrixPtr ptr = (#peek SunMatrix, content) ptr | ||
23 | |||
24 | getNRows :: Ptr b -> IO CInt | ||
25 | getNRows ptr = (#peek SunMatrixContent, M) ptr | ||
26 | putNRows :: CInt -> Ptr b -> IO () | ||
27 | putNRows nr ptr = (#poke SunMatrixContent, M) ptr nr | ||
28 | |||
29 | getNCols :: Ptr b -> IO CInt | ||
30 | getNCols ptr = (#peek SunMatrixContent, N) ptr | ||
31 | putNCols :: CInt -> Ptr b -> IO () | ||
32 | putNCols nc ptr = (#poke SunMatrixContent, N) ptr nc | ||
33 | |||
34 | getMatrixData :: Storable a => Ptr b -> IO a | ||
35 | getMatrixData ptr = (#peek SunMatrixContent, data) ptr | ||
36 | |||
37 | getContentPtr :: Storable a => Ptr b -> IO a | ||
38 | getContentPtr ptr = (#peek SunVector, content) ptr | ||
39 | |||
40 | getData :: Storable a => Ptr b -> IO a | ||
41 | getData ptr = (#peek SunContent, data) ptr | ||
42 | |||
43 | arkSMax :: Int | ||
44 | arkSMax = #const ARK_S_MAX | ||
45 | |||
46 | mIN_DIRK_NUM, mAX_DIRK_NUM :: Int | ||
47 | mIN_DIRK_NUM = #const MIN_DIRK_NUM | ||
48 | mAX_DIRK_NUM = #const MAX_DIRK_NUM | ||
49 | |||
50 | -- FIXME: We could just use inline-c instead | ||
51 | |||
52 | -- Butcher table accessors -- implicit | ||
53 | sDIRK_2_1_2 :: Int | ||
54 | sDIRK_2_1_2 = #const SDIRK_2_1_2 | ||
55 | bILLINGTON_3_3_2 :: Int | ||
56 | bILLINGTON_3_3_2 = #const BILLINGTON_3_3_2 | ||
57 | tRBDF2_3_3_2 :: Int | ||
58 | tRBDF2_3_3_2 = #const TRBDF2_3_3_2 | ||
59 | kVAERNO_4_2_3 :: Int | ||
60 | kVAERNO_4_2_3 = #const KVAERNO_4_2_3 | ||
61 | aRK324L2SA_DIRK_4_2_3 :: Int | ||
62 | aRK324L2SA_DIRK_4_2_3 = #const ARK324L2SA_DIRK_4_2_3 | ||
63 | cASH_5_2_4 :: Int | ||
64 | cASH_5_2_4 = #const CASH_5_2_4 | ||
65 | cASH_5_3_4 :: Int | ||
66 | cASH_5_3_4 = #const CASH_5_3_4 | ||
67 | sDIRK_5_3_4 :: Int | ||
68 | sDIRK_5_3_4 = #const SDIRK_5_3_4 | ||
69 | kVAERNO_5_3_4 :: Int | ||
70 | kVAERNO_5_3_4 = #const KVAERNO_5_3_4 | ||
71 | aRK436L2SA_DIRK_6_3_4 :: Int | ||
72 | aRK436L2SA_DIRK_6_3_4 = #const ARK436L2SA_DIRK_6_3_4 | ||
73 | kVAERNO_7_4_5 :: Int | ||
74 | kVAERNO_7_4_5 = #const KVAERNO_7_4_5 | ||
75 | aRK548L2SA_DIRK_8_4_5 :: Int | ||
76 | aRK548L2SA_DIRK_8_4_5 = #const ARK548L2SA_DIRK_8_4_5 | ||
77 | |||
78 | -- #define DEFAULT_DIRK_2 SDIRK_2_1_2 | ||
79 | -- #define DEFAULT_DIRK_3 ARK324L2SA_DIRK_4_2_3 | ||
80 | -- #define DEFAULT_DIRK_4 SDIRK_5_3_4 | ||
81 | -- #define DEFAULT_DIRK_5 ARK548L2SA_DIRK_8_4_5 | ||
82 | |||
83 | -- Butcher table accessors -- explicit | ||
84 | hEUN_EULER_2_1_2 :: Int | ||
85 | hEUN_EULER_2_1_2 = #const HEUN_EULER_2_1_2 | ||
86 | bOGACKI_SHAMPINE_4_2_3 :: Int | ||
87 | bOGACKI_SHAMPINE_4_2_3 = #const BOGACKI_SHAMPINE_4_2_3 | ||
88 | aRK324L2SA_ERK_4_2_3 :: Int | ||
89 | aRK324L2SA_ERK_4_2_3 = #const ARK324L2SA_ERK_4_2_3 | ||
90 | zONNEVELD_5_3_4 :: Int | ||
91 | zONNEVELD_5_3_4 = #const ZONNEVELD_5_3_4 | ||
92 | aRK436L2SA_ERK_6_3_4 :: Int | ||
93 | aRK436L2SA_ERK_6_3_4 = #const ARK436L2SA_ERK_6_3_4 | ||
94 | sAYFY_ABURUB_6_3_4 :: Int | ||
95 | sAYFY_ABURUB_6_3_4 = #const SAYFY_ABURUB_6_3_4 | ||
96 | cASH_KARP_6_4_5 :: Int | ||
97 | cASH_KARP_6_4_5 = #const CASH_KARP_6_4_5 | ||
98 | fEHLBERG_6_4_5 :: Int | ||
99 | fEHLBERG_6_4_5 = #const FEHLBERG_6_4_5 | ||
100 | dORMAND_PRINCE_7_4_5 :: Int | ||
101 | dORMAND_PRINCE_7_4_5 = #const DORMAND_PRINCE_7_4_5 | ||
102 | aRK548L2SA_ERK_8_4_5 :: Int | ||
103 | aRK548L2SA_ERK_8_4_5 = #const ARK548L2SA_ERK_8_4_5 | ||
104 | vERNER_8_5_6 :: Int | ||
105 | vERNER_8_5_6 = #const VERNER_8_5_6 | ||
106 | fEHLBERG_13_7_8 :: Int | ||
107 | fEHLBERG_13_7_8 = #const FEHLBERG_13_7_8 | ||
108 | |||
109 | -- #define DEFAULT_ERK_2 HEUN_EULER_2_1_2 | ||
110 | -- #define DEFAULT_ERK_3 BOGACKI_SHAMPINE_4_2_3 | ||
111 | -- #define DEFAULT_ERK_4 ZONNEVELD_5_3_4 | ||
112 | -- #define DEFAULT_ERK_5 CASH_KARP_6_4_5 | ||
113 | -- #define DEFAULT_ERK_6 VERNER_8_5_6 | ||
114 | -- #define DEFAULT_ERK_8 FEHLBERG_13_7_8 | ||
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs new file mode 100644 index 0000000..729d35a --- /dev/null +++ b/packages/sundials/src/Main.hs | |||
@@ -0,0 +1,138 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | import Numeric.Sundials.ARKode.ODE | ||
4 | import Numeric.LinearAlgebra | ||
5 | |||
6 | import Plots as P | ||
7 | import qualified Diagrams.Prelude as D | ||
8 | import Diagrams.Backend.Rasterific | ||
9 | |||
10 | import Control.Lens | ||
11 | |||
12 | import Test.Hspec | ||
13 | |||
14 | |||
15 | lorenz :: Double -> [Double] -> [Double] | ||
16 | lorenz _t u = [ sigma * (y - x) | ||
17 | , x * (rho - z) - y | ||
18 | , x * y - beta * z | ||
19 | ] | ||
20 | where | ||
21 | rho = 28.0 | ||
22 | sigma = 10.0 | ||
23 | beta = 8.0 / 3.0 | ||
24 | x = u !! 0 | ||
25 | y = u !! 1 | ||
26 | z = u !! 2 | ||
27 | |||
28 | _lorenzJac :: Double -> Vector Double -> Matrix Double | ||
29 | _lorenzJac _t u = (3><3) [ (-sigma), rho - z, y | ||
30 | , sigma , -1.0 , x | ||
31 | , 0.0 , (-x) , (-beta) | ||
32 | ] | ||
33 | where | ||
34 | rho = 28.0 | ||
35 | sigma = 10.0 | ||
36 | beta = 8.0 / 3.0 | ||
37 | x = u ! 0 | ||
38 | y = u ! 1 | ||
39 | z = u ! 2 | ||
40 | |||
41 | brusselator :: Double -> [Double] -> [Double] | ||
42 | brusselator _t x = [ a - (w + 1) * u + v * u * u | ||
43 | , w * u - v * u * u | ||
44 | , (b - w) / eps - w * u | ||
45 | ] | ||
46 | where | ||
47 | a = 1.0 | ||
48 | b = 3.5 | ||
49 | eps = 5.0e-6 | ||
50 | u = x !! 0 | ||
51 | v = x !! 1 | ||
52 | w = x !! 2 | ||
53 | |||
54 | _brussJac :: Double -> Vector Double -> Matrix Double | ||
55 | _brussJac _t x = (3><3) [ (-(w + 1.0)) + 2.0 * u * v, w - 2.0 * u * v, (-w) | ||
56 | , u * u , (-(u * u)) , 0.0 | ||
57 | , (-u) , u , (-1.0) / eps - u | ||
58 | ] | ||
59 | where | ||
60 | y = toList x | ||
61 | u = y !! 0 | ||
62 | v = y !! 1 | ||
63 | w = y !! 2 | ||
64 | eps = 5.0e-6 | ||
65 | |||
66 | stiffish :: Double -> [Double] -> [Double] | ||
67 | stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
68 | where | ||
69 | lamda = -100.0 | ||
70 | u = v !! 0 | ||
71 | |||
72 | stiffishV :: Double -> Vector Double -> Vector Double | ||
73 | stiffishV t v = fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
74 | where | ||
75 | lamda = -100.0 | ||
76 | u = v ! 0 | ||
77 | |||
78 | _stiffJac :: Double -> Vector Double -> Matrix Double | ||
79 | _stiffJac _t _v = (1><1) [ lamda ] | ||
80 | where | ||
81 | lamda = -100.0 | ||
82 | |||
83 | lSaxis :: [[Double]] -> P.Axis B D.V2 Double | ||
84 | lSaxis xs = P.r2Axis &~ do | ||
85 | let ts = xs!!0 | ||
86 | us = xs!!1 | ||
87 | vs = xs!!2 | ||
88 | ws = xs!!3 | ||
89 | P.linePlot' $ zip ts us | ||
90 | P.linePlot' $ zip ts vs | ||
91 | P.linePlot' $ zip ts ws | ||
92 | |||
93 | kSaxis :: [(Double, Double)] -> P.Axis B D.V2 Double | ||
94 | kSaxis xs = P.r2Axis &~ do | ||
95 | P.linePlot' xs | ||
96 | |||
97 | main :: IO () | ||
98 | main = do | ||
99 | |||
100 | let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
101 | renderRasterific "diagrams/brusselator.png" | ||
102 | (D.dims2D 500.0 500.0) | ||
103 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | ||
104 | |||
105 | let res1a = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
106 | renderRasterific "diagrams/brusselatorA.png" | ||
107 | (D.dims2D 500.0 500.0) | ||
108 | (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a)) | ||
109 | |||
110 | let res2 = odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) | ||
111 | renderRasterific "diagrams/stiffish.png" | ||
112 | (D.dims2D 500.0 500.0) | ||
113 | (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) | ||
114 | |||
115 | let res2a = odeSolveV (SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) | ||
116 | |||
117 | let res2b = odeSolveV (TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) | ||
118 | |||
119 | let maxDiff = maximum $ map abs $ | ||
120 | zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) | ||
121 | |||
122 | hspec $ describe "Compare results" $ do | ||
123 | it "for two different RK methods" $ | ||
124 | maxDiff < 1.0e-6 | ||
125 | |||
126 | let res3 = odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) | ||
127 | |||
128 | renderRasterific "diagrams/lorenz.png" | ||
129 | (D.dims2D 500.0 500.0) | ||
130 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!0) ((toLists $ tr res3)!!1)) | ||
131 | |||
132 | renderRasterific "diagrams/lorenz1.png" | ||
133 | (D.dims2D 500.0 500.0) | ||
134 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!0) ((toLists $ tr res3)!!2)) | ||
135 | |||
136 | renderRasterific "diagrams/lorenz2.png" | ||
137 | (D.dims2D 500.0 500.0) | ||
138 | (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2)) | ||
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs new file mode 100644 index 0000000..e5a2e4d --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -0,0 +1,898 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE DeriveGeneric #-} | ||
9 | {-# LANGUAGE TypeOperators #-} | ||
10 | {-# LANGUAGE KindSignatures #-} | ||
11 | {-# LANGUAGE TypeSynonymInstances #-} | ||
12 | {-# LANGUAGE FlexibleInstances #-} | ||
13 | {-# LANGUAGE FlexibleContexts #-} | ||
14 | |||
15 | ----------------------------------------------------------------------------- | ||
16 | -- | | ||
17 | -- Module : Numeric.Sundials.ARKode.ODE | ||
18 | -- Copyright : Dominic Steinitz 2018, | ||
19 | -- Novadiscovery 2018 | ||
20 | -- License : BSD | ||
21 | -- Maintainer : Dominic Steinitz | ||
22 | -- Stability : provisional | ||
23 | -- | ||
24 | -- Solution of ordinary differential equation (ODE) initial value problems. | ||
25 | -- | ||
26 | -- <https://computation.llnl.gov/projects/sundials/sundials-software> | ||
27 | -- | ||
28 | -- A simple example: | ||
29 | -- | ||
30 | -- <<diagrams/brusselator.png#diagram=brusselator&height=400&width=500>> | ||
31 | -- | ||
32 | -- @ | ||
33 | -- import Numeric.Sundials.ARKode.ODE | ||
34 | -- import Numeric.LinearAlgebra | ||
35 | -- | ||
36 | -- import Plots as P | ||
37 | -- import qualified Diagrams.Prelude as D | ||
38 | -- import Diagrams.Backend.Rasterific | ||
39 | -- | ||
40 | -- brusselator :: Double -> [Double] -> [Double] | ||
41 | -- brusselator _t x = [ a - (w + 1) * u + v * u * u | ||
42 | -- , w * u - v * u * u | ||
43 | -- , (b - w) / eps - w * u | ||
44 | -- ] | ||
45 | -- where | ||
46 | -- a = 1.0 | ||
47 | -- b = 3.5 | ||
48 | -- eps = 5.0e-6 | ||
49 | -- u = x !! 0 | ||
50 | -- v = x !! 1 | ||
51 | -- w = x !! 2 | ||
52 | -- | ||
53 | -- lSaxis :: [[Double]] -> P.Axis B D.V2 Double | ||
54 | -- lSaxis xs = P.r2Axis &~ do | ||
55 | -- let ts = xs!!0 | ||
56 | -- us = xs!!1 | ||
57 | -- vs = xs!!2 | ||
58 | -- ws = xs!!3 | ||
59 | -- P.linePlot' $ zip ts us | ||
60 | -- P.linePlot' $ zip ts vs | ||
61 | -- P.linePlot' $ zip ts ws | ||
62 | -- | ||
63 | -- main = do | ||
64 | -- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) | ||
65 | -- renderRasterific "diagrams/brusselator.png" | ||
66 | -- (D.dims2D 500.0 500.0) | ||
67 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | ||
68 | -- @ | ||
69 | -- | ||
70 | -- KVAERNO_4_2_3 | ||
71 | -- | ||
72 | -- \[ | ||
73 | -- \begin{array}{c|cccc} | ||
74 | -- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
75 | -- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ | ||
76 | -- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
77 | -- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
78 | -- \hline | ||
79 | -- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
80 | -- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
81 | -- \end{array} | ||
82 | -- \] | ||
83 | -- | ||
84 | -- SDIRK_2_1_2 | ||
85 | -- | ||
86 | -- \[ | ||
87 | -- \begin{array}{c|cc} | ||
88 | -- 1.0 & 1.0 & 0.0 \\ | ||
89 | -- 0.0 & -1.0 & 1.0 \\ | ||
90 | -- \hline | ||
91 | -- & 0.5 & 0.5 \\ | ||
92 | -- & 1.0 & 0.0 \\ | ||
93 | -- \end{array} | ||
94 | -- \] | ||
95 | -- | ||
96 | -- SDIRK_5_3_4 | ||
97 | -- | ||
98 | -- \[ | ||
99 | -- \begin{array}{c|ccccc} | ||
100 | -- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
101 | -- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ | ||
102 | -- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ | ||
103 | -- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ | ||
104 | -- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ | ||
105 | -- \hline | ||
106 | -- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ | ||
107 | -- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\ | ||
108 | -- \end{array} | ||
109 | -- \] | ||
110 | ----------------------------------------------------------------------------- | ||
111 | module Numeric.Sundials.ARKode.ODE ( odeSolve | ||
112 | , odeSolveV | ||
113 | , odeSolveVWith | ||
114 | , odeSolveVWith' | ||
115 | , ButcherTable(..) | ||
116 | , butcherTable | ||
117 | , ODEMethod(..) | ||
118 | , StepControl(..) | ||
119 | , Jacobian | ||
120 | , SundialsDiagnostics(..) | ||
121 | ) where | ||
122 | |||
123 | import qualified Language.C.Inline as C | ||
124 | import qualified Language.C.Inline.Unsafe as CU | ||
125 | |||
126 | import Data.Monoid ((<>)) | ||
127 | import Data.Maybe (isJust) | ||
128 | |||
129 | import Foreign.C.Types | ||
130 | import Foreign.Ptr (Ptr) | ||
131 | import Foreign.ForeignPtr (newForeignPtr_) | ||
132 | import Foreign.Storable (Storable) | ||
133 | |||
134 | import qualified Data.Vector.Storable as V | ||
135 | import qualified Data.Vector.Storable.Mutable as VM | ||
136 | |||
137 | import Data.Coerce (coerce) | ||
138 | import System.IO.Unsafe (unsafePerformIO) | ||
139 | import GHC.Generics | ||
140 | |||
141 | import Numeric.LinearAlgebra.Devel (createVector) | ||
142 | |||
143 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | ||
144 | subMatrix, rows, cols, toLists, | ||
145 | size, subVector) | ||
146 | |||
147 | import qualified Types as T | ||
148 | import Arkode | ||
149 | import qualified Arkode as B | ||
150 | |||
151 | |||
152 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
153 | |||
154 | C.include "<stdlib.h>" | ||
155 | C.include "<stdio.h>" | ||
156 | C.include "<math.h>" | ||
157 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
158 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
159 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
160 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
161 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
162 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
163 | C.include "<sundials/sundials_math.h>" | ||
164 | C.include "../../../helpers.h" | ||
165 | C.include "Arkode_hsc.h" | ||
166 | |||
167 | |||
168 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) | ||
169 | getDataFromContents len ptr = do | ||
170 | qtr <- B.getContentPtr ptr | ||
171 | rtr <- B.getData qtr | ||
172 | vectorFromC len rtr | ||
173 | |||
174 | -- FIXME: Potentially an instance of Storable | ||
175 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
176 | _getMatrixDataFromContents ptr = do | ||
177 | qtr <- B.getContentMatrixPtr ptr | ||
178 | rs <- B.getNRows qtr | ||
179 | cs <- B.getNCols qtr | ||
180 | rtr <- B.getMatrixData qtr | ||
181 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
182 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
183 | |||
184 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
185 | putMatrixDataFromContents mat ptr = do | ||
186 | let rs = T.rows mat | ||
187 | cs = T.cols mat | ||
188 | vs = T.vals mat | ||
189 | qtr <- B.getContentMatrixPtr ptr | ||
190 | B.putNRows rs qtr | ||
191 | B.putNCols cs qtr | ||
192 | rtr <- B.getMatrixData qtr | ||
193 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
194 | -- FIXME: END | ||
195 | |||
196 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
197 | putDataInContents vec len ptr = do | ||
198 | qtr <- B.getContentPtr ptr | ||
199 | rtr <- B.getData qtr | ||
200 | vectorToC vec len rtr | ||
201 | |||
202 | -- Utils | ||
203 | |||
204 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
205 | vectorFromC len ptr = do | ||
206 | ptr' <- newForeignPtr_ ptr | ||
207 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
208 | |||
209 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
210 | vectorToC vec len ptr = do | ||
211 | ptr' <- newForeignPtr_ ptr | ||
212 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
213 | |||
214 | data SundialsDiagnostics = SundialsDiagnostics { | ||
215 | aRKodeGetNumSteps :: Int | ||
216 | , aRKodeGetNumStepAttempts :: Int | ||
217 | , aRKodeGetNumRhsEvals_fe :: Int | ||
218 | , aRKodeGetNumRhsEvals_fi :: Int | ||
219 | , aRKodeGetNumLinSolvSetups :: Int | ||
220 | , aRKodeGetNumErrTestFails :: Int | ||
221 | , aRKodeGetNumNonlinSolvIters :: Int | ||
222 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
223 | , aRKDlsGetNumJacEvals :: Int | ||
224 | , aRKDlsGetNumRhsEvals :: Int | ||
225 | } deriving Show | ||
226 | |||
227 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
228 | |||
229 | -- | Stepping functions | ||
230 | data ODEMethod = SDIRK_2_1_2 Jacobian | ||
231 | | SDIRK_2_1_2' | ||
232 | | BILLINGTON_3_3_2 Jacobian | ||
233 | | BILLINGTON_3_3_2' | ||
234 | | TRBDF2_3_3_2 Jacobian | ||
235 | | TRBDF2_3_3_2' | ||
236 | | KVAERNO_4_2_3 Jacobian | ||
237 | | KVAERNO_4_2_3' | ||
238 | | ARK324L2SA_DIRK_4_2_3 Jacobian | ||
239 | | ARK324L2SA_DIRK_4_2_3' | ||
240 | | CASH_5_2_4 Jacobian | ||
241 | | CASH_5_2_4' | ||
242 | | CASH_5_3_4 Jacobian | ||
243 | | CASH_5_3_4' | ||
244 | | SDIRK_5_3_4 Jacobian | ||
245 | | SDIRK_5_3_4' | ||
246 | | KVAERNO_5_3_4 Jacobian | ||
247 | | KVAERNO_5_3_4' | ||
248 | | ARK436L2SA_DIRK_6_3_4 Jacobian | ||
249 | | ARK436L2SA_DIRK_6_3_4' | ||
250 | | KVAERNO_7_4_5 Jacobian | ||
251 | | KVAERNO_7_4_5' | ||
252 | | ARK548L2SA_DIRK_8_4_5 Jacobian | ||
253 | | ARK548L2SA_DIRK_8_4_5' | ||
254 | | HEUN_EULER_2_1_2 Jacobian | ||
255 | | HEUN_EULER_2_1_2' | ||
256 | | BOGACKI_SHAMPINE_4_2_3 Jacobian | ||
257 | | BOGACKI_SHAMPINE_4_2_3' | ||
258 | | ARK324L2SA_ERK_4_2_3 Jacobian | ||
259 | | ARK324L2SA_ERK_4_2_3' | ||
260 | | ZONNEVELD_5_3_4 Jacobian | ||
261 | | ZONNEVELD_5_3_4' | ||
262 | | ARK436L2SA_ERK_6_3_4 Jacobian | ||
263 | | ARK436L2SA_ERK_6_3_4' | ||
264 | | SAYFY_ABURUB_6_3_4 Jacobian | ||
265 | | SAYFY_ABURUB_6_3_4' | ||
266 | | CASH_KARP_6_4_5 Jacobian | ||
267 | | CASH_KARP_6_4_5' | ||
268 | | FEHLBERG_6_4_5 Jacobian | ||
269 | | FEHLBERG_6_4_5' | ||
270 | | DORMAND_PRINCE_7_4_5 Jacobian | ||
271 | | DORMAND_PRINCE_7_4_5' | ||
272 | | ARK548L2SA_ERK_8_4_5 Jacobian | ||
273 | | ARK548L2SA_ERK_8_4_5' | ||
274 | | VERNER_8_5_6 Jacobian | ||
275 | | VERNER_8_5_6' | ||
276 | | FEHLBERG_13_7_8 Jacobian | ||
277 | | FEHLBERG_13_7_8' | ||
278 | deriving Generic | ||
279 | |||
280 | constrName :: (HasConstructor (Rep a), Generic a)=> a -> String | ||
281 | constrName = genericConstrName . from | ||
282 | |||
283 | class HasConstructor (f :: * -> *) where | ||
284 | genericConstrName :: f x -> String | ||
285 | |||
286 | instance HasConstructor f => HasConstructor (D1 c f) where | ||
287 | genericConstrName (M1 x) = genericConstrName x | ||
288 | |||
289 | instance (HasConstructor x, HasConstructor y) => HasConstructor (x :+: y) where | ||
290 | genericConstrName (L1 l) = genericConstrName l | ||
291 | genericConstrName (R1 r) = genericConstrName r | ||
292 | |||
293 | instance Constructor c => HasConstructor (C1 c f) where | ||
294 | genericConstrName x = conName x | ||
295 | |||
296 | instance Show ODEMethod where | ||
297 | show x = constrName x | ||
298 | |||
299 | -- FIXME: We can probably do better here with generics | ||
300 | getMethod :: ODEMethod -> Int | ||
301 | getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 | ||
302 | getMethod (SDIRK_2_1_2') = sDIRK_2_1_2 | ||
303 | getMethod (BILLINGTON_3_3_2 _) = bILLINGTON_3_3_2 | ||
304 | getMethod (BILLINGTON_3_3_2') = bILLINGTON_3_3_2 | ||
305 | getMethod (TRBDF2_3_3_2 _) = tRBDF2_3_3_2 | ||
306 | getMethod (TRBDF2_3_3_2') = tRBDF2_3_3_2 | ||
307 | getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 | ||
308 | getMethod (KVAERNO_4_2_3') = kVAERNO_4_2_3 | ||
309 | getMethod (ARK324L2SA_DIRK_4_2_3 _) = aRK324L2SA_DIRK_4_2_3 | ||
310 | getMethod (ARK324L2SA_DIRK_4_2_3') = aRK324L2SA_DIRK_4_2_3 | ||
311 | getMethod (CASH_5_2_4 _) = cASH_5_2_4 | ||
312 | getMethod (CASH_5_2_4') = cASH_5_2_4 | ||
313 | getMethod (CASH_5_3_4 _) = cASH_5_3_4 | ||
314 | getMethod (CASH_5_3_4') = cASH_5_3_4 | ||
315 | getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 | ||
316 | getMethod (SDIRK_5_3_4') = sDIRK_5_3_4 | ||
317 | getMethod (KVAERNO_5_3_4 _) = kVAERNO_5_3_4 | ||
318 | getMethod (KVAERNO_5_3_4') = kVAERNO_5_3_4 | ||
319 | getMethod (ARK436L2SA_DIRK_6_3_4 _) = aRK436L2SA_DIRK_6_3_4 | ||
320 | getMethod (ARK436L2SA_DIRK_6_3_4') = aRK436L2SA_DIRK_6_3_4 | ||
321 | getMethod (KVAERNO_7_4_5 _) = kVAERNO_7_4_5 | ||
322 | getMethod (KVAERNO_7_4_5') = kVAERNO_7_4_5 | ||
323 | getMethod (ARK548L2SA_DIRK_8_4_5 _) = aRK548L2SA_DIRK_8_4_5 | ||
324 | getMethod (ARK548L2SA_DIRK_8_4_5') = aRK548L2SA_DIRK_8_4_5 | ||
325 | getMethod (HEUN_EULER_2_1_2 _) = hEUN_EULER_2_1_2 | ||
326 | getMethod (HEUN_EULER_2_1_2') = hEUN_EULER_2_1_2 | ||
327 | getMethod (BOGACKI_SHAMPINE_4_2_3 _) = bOGACKI_SHAMPINE_4_2_3 | ||
328 | getMethod (BOGACKI_SHAMPINE_4_2_3') = bOGACKI_SHAMPINE_4_2_3 | ||
329 | getMethod (ARK324L2SA_ERK_4_2_3 _) = aRK324L2SA_ERK_4_2_3 | ||
330 | getMethod (ARK324L2SA_ERK_4_2_3') = aRK324L2SA_ERK_4_2_3 | ||
331 | getMethod (ZONNEVELD_5_3_4 _) = zONNEVELD_5_3_4 | ||
332 | getMethod (ZONNEVELD_5_3_4') = zONNEVELD_5_3_4 | ||
333 | getMethod (ARK436L2SA_ERK_6_3_4 _) = aRK436L2SA_ERK_6_3_4 | ||
334 | getMethod (ARK436L2SA_ERK_6_3_4') = aRK436L2SA_ERK_6_3_4 | ||
335 | getMethod (SAYFY_ABURUB_6_3_4 _) = sAYFY_ABURUB_6_3_4 | ||
336 | getMethod (SAYFY_ABURUB_6_3_4') = sAYFY_ABURUB_6_3_4 | ||
337 | getMethod (CASH_KARP_6_4_5 _) = cASH_KARP_6_4_5 | ||
338 | getMethod (CASH_KARP_6_4_5') = cASH_KARP_6_4_5 | ||
339 | getMethod (FEHLBERG_6_4_5 _) = fEHLBERG_6_4_5 | ||
340 | getMethod (FEHLBERG_6_4_5' ) = fEHLBERG_6_4_5 | ||
341 | getMethod (DORMAND_PRINCE_7_4_5 _) = dORMAND_PRINCE_7_4_5 | ||
342 | getMethod (DORMAND_PRINCE_7_4_5') = dORMAND_PRINCE_7_4_5 | ||
343 | getMethod (ARK548L2SA_ERK_8_4_5 _) = aRK548L2SA_ERK_8_4_5 | ||
344 | getMethod (ARK548L2SA_ERK_8_4_5') = aRK548L2SA_ERK_8_4_5 | ||
345 | getMethod (VERNER_8_5_6 _) = vERNER_8_5_6 | ||
346 | getMethod (VERNER_8_5_6') = vERNER_8_5_6 | ||
347 | getMethod (FEHLBERG_13_7_8 _) = fEHLBERG_13_7_8 | ||
348 | getMethod (FEHLBERG_13_7_8') = fEHLBERG_13_7_8 | ||
349 | |||
350 | getJacobian :: ODEMethod -> Maybe Jacobian | ||
351 | getJacobian (SDIRK_2_1_2 j) = Just j | ||
352 | getJacobian (BILLINGTON_3_3_2 j) = Just j | ||
353 | getJacobian (TRBDF2_3_3_2 j) = Just j | ||
354 | getJacobian (KVAERNO_4_2_3 j) = Just j | ||
355 | getJacobian (ARK324L2SA_DIRK_4_2_3 j) = Just j | ||
356 | getJacobian (CASH_5_2_4 j) = Just j | ||
357 | getJacobian (CASH_5_3_4 j) = Just j | ||
358 | getJacobian (SDIRK_5_3_4 j) = Just j | ||
359 | getJacobian (KVAERNO_5_3_4 j) = Just j | ||
360 | getJacobian (ARK436L2SA_DIRK_6_3_4 j) = Just j | ||
361 | getJacobian (KVAERNO_7_4_5 j) = Just j | ||
362 | getJacobian (ARK548L2SA_DIRK_8_4_5 j) = Just j | ||
363 | getJacobian (HEUN_EULER_2_1_2 j) = Just j | ||
364 | getJacobian (BOGACKI_SHAMPINE_4_2_3 j) = Just j | ||
365 | getJacobian (ARK324L2SA_ERK_4_2_3 j) = Just j | ||
366 | getJacobian (ZONNEVELD_5_3_4 j) = Just j | ||
367 | getJacobian (ARK436L2SA_ERK_6_3_4 j) = Just j | ||
368 | getJacobian (SAYFY_ABURUB_6_3_4 j) = Just j | ||
369 | getJacobian (CASH_KARP_6_4_5 j) = Just j | ||
370 | getJacobian (FEHLBERG_6_4_5 j) = Just j | ||
371 | getJacobian (DORMAND_PRINCE_7_4_5 j) = Just j | ||
372 | getJacobian (ARK548L2SA_ERK_8_4_5 j) = Just j | ||
373 | getJacobian (VERNER_8_5_6 j) = Just j | ||
374 | getJacobian (FEHLBERG_13_7_8 j) = Just j | ||
375 | getJacobian _ = Nothing | ||
376 | |||
377 | -- | A version of 'odeSolveVWith' with reasonable default step control. | ||
378 | odeSolveV | ||
379 | :: ODEMethod | ||
380 | -> Maybe Double -- ^ initial step size - by default, ARKode | ||
381 | -- estimates the initial step size to be the | ||
382 | -- solution \(h\) of the equation | ||
383 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
384 | -- \(\ddot{y}\) is an estimated value of the | ||
385 | -- second derivative of the solution at \(t_0\) | ||
386 | -> Double -- ^ absolute tolerance for the state vector | ||
387 | -> Double -- ^ relative tolerance for the state vector | ||
388 | -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
389 | -> Vector Double -- ^ initial conditions | ||
390 | -> Vector Double -- ^ desired solution times | ||
391 | -> Matrix Double -- ^ solution | ||
392 | odeSolveV meth hi epsAbs epsRel f y0 ts = | ||
393 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of | ||
394 | Left c -> error $ show c -- FIXME | ||
395 | -- FIXME: Can we do better than using lists? | ||
396 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
397 | where | ||
398 | us = toList ts | ||
399 | nR = length us | ||
400 | nC = size y0 | ||
401 | g t x0 = coerce $ f t x0 | ||
402 | |||
403 | -- | A version of 'odeSolveV' with reasonable default parameters and | ||
404 | -- system of equations defined using lists. FIXME: we should say | ||
405 | -- something about the fact we could use the Jacobian but don't for | ||
406 | -- compatibility with hmatrix-gsl. | ||
407 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
408 | -> [Double] -- ^ initial conditions | ||
409 | -> Vector Double -- ^ desired solution times | ||
410 | -> Matrix Double -- ^ solution | ||
411 | odeSolve f y0 ts = | ||
412 | -- FIXME: These tolerances are different from the ones in GSL | ||
413 | case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | ||
414 | Left c -> error $ show c -- FIXME | ||
415 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
416 | where | ||
417 | us = toList ts | ||
418 | nR = length us | ||
419 | nC = length y0 | ||
420 | g t x0 = V.fromList $ f t (V.toList x0) | ||
421 | |||
422 | odeSolveVWith' :: | ||
423 | ODEMethod | ||
424 | -> StepControl | ||
425 | -> Maybe Double -- ^ initial step size - by default, ARKode | ||
426 | -- estimates the initial step size to be the | ||
427 | -- solution \(h\) of the equation | ||
428 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
429 | -- \(\ddot{y}\) is an estimated value of the second | ||
430 | -- derivative of the solution at \(t_0\) | ||
431 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
432 | -> V.Vector Double -- ^ Initial conditions | ||
433 | -> V.Vector Double -- ^ Desired solution times | ||
434 | -> Matrix Double -- ^ Error code or solution | ||
435 | odeSolveVWith' method control initStepSize f y0 tt = | ||
436 | case odeSolveVWith method control initStepSize f y0 tt of | ||
437 | Left c -> error $ show c -- FIXME | ||
438 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
439 | where | ||
440 | nR = V.length tt | ||
441 | nC = V.length y0 | ||
442 | |||
443 | odeSolveVWith :: | ||
444 | ODEMethod | ||
445 | -> StepControl | ||
446 | -> Maybe Double -- ^ initial step size - by default, ARKode | ||
447 | -- estimates the initial step size to be the | ||
448 | -- solution \(h\) of the equation | ||
449 | -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where | ||
450 | -- \(\ddot{y}\) is an estimated value of the second | ||
451 | -- derivative of the solution at \(t_0\) | ||
452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
453 | -> V.Vector Double -- ^ Initial conditions | ||
454 | -> V.Vector Double -- ^ Desired solution times | ||
455 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | ||
456 | odeSolveVWith method control initStepSize f y0 tt = | ||
457 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | ||
458 | (coerce f) (coerce y0) (coerce tt) of | ||
459 | Left c -> Left $ fromIntegral c | ||
460 | Right (v, d) -> Right (coerce v, d) | ||
461 | where | ||
462 | l = size y0 | ||
463 | scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) | ||
464 | scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) | ||
465 | scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) | ||
466 | -- FIXME; Should we check that the length of ss is correct? | ||
467 | scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) | ||
468 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ | ||
469 | getJacobian method | ||
470 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | ||
471 | where | ||
472 | nr = fromIntegral $ rows m | ||
473 | nc = fromIntegral $ cols m | ||
474 | -- FIXME: efficiency | ||
475 | vs = V.fromList $ map coerce $ concat $ toLists m | ||
476 | |||
477 | solveOdeC :: | ||
478 | CInt -> | ||
479 | Maybe CDouble -> | ||
480 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> | ||
481 | (V.Vector CDouble, CDouble) -> | ||
482 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
483 | -> V.Vector CDouble -- ^ Initial conditions | ||
484 | -> V.Vector CDouble -- ^ Desired solution times | ||
485 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | ||
486 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | ||
487 | |||
488 | let isInitStepSize :: CInt | ||
489 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize | ||
490 | ss :: CDouble | ||
491 | ss = case initStepSize of | ||
492 | -- It would be better to put an error message here but | ||
493 | -- inline-c seems to evaluate this even if it is never | ||
494 | -- used :( | ||
495 | Nothing -> 0.0 | ||
496 | Just x -> x | ||
497 | let dim = V.length f0 | ||
498 | nEq :: CLong | ||
499 | nEq = fromIntegral dim | ||
500 | nTs :: CInt | ||
501 | nTs = fromIntegral $ V.length ts | ||
502 | -- FIXME: fMut is not actually mutatated | ||
503 | fMut <- V.thaw f0 | ||
504 | tMut <- V.thaw ts | ||
505 | -- FIXME: I believe this gets taken from the ghc heap and so should | ||
506 | -- be subject to garbage collection. | ||
507 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | ||
508 | qMatMut <- V.thaw quasiMatrixRes | ||
509 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | ||
510 | diagMut <- V.thaw diagnostics | ||
511 | -- We need the types that sundials expects. These are tied together | ||
512 | -- in 'Types'. FIXME: The Haskell type is currently empty! | ||
513 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
514 | funIO x y f _ptr = do | ||
515 | -- Convert the pointer we get from C (y) to a vector, and then | ||
516 | -- apply the user-supplied function. | ||
517 | fImm <- fun x <$> getDataFromContents dim y | ||
518 | -- Fill in the provided pointer with the resulting vector. | ||
519 | putDataInContents fImm dim f | ||
520 | -- FIXME: I don't understand what this comment means | ||
521 | -- Unsafe since the function will be called many times. | ||
522 | [CU.exp| int{ 0 } |] | ||
523 | let isJac :: CInt | ||
524 | isJac = fromIntegral $ fromEnum $ isJust jacH | ||
525 | jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> | ||
526 | Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> | ||
527 | IO CInt | ||
528 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | ||
529 | case jacH of | ||
530 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | ||
531 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | ||
532 | putMatrixDataFromContents j jacS | ||
533 | -- FIXME: I don't understand what this comment means | ||
534 | -- Unsafe since the function will be called many times. | ||
535 | [CU.exp| int{ 0 } |] | ||
536 | |||
537 | res <- [C.block| int { | ||
538 | /* general problem variables */ | ||
539 | |||
540 | int flag; /* reusable error-checking flag */ | ||
541 | int i, j; /* reusable loop indices */ | ||
542 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
543 | N_Vector tv = NULL; /* empty vector for storing absolute tolerances */ | ||
544 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
545 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
546 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
547 | realtype t; | ||
548 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
549 | |||
550 | /* general problem parameters */ | ||
551 | |||
552 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ | ||
553 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
554 | |||
555 | /* Initialize data structures */ | ||
556 | |||
557 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
558 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
559 | /* Specify initial condition */ | ||
560 | for (i = 0; i < NEQ; i++) { | ||
561 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | ||
562 | }; | ||
563 | |||
564 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ | ||
565 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; | ||
566 | /* Specify tolerances */ | ||
567 | for (i = 0; i < NEQ; i++) { | ||
568 | NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; | ||
569 | }; | ||
570 | |||
571 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
572 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
573 | |||
574 | /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
575 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
576 | /* the initial dependent variable vector y. Note: we treat the */ | ||
577 | /* problem as fully implicit and set f_E to NULL and f_I to f. */ | ||
578 | |||
579 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
580 | /* the Haskell types defined in Types */ | ||
581 | if ($(int method) < MIN_DIRK_NUM) { | ||
582 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); | ||
583 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
584 | } else { | ||
585 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | ||
586 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
587 | } | ||
588 | |||
589 | /* FIXME: A hack for initial testing */ | ||
590 | flag = ARKodeSetMinStep(arkode_mem, 1.0e-12); | ||
591 | if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; | ||
592 | flag = ARKodeSetMaxNumSteps(arkode_mem, 10000); | ||
593 | if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; | ||
594 | |||
595 | /* Set routines */ | ||
596 | flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); | ||
597 | if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; | ||
598 | |||
599 | /* Initialize dense matrix data structure and solver */ | ||
600 | A = SUNDenseMatrix(NEQ, NEQ); | ||
601 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
602 | LS = SUNDenseLinearSolver(y, A); | ||
603 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
604 | |||
605 | /* Attach matrix and linear solver */ | ||
606 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); | ||
607 | if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1; | ||
608 | |||
609 | /* Set the initial step size if there is one */ | ||
610 | if ($(int isInitStepSize)) { | ||
611 | /* FIXME: We could check if the initial step size is 0 */ | ||
612 | /* or even NaN and then throw an error */ | ||
613 | flag = ARKodeSetInitStep(arkode_mem, $(double ss)); | ||
614 | if (check_flag(&flag, "ARKodeSetInitStep", 1)) return 1; | ||
615 | } | ||
616 | |||
617 | /* Set the Jacobian if there is one */ | ||
618 | if ($(int isJac)) { | ||
619 | flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); | ||
620 | if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; | ||
621 | } | ||
622 | |||
623 | /* Store initial conditions */ | ||
624 | for (j = 0; j < NEQ; j++) { | ||
625 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | ||
626 | } | ||
627 | |||
628 | /* Explicitly set the method */ | ||
629 | if ($(int method) >= MIN_DIRK_NUM) { | ||
630 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); | ||
631 | if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; | ||
632 | } else { | ||
633 | flag = ARKodeSetERKTableNum(arkode_mem, $(int method)); | ||
634 | if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; | ||
635 | } | ||
636 | |||
637 | /* Main time-stepping loop: calls ARKode to perform the integration */ | ||
638 | /* Stops when the final time has been reached */ | ||
639 | for (i = 1; i < $(int nTs); i++) { | ||
640 | |||
641 | flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ | ||
642 | if (check_flag(&flag, "ARKode", 1)) break; | ||
643 | |||
644 | /* Store the results for Haskell */ | ||
645 | for (j = 0; j < NEQ; j++) { | ||
646 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); | ||
647 | } | ||
648 | |||
649 | /* unsuccessful solve: break */ | ||
650 | if (flag < 0) { | ||
651 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
652 | break; | ||
653 | } | ||
654 | } | ||
655 | |||
656 | /* Get some final statistics on how the solve progressed */ | ||
657 | |||
658 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
659 | check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
660 | ($vec-ptr:(long int *diagMut))[0] = nst; | ||
661 | |||
662 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
663 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
664 | ($vec-ptr:(long int *diagMut))[1] = nst_a; | ||
665 | |||
666 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
667 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
668 | ($vec-ptr:(long int *diagMut))[2] = nfe; | ||
669 | ($vec-ptr:(long int *diagMut))[3] = nfi; | ||
670 | |||
671 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
672 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
673 | ($vec-ptr:(long int *diagMut))[4] = nsetups; | ||
674 | |||
675 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
676 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
677 | ($vec-ptr:(long int *diagMut))[5] = netf; | ||
678 | |||
679 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
680 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
681 | ($vec-ptr:(long int *diagMut))[6] = nni; | ||
682 | |||
683 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
684 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
685 | ($vec-ptr:(long int *diagMut))[7] = ncfn; | ||
686 | |||
687 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
688 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
689 | ($vec-ptr:(long int *diagMut))[8] = ncfn; | ||
690 | |||
691 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
692 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
693 | ($vec-ptr:(long int *diagMut))[9] = ncfn; | ||
694 | |||
695 | /* Clean up and return */ | ||
696 | N_VDestroy(y); /* Free y vector */ | ||
697 | N_VDestroy(tv); /* Free tv vector */ | ||
698 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
699 | SUNLinSolFree(LS); /* Free linear solver */ | ||
700 | SUNMatDestroy(A); /* Free A matrix */ | ||
701 | |||
702 | return flag; | ||
703 | } |] | ||
704 | if res == 0 | ||
705 | then do | ||
706 | preD <- V.freeze diagMut | ||
707 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | ||
708 | (fromIntegral $ preD V.!1) | ||
709 | (fromIntegral $ preD V.!2) | ||
710 | (fromIntegral $ preD V.!3) | ||
711 | (fromIntegral $ preD V.!4) | ||
712 | (fromIntegral $ preD V.!5) | ||
713 | (fromIntegral $ preD V.!6) | ||
714 | (fromIntegral $ preD V.!7) | ||
715 | (fromIntegral $ preD V.!8) | ||
716 | (fromIntegral $ preD V.!9) | ||
717 | m <- V.freeze qMatMut | ||
718 | return $ Right (m, d) | ||
719 | else do | ||
720 | return $ Left res | ||
721 | |||
722 | data ButcherTable = ButcherTable { am :: Matrix Double | ||
723 | , cv :: Vector Double | ||
724 | , bv :: Vector Double | ||
725 | , b2v :: Vector Double | ||
726 | } | ||
727 | deriving Show | ||
728 | |||
729 | data ButcherTable' a = ButcherTable' { am' :: V.Vector a | ||
730 | , cv' :: V.Vector a | ||
731 | , bv' :: V.Vector a | ||
732 | , b2v' :: V.Vector a | ||
733 | } | ||
734 | deriving Show | ||
735 | |||
736 | butcherTable :: ODEMethod -> ButcherTable | ||
737 | butcherTable method = | ||
738 | case getBT method of | ||
739 | Left c -> error $ show c -- FIXME | ||
740 | Right (ButcherTable' v w x y, sqp) -> | ||
741 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) | ||
742 | , cv = subVector 0 s w | ||
743 | , bv = subVector 0 s x | ||
744 | , b2v = subVector 0 s y | ||
745 | } | ||
746 | where | ||
747 | s = fromIntegral $ sqp V.! 0 | ||
748 | |||
749 | getBT :: ODEMethod -> Either Int (ButcherTable' Double, V.Vector Int) | ||
750 | getBT method = case getButcherTable method of | ||
751 | Left c -> | ||
752 | Left $ fromIntegral c | ||
753 | Right (ButcherTable' a b c d, sqp) -> | ||
754 | Right $ ( ButcherTable' (coerce a) (coerce b) (coerce c) (coerce d) | ||
755 | , V.map fromIntegral sqp ) | ||
756 | |||
757 | getButcherTable :: ODEMethod | ||
758 | -> Either CInt (ButcherTable' CDouble, V.Vector CInt) | ||
759 | getButcherTable method = unsafePerformIO $ do | ||
760 | -- ARKode seems to want an ODE in order to set and then get the | ||
761 | -- Butcher tableau so here's one to keep it happy | ||
762 | let funI :: CDouble -> V.Vector CDouble -> V.Vector CDouble | ||
763 | funI _t ys = V.fromList [ ys V.! 0 ] | ||
764 | let funE :: CDouble -> V.Vector CDouble -> V.Vector CDouble | ||
765 | funE _t ys = V.fromList [ ys V.! 0 ] | ||
766 | f0 = V.fromList [ 1.0 ] | ||
767 | ts = V.fromList [ 0.0 ] | ||
768 | dim = V.length f0 | ||
769 | nEq :: CLong | ||
770 | nEq = fromIntegral dim | ||
771 | mN :: CInt | ||
772 | mN = fromIntegral $ getMethod method | ||
773 | |||
774 | btSQP :: V.Vector CInt <- createVector 3 | ||
775 | btSQPMut <- V.thaw btSQP | ||
776 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) | ||
777 | btAsMut <- V.thaw btAs | ||
778 | btCs :: V.Vector CDouble <- createVector B.arkSMax | ||
779 | btBs :: V.Vector CDouble <- createVector B.arkSMax | ||
780 | btB2s :: V.Vector CDouble <- createVector B.arkSMax | ||
781 | btCsMut <- V.thaw btCs | ||
782 | btBsMut <- V.thaw btBs | ||
783 | btB2sMut <- V.thaw btB2s | ||
784 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
785 | funIOI x y f _ptr = do | ||
786 | fImm <- funI x <$> getDataFromContents dim y | ||
787 | putDataInContents fImm dim f | ||
788 | -- FIXME: I don't understand what this comment means | ||
789 | -- Unsafe since the function will be called many times. | ||
790 | [CU.exp| int{ 0 } |] | ||
791 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | ||
792 | funIOE x y f _ptr = do | ||
793 | fImm <- funE x <$> getDataFromContents dim y | ||
794 | putDataInContents fImm dim f | ||
795 | -- FIXME: I don't understand what this comment means | ||
796 | -- Unsafe since the function will be called many times. | ||
797 | [CU.exp| int{ 0 } |] | ||
798 | res <- [C.block| int { | ||
799 | /* general problem variables */ | ||
800 | |||
801 | int flag; /* reusable error-checking flag */ | ||
802 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
803 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
804 | int i, j; /* reusable loop indices */ | ||
805 | |||
806 | /* general problem parameters */ | ||
807 | |||
808 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ | ||
809 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars */ | ||
810 | |||
811 | /* Initialize data structures */ | ||
812 | |||
813 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
814 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
815 | /* Specify initial condition */ | ||
816 | for (i = 0; i < NEQ; i++) { | ||
817 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; | ||
818 | }; | ||
819 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
820 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
821 | |||
822 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIOE) (double t, SunVector y[], SunVector dydt[], void * params)), $fun:(int (* funIOI) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); | ||
823 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
824 | |||
825 | if ($(int mN) >= MIN_DIRK_NUM) { | ||
826 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); | ||
827 | if (check_flag(&flag, "ARKodeSetIRKTableNum", 1)) return 1; | ||
828 | } else { | ||
829 | flag = ARKodeSetERKTableNum(arkode_mem, $(int mN)); | ||
830 | if (check_flag(&flag, "ARKodeSetERKTableNum", 1)) return 1; | ||
831 | } | ||
832 | |||
833 | int s, q, p; | ||
834 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
835 | realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
836 | realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
837 | realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
838 | realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
839 | realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
840 | realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
841 | realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
842 | flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e); | ||
843 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
844 | $vec-ptr:(int *btSQPMut)[0] = s; | ||
845 | $vec-ptr:(int *btSQPMut)[1] = q; | ||
846 | $vec-ptr:(int *btSQPMut)[2] = p; | ||
847 | for (i = 0; i < s; i++) { | ||
848 | for (j = 0; j < s; j++) { | ||
849 | /* FIXME: double should be realtype */ | ||
850 | ($vec-ptr:(double *btAsMut))[i * ARK_S_MAX + j] = ai[i * ARK_S_MAX + j]; | ||
851 | } | ||
852 | } | ||
853 | |||
854 | for (i = 0; i < s; i++) { | ||
855 | ($vec-ptr:(double *btCsMut))[i] = ci[i]; | ||
856 | ($vec-ptr:(double *btBsMut))[i] = bi[i]; | ||
857 | ($vec-ptr:(double *btB2sMut))[i] = b2i[i]; | ||
858 | } | ||
859 | |||
860 | /* Clean up and return */ | ||
861 | N_VDestroy(y); /* Free y vector */ | ||
862 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
863 | |||
864 | return flag; | ||
865 | } |] | ||
866 | if res == 0 | ||
867 | then do | ||
868 | x <- V.freeze btAsMut | ||
869 | y <- V.freeze btSQPMut | ||
870 | z <- V.freeze btCsMut | ||
871 | u <- V.freeze btBsMut | ||
872 | v <- V.freeze btB2sMut | ||
873 | return $ Right (ButcherTable' { am' = x, cv' = z, bv' = u, b2v' = v }, y) | ||
874 | else do | ||
875 | return $ Left res | ||
876 | |||
877 | -- | Adaptive step-size control | ||
878 | -- functions. | ||
879 | -- | ||
880 | -- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control) | ||
881 | -- allows the user to control the step size adjustment using | ||
882 | -- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where | ||
883 | -- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\) | ||
884 | -- is the required relative error, \(s_i\) is a vector of scaling | ||
885 | -- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and | ||
886 | -- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\). | ||
887 | -- | ||
888 | -- [ARKode](https://computation.llnl.gov/projects/sundials/arkode) | ||
889 | -- allows the user to control the step size adjustment using | ||
890 | -- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with | ||
891 | -- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl), | ||
892 | -- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no | ||
893 | -- effect. | ||
894 | data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical | ||
895 | | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem | ||
896 | | XX' Double Double Double Double -- ^ include both via relative tolerance | ||
897 | -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
898 | | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\) | ||
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs new file mode 100644 index 0000000..04e4280 --- /dev/null +++ b/packages/sundials/src/Types.hs | |||
@@ -0,0 +1,40 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE EmptyDataDecls #-} | ||
8 | |||
9 | module Types where | ||
10 | |||
11 | import Foreign.C.Types | ||
12 | |||
13 | import qualified Language.Haskell.TH as TH | ||
14 | import qualified Language.C.Types as CT | ||
15 | import qualified Data.Map as Map | ||
16 | import Language.C.Inline.Context | ||
17 | |||
18 | import qualified Data.Vector.Storable as V | ||
19 | |||
20 | |||
21 | data SunVector | ||
22 | data SunMatrix = SunMatrix { rows :: CInt | ||
23 | , cols :: CInt | ||
24 | , vals :: V.Vector CDouble | ||
25 | } | ||
26 | |||
27 | -- FIXME: Is this true? | ||
28 | type SunIndexType = CLong | ||
29 | |||
30 | sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ | ||
31 | sunTypesTable = Map.fromList | ||
32 | [ | ||
33 | (CT.TypeName "sunindextype", [t| SunIndexType |] ) | ||
34 | , (CT.TypeName "SunVector", [t| SunVector |] ) | ||
35 | , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) | ||
36 | ] | ||
37 | |||
38 | sunCtx :: Context | ||
39 | sunCtx = mempty {ctxTypesTable = sunTypesTable} | ||
40 | |||
diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c new file mode 100644 index 0000000..f0ca592 --- /dev/null +++ b/packages/sundials/src/helpers.c | |||
@@ -0,0 +1,44 @@ | |||
1 | #include <stdio.h> | ||
2 | #include <math.h> | ||
3 | #include <arkode/arkode.h> /* prototypes for ARKODE fcts., consts. */ | ||
4 | #include <nvector/nvector_serial.h> /* serial N_Vector types, fcts., macros */ | ||
5 | #include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix */ | ||
6 | #include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver */ | ||
7 | #include <arkode/arkode_direct.h> /* access to ARKDls interface */ | ||
8 | #include <sundials/sundials_types.h> /* definition of type realtype */ | ||
9 | #include <sundials/sundials_math.h> | ||
10 | |||
11 | /* Check function return value... | ||
12 | opt == 0 means SUNDIALS function allocates memory so check if | ||
13 | returned NULL pointer | ||
14 | opt == 1 means SUNDIALS function returns a flag so check if | ||
15 | flag >= 0 | ||
16 | opt == 2 means function allocates memory so check if returned | ||
17 | NULL pointer | ||
18 | */ | ||
19 | int check_flag(void *flagvalue, const char *funcname, int opt) | ||
20 | { | ||
21 | int *errflag; | ||
22 | |||
23 | /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ | ||
24 | if (opt == 0 && flagvalue == NULL) { | ||
25 | fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", | ||
26 | funcname); | ||
27 | return 1; } | ||
28 | |||
29 | /* Check if flag < 0 */ | ||
30 | else if (opt == 1) { | ||
31 | errflag = (int *) flagvalue; | ||
32 | if (*errflag < 0) { | ||
33 | fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n", | ||
34 | funcname, *errflag); | ||
35 | return 1; }} | ||
36 | |||
37 | /* Check if function returned NULL pointer - no memory allocated */ | ||
38 | else if (opt == 2 && flagvalue == NULL) { | ||
39 | fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n", | ||
40 | funcname); | ||
41 | return 1; } | ||
42 | |||
43 | return 0; | ||
44 | } | ||
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h new file mode 100644 index 0000000..3d8fbc0 --- /dev/null +++ b/packages/sundials/src/helpers.h | |||
@@ -0,0 +1,9 @@ | |||
1 | /* Check function return value... | ||
2 | opt == 0 means SUNDIALS function allocates memory so check if | ||
3 | returned NULL pointer | ||
4 | opt == 1 means SUNDIALS function returns a flag so check if | ||
5 | flag >= 0 | ||
6 | opt == 2 means function allocates memory so check if returned | ||
7 | NULL pointer | ||
8 | */ | ||
9 | int check_flag(void *flagvalue, const char *funcname, int opt); | ||