summaryrefslogtreecommitdiff
path: root/packages
diff options
context:
space:
mode:
authorPeter Dobsan <pdobsan@gmail.com>2018-05-03 20:08:33 +0200
committerPeter Dobsan <pdobsan@gmail.com>2018-05-03 20:08:33 +0200
commitcafdc664c01ea7392c81c352b5c5444dc2963531 (patch)
treec6d9a758fa7c36730c0468b393a6dc8c47cbfac2 /packages
parentea1bfea4486f8f2c646f82dabd1ff9a222b68506 (diff)
parent1675813d8f540af9832a78c7a7a40bbdf1cec42c (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'packages')
-rw-r--r--packages/sundials/hmatrix-sundials.cabal18
-rw-r--r--packages/sundials/src/Main.hs72
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs269
-rw-r--r--packages/sundials/src/Numeric/Sundials/Arkode.hsc (renamed from packages/sundials/src/Arkode.hsc)98
-rw-r--r--packages/sundials/src/Numeric/Sundials/CVode/ODE.hs476
-rw-r--r--packages/sundials/src/Numeric/Sundials/ODEOpts.hs32
-rw-r--r--packages/sundials/src/Types.hs40
7 files changed, 810 insertions, 195 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal
index 388f1db..cd2be4e 100644
--- a/packages/sundials/hmatrix-sundials.cabal
+++ b/packages/sundials/hmatrix-sundials.cabal
@@ -25,21 +25,24 @@ library
25 template-haskell >=2.12 && <2.13, 25 template-haskell >=2.12 && <2.13,
26 containers >=0.5 && <0.6, 26 containers >=0.5 && <0.6,
27 hmatrix>=0.18 27 hmatrix>=0.18
28 extra-libraries: sundials_arkode 28 extra-libraries: sundials_arkode,
29 sundials_cvode
29 other-extensions: QuasiQuotes 30 other-extensions: QuasiQuotes
30 hs-source-dirs: src 31 hs-source-dirs: src
31 exposed-modules: Numeric.Sundials.ARKode.ODE 32 exposed-modules: Numeric.Sundials.ODEOpts,
32 other-modules: Types, 33 Numeric.Sundials.ARKode.ODE,
33 Arkode 34 Numeric.Sundials.CVode.ODE
35 other-modules: Numeric.Sundials.Arkode
34 c-sources: src/helpers.c src/helpers.h 36 c-sources: src/helpers.c src/helpers.h
35 default-language: Haskell2010 37 default-language: Haskell2010
36 38
37test-suite hmatrix-sundials-testsuite 39test-suite hmatrix-sundials-testsuite
38 type: exitcode-stdio-1.0 40 type: exitcode-stdio-1.0
39 main-is: Main.hs 41 main-is: Main.hs
40 other-modules: Types, 42 other-modules: Numeric.Sundials.ODEOpts,
41 Numeric.Sundials.ARKode.ODE, 43 Numeric.Sundials.ARKode.ODE,
42 Arkode 44 Numeric.Sundials.CVode.ODE,
45 Numeric.Sundials.Arkode
43 build-depends: base >=4.10 && <4.11, 46 build-depends: base >=4.10 && <4.11,
44 inline-c >=0.6 && <0.7, 47 inline-c >=0.6 && <0.7,
45 vector >=0.12 && <0.13, 48 vector >=0.12 && <0.13,
@@ -52,6 +55,7 @@ test-suite hmatrix-sundials-testsuite
52 lens, 55 lens,
53 hspec 56 hspec
54 hs-source-dirs: src 57 hs-source-dirs: src
55 extra-libraries: sundials_arkode 58 extra-libraries: sundials_arkode,
59 sundials_cvode
56 c-sources: src/helpers.c src/helpers.h 60 c-sources: src/helpers.c src/helpers.h
57 default-language: Haskell2010 61 default-language: Haskell2010
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 729d35a..16c21c5 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -1,6 +1,7 @@
1{-# OPTIONS_GHC -Wall #-} 1{-# OPTIONS_GHC -Wall #-}
2 2
3import Numeric.Sundials.ARKode.ODE 3import qualified Numeric.Sundials.ARKode.ODE as ARK
4import qualified Numeric.Sundials.CVode.ODE as CV
4import Numeric.LinearAlgebra 5import Numeric.LinearAlgebra
5 6
6import Plots as P 7import Plots as P
@@ -80,6 +81,23 @@ _stiffJac _t _v = (1><1) [ lamda ]
80 where 81 where
81 lamda = -100.0 82 lamda = -100.0
82 83
84predatorPrey :: Double -> [Double] -> [Double]
85predatorPrey _t v = [ x * a - b * x * y
86 , d * x * y - c * y - e * y * z
87 , (-f) * z + g * y * z
88 ]
89 where
90 x = v!!0
91 y = v!!1
92 z = v!!2
93 a = 1.0
94 b = 1.0
95 c = 1.0
96 d = 1.0
97 e = 1.0
98 f = 1.0
99 g = 1.0
100
83lSaxis :: [[Double]] -> P.Axis B D.V2 Double 101lSaxis :: [[Double]] -> P.Axis B D.V2 Double
84lSaxis xs = P.r2Axis &~ do 102lSaxis xs = P.r2Axis &~ do
85 let ts = xs!!0 103 let ts = xs!!0
@@ -97,33 +115,37 @@ kSaxis xs = P.r2Axis &~ do
97main :: IO () 115main :: IO ()
98main = do 116main = do
99 117
100 let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) 118 let res1 = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
101 renderRasterific "diagrams/brusselator.png" 119 renderRasterific "diagrams/brusselator.png"
102 (D.dims2D 500.0 500.0) 120 (D.dims2D 500.0 500.0)
103 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) 121 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1))
104 122
105 let res1a = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) 123 let res1a = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
106 renderRasterific "diagrams/brusselatorA.png" 124 renderRasterific "diagrams/brusselatorA.png"
107 (D.dims2D 500.0 500.0) 125 (D.dims2D 500.0 500.0)
108 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a)) 126 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a))
109 127
110 let res2 = odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) 128 let res2 = ARK.odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0])
111 renderRasterific "diagrams/stiffish.png" 129 renderRasterific "diagrams/stiffish.png"
112 (D.dims2D 500.0 500.0) 130 (D.dims2D 500.0 500.0)
113 (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) 131 (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2))
114 132
115 let res2a = odeSolveV (SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) 133 let res2a = ARK.odeSolveV (ARK.SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
116 134
117 let res2b = odeSolveV (TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) 135 let res2b = ARK.odeSolveV (ARK.TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
118 136
119 let maxDiff = maximum $ map abs $ 137 let maxDiffA = maximum $ map abs $
120 zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) 138 zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0)
121 139
122 hspec $ describe "Compare results" $ do 140 let res2c = CV.odeSolveV (CV.BDF) Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
123 it "for two different RK methods" $ 141
124 maxDiff < 1.0e-6 142 let maxDiffB = maximum $ map abs $
143 zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2c)!!0)
125 144
126 let res3 = odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) 145 let maxDiffC = maximum $ map abs $
146 zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0)
147
148 let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0])
127 149
128 renderRasterific "diagrams/lorenz.png" 150 renderRasterific "diagrams/lorenz.png"
129 (D.dims2D 500.0 500.0) 151 (D.dims2D 500.0 500.0)
@@ -136,3 +158,29 @@ main = do
136 renderRasterific "diagrams/lorenz2.png" 158 renderRasterific "diagrams/lorenz2.png"
137 (D.dims2D 500.0 500.0) 159 (D.dims2D 500.0 500.0)
138 (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2)) 160 (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2))
161
162 let res4 = CV.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0])
163
164 renderRasterific "diagrams/predatorPrey.png"
165 (D.dims2D 500.0 500.0)
166 (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!1))
167
168 renderRasterific "diagrams/predatorPrey1.png"
169 (D.dims2D 500.0 500.0)
170 (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!2))
171
172 renderRasterific "diagrams/predatorPrey2.png"
173 (D.dims2D 500.0 500.0)
174 (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!1) ((toLists $ tr res4)!!2))
175
176 let res4a = ARK.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0])
177
178 let maxDiffPpA = maximum $ map abs $
179 zipWith (-) ((toLists $ tr res4)!!0) ((toLists $ tr res4a)!!0)
180
181 hspec $ describe "Compare results" $ do
182 it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6
183 it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6
184 it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6
185 it "for CV and ARK for the Predator Prey model" $ maxDiffPpA < 1.0e-3
186
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index e5a2e4d..fafc237 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -1,5 +1,3 @@
1{-# OPTIONS_GHC -Wall #-}
2
3{-# LANGUAGE QuasiQuotes #-} 1{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-} 2{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-} 3{-# LANGUAGE MultiWayIf #-}
@@ -22,8 +20,7 @@
22-- Stability : provisional 20-- Stability : provisional
23-- 21--
24-- Solution of ordinary differential equation (ODE) initial value problems. 22-- Solution of ordinary differential equation (ODE) initial value problems.
25-- 23-- See <https://computation.llnl.gov/projects/sundials/sundials-software> for more detail.
26-- <https://computation.llnl.gov/projects/sundials/sundials-software>
27-- 24--
28-- A simple example: 25-- A simple example:
29-- 26--
@@ -67,6 +64,54 @@
67-- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) 64-- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1))
68-- @ 65-- @
69-- 66--
67-- With Sundials ARKode, it is possible to retrieve the Butcher tableau for the solver.
68--
69-- @
70-- import Numeric.Sundials.ARKode.ODE
71-- import Numeric.LinearAlgebra
72--
73-- import Data.List (intercalate)
74--
75-- import Text.PrettyPrint.HughesPJClass
76--
77--
78-- butcherTableauTex :: ButcherTable -> String
79-- butcherTableauTex (ButcherTable m c b b2) =
80-- render $
81-- vcat [ text ("\n\\begin{array}{c|" ++ (concat $ replicate n "c") ++ "}")
82-- , us
83-- , text "\\hline"
84-- , text bs <+> text "\\\\"
85-- , text b2s <+> text "\\\\"
86-- , text "\\end{array}"
87-- ]
88-- where
89-- n = rows m
90-- rs = toLists m
91-- ss = map (\r -> intercalate " & " $ map show r) rs
92-- ts = zipWith (\i r -> show i ++ " & " ++ r) (toList c) ss
93-- us = vcat $ map (\r -> text r <+> text "\\\\") ts
94-- bs = " & " ++ (intercalate " & " $ map show $ toList b)
95-- b2s = " & " ++ (intercalate " & " $ map show $ toList b2)
96--
97-- main :: IO ()
98-- main = do
99--
100-- let res = butcherTable (SDIRK_2_1_2 undefined)
101-- putStrLn $ show res
102-- putStrLn $ butcherTableauTex res
103--
104-- let resA = butcherTable (KVAERNO_4_2_3 undefined)
105-- putStrLn $ show resA
106-- putStrLn $ butcherTableauTex resA
107--
108-- let resB = butcherTable (SDIRK_5_3_4 undefined)
109-- putStrLn $ show resB
110-- putStrLn $ butcherTableauTex resB
111-- @
112--
113-- Using the code above from the examples gives
114--
70-- KVAERNO_4_2_3 115-- KVAERNO_4_2_3
71-- 116--
72-- \[ 117-- \[
@@ -116,8 +161,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve
116 , butcherTable 161 , butcherTable
117 , ODEMethod(..) 162 , ODEMethod(..)
118 , StepControl(..) 163 , StepControl(..)
119 , Jacobian
120 , SundialsDiagnostics(..)
121 ) where 164 ) where
122 165
123import qualified Language.C.Inline as C 166import qualified Language.C.Inline as C
@@ -126,27 +169,50 @@ import qualified Language.C.Inline.Unsafe as CU
126import Data.Monoid ((<>)) 169import Data.Monoid ((<>))
127import Data.Maybe (isJust) 170import Data.Maybe (isJust)
128 171
129import Foreign.C.Types 172import Foreign.C.Types (CDouble, CInt, CLong)
130import Foreign.Ptr (Ptr) 173import Foreign.Ptr (Ptr)
131import Foreign.ForeignPtr (newForeignPtr_) 174import Foreign.Storable (poke)
132import Foreign.Storable (Storable)
133 175
134import qualified Data.Vector.Storable as V 176import qualified Data.Vector.Storable as V
135import qualified Data.Vector.Storable.Mutable as VM
136 177
137import Data.Coerce (coerce) 178import Data.Coerce (coerce)
138import System.IO.Unsafe (unsafePerformIO) 179import System.IO.Unsafe (unsafePerformIO)
139import GHC.Generics 180import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..),
181 from, conName)
140 182
141import Numeric.LinearAlgebra.Devel (createVector) 183import Numeric.LinearAlgebra.Devel (createVector)
142 184
143import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), 185import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows,
144 subMatrix, rows, cols, toLists, 186 cols, toLists, size, reshape,
145 size, subVector) 187 subVector, subMatrix, (><))
146 188
147import qualified Types as T 189import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..))
148import Arkode 190import qualified Numeric.Sundials.Arkode as T
149import qualified Arkode as B 191import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax,
192 sDIRK_2_1_2,
193 bILLINGTON_3_3_2,
194 tRBDF2_3_3_2,
195 kVAERNO_4_2_3,
196 aRK324L2SA_DIRK_4_2_3,
197 cASH_5_2_4,
198 cASH_5_3_4,
199 sDIRK_5_3_4,
200 kVAERNO_5_3_4,
201 aRK436L2SA_DIRK_6_3_4,
202 kVAERNO_7_4_5,
203 aRK548L2SA_DIRK_8_4_5,
204 hEUN_EULER_2_1_2,
205 bOGACKI_SHAMPINE_4_2_3,
206 aRK324L2SA_ERK_4_2_3,
207 zONNEVELD_5_3_4,
208 aRK436L2SA_ERK_6_3_4,
209 sAYFY_ABURUB_6_3_4,
210 cASH_KARP_6_4_5,
211 fEHLBERG_6_4_5,
212 dORMAND_PRINCE_7_4_5,
213 aRK548L2SA_ERK_8_4_5,
214 vERNER_8_5_6,
215 fEHLBERG_13_7_8)
150 216
151 217
152C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) 218C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
@@ -162,69 +228,8 @@ C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
162C.include "<sundials/sundials_types.h>" -- definition of type realtype 228C.include "<sundials/sundials_types.h>" -- definition of type realtype
163C.include "<sundials/sundials_math.h>" 229C.include "<sundials/sundials_math.h>"
164C.include "../../../helpers.h" 230C.include "../../../helpers.h"
165C.include "Arkode_hsc.h" 231C.include "Numeric/Sundials/Arkode_hsc.h"
166 232
167
168getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble)
169getDataFromContents len ptr = do
170 qtr <- B.getContentPtr ptr
171 rtr <- B.getData qtr
172 vectorFromC len rtr
173
174-- FIXME: Potentially an instance of Storable
175_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix
176_getMatrixDataFromContents ptr = do
177 qtr <- B.getContentMatrixPtr ptr
178 rs <- B.getNRows qtr
179 cs <- B.getNCols qtr
180 rtr <- B.getMatrixData qtr
181 vs <- vectorFromC (fromIntegral $ rs * cs) rtr
182 return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs }
183
184putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO ()
185putMatrixDataFromContents mat ptr = do
186 let rs = T.rows mat
187 cs = T.cols mat
188 vs = T.vals mat
189 qtr <- B.getContentMatrixPtr ptr
190 B.putNRows rs qtr
191 B.putNCols cs qtr
192 rtr <- B.getMatrixData qtr
193 vectorToC vs (fromIntegral $ rs * cs) rtr
194-- FIXME: END
195
196putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
197putDataInContents vec len ptr = do
198 qtr <- B.getContentPtr ptr
199 rtr <- B.getData qtr
200 vectorToC vec len rtr
201
202-- Utils
203
204vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
205vectorFromC len ptr = do
206 ptr' <- newForeignPtr_ ptr
207 V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
208
209vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
210vectorToC vec len ptr = do
211 ptr' <- newForeignPtr_ ptr
212 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
213
214data SundialsDiagnostics = SundialsDiagnostics {
215 aRKodeGetNumSteps :: Int
216 , aRKodeGetNumStepAttempts :: Int
217 , aRKodeGetNumRhsEvals_fe :: Int
218 , aRKodeGetNumRhsEvals_fi :: Int
219 , aRKodeGetNumLinSolvSetups :: Int
220 , aRKodeGetNumErrTestFails :: Int
221 , aRKodeGetNumNonlinSolvIters :: Int
222 , aRKodeGetNumNonlinSolvConvFails :: Int
223 , aRKDlsGetNumJacEvals :: Int
224 , aRKDlsGetNumRhsEvals :: Int
225 } deriving Show
226
227type Jacobian = Double -> Vector Double -> Matrix Double
228 233
229-- | Stepping functions 234-- | Stepping functions
230data ODEMethod = SDIRK_2_1_2 Jacobian 235data ODEMethod = SDIRK_2_1_2 Jacobian
@@ -390,15 +395,9 @@ odeSolveV
390 -> Vector Double -- ^ desired solution times 395 -> Vector Double -- ^ desired solution times
391 -> Matrix Double -- ^ solution 396 -> Matrix Double -- ^ solution
392odeSolveV meth hi epsAbs epsRel f y0 ts = 397odeSolveV meth hi epsAbs epsRel f y0 ts =
393 case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of 398 odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts
394 Left c -> error $ show c -- FIXME 399 where
395 -- FIXME: Can we do better than using lists? 400 g t x0 = coerce $ f t x0
396 Right (v, _d) -> (nR >< nC) (V.toList v)
397 where
398 us = toList ts
399 nR = length us
400 nC = size y0
401 g t x0 = coerce $ f t x0
402 401
403-- | A version of 'odeSolveV' with reasonable default parameters and 402-- | A version of 'odeSolveV' with reasonable default parameters and
404-- system of equations defined using lists. FIXME: we should say 403-- system of equations defined using lists. FIXME: we should say
@@ -410,16 +409,11 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
410 -> Matrix Double -- ^ solution 409 -> Matrix Double -- ^ solution
411odeSolve f y0 ts = 410odeSolve f y0 ts =
412 -- FIXME: These tolerances are different from the ones in GSL 411 -- FIXME: These tolerances are different from the ones in GSL
413 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 412 odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts)
414 Left c -> error $ show c -- FIXME
415 Right (v, _d) -> (nR >< nC) (V.toList v)
416 where 413 where
417 us = toList ts
418 nR = length us
419 nC = length y0
420 g t x0 = V.fromList $ f t (V.toList x0) 414 g t x0 = V.fromList $ f t (V.toList x0)
421 415
422odeSolveVWith' :: 416odeSolveVWith ::
423 ODEMethod 417 ODEMethod
424 -> StepControl 418 -> StepControl
425 -> Maybe Double -- ^ initial step size - by default, ARKode 419 -> Maybe Double -- ^ initial step size - by default, ARKode
@@ -432,16 +426,22 @@ odeSolveVWith' ::
432 -> V.Vector Double -- ^ Initial conditions 426 -> V.Vector Double -- ^ Initial conditions
433 -> V.Vector Double -- ^ Desired solution times 427 -> V.Vector Double -- ^ Desired solution times
434 -> Matrix Double -- ^ Error code or solution 428 -> Matrix Double -- ^ Error code or solution
435odeSolveVWith' method control initStepSize f y0 tt = 429odeSolveVWith method control initStepSize f y0 tt =
436 case odeSolveVWith method control initStepSize f y0 tt of 430 case odeSolveVWith' opts method control initStepSize f y0 tt of
437 Left c -> error $ show c -- FIXME 431 Left c -> error $ show c -- FIXME
438 Right (v, _d) -> (nR >< nC) (V.toList v) 432 Right (v, _d) -> v
439 where 433 where
440 nR = V.length tt 434 opts = ODEOpts { maxNumSteps = 10000
441 nC = V.length y0 435 , minStep = 1.0e-12
436 , relTol = error "relTol"
437 , absTols = error "absTol"
438 , initStep = error "initStep"
439 , maxFail = 10
440 }
442 441
443odeSolveVWith :: 442odeSolveVWith' ::
444 ODEMethod 443 ODEOpts
444 -> ODEMethod
445 -> StepControl 445 -> StepControl
446 -> Maybe Double -- ^ initial step size - by default, ARKode 446 -> Maybe Double -- ^ initial step size - by default, ARKode
447 -- estimates the initial step size to be the 447 -- estimates the initial step size to be the
@@ -452,19 +452,21 @@ odeSolveVWith ::
452 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 452 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
453 -> V.Vector Double -- ^ Initial conditions 453 -> V.Vector Double -- ^ Initial conditions
454 -> V.Vector Double -- ^ Desired solution times 454 -> V.Vector Double -- ^ Desired solution times
455 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 455 -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution
456odeSolveVWith method control initStepSize f y0 tt = 456odeSolveVWith' opts method control initStepSize f y0 tt =
457 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) 457 case solveOdeC (fromIntegral $ maxFail opts)
458 (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
459 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
458 (coerce f) (coerce y0) (coerce tt) of 460 (coerce f) (coerce y0) (coerce tt) of
459 Left c -> Left $ fromIntegral c 461 Left c -> Left $ fromIntegral c
460 Right (v, d) -> Right (coerce v, d) 462 Right (v, d) -> Right (reshape l (coerce v), d)
461 where 463 where
462 l = size y0 464 l = size y0
463 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) 465 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
464 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) 466 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
465 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) 467 scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol)
466 -- FIXME; Should we check that the length of ss is correct? 468 -- FIXME; Should we check that the length of ss is correct?
467 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) 469 scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol)
468 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ 470 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
469 getJacobian method 471 getJacobian method
470 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } 472 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
@@ -476,6 +478,9 @@ odeSolveVWith method control initStepSize f y0 tt =
476 478
477solveOdeC :: 479solveOdeC ::
478 CInt -> 480 CInt ->
481 CLong ->
482 CDouble ->
483 CInt ->
479 Maybe CDouble -> 484 Maybe CDouble ->
480 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> 485 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
481 (V.Vector CDouble, CDouble) -> 486 (V.Vector CDouble, CDouble) ->
@@ -483,7 +488,8 @@ solveOdeC ::
483 -> V.Vector CDouble -- ^ Initial conditions 488 -> V.Vector CDouble -- ^ Initial conditions
484 -> V.Vector CDouble -- ^ Desired solution times 489 -> V.Vector CDouble -- ^ Desired solution times
485 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 490 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
486solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do 491solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize
492 jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do
487 493
488 let isInitStepSize :: CInt 494 let isInitStepSize :: CInt
489 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize 495 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
@@ -494,14 +500,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
494 -- used :( 500 -- used :(
495 Nothing -> 0.0 501 Nothing -> 0.0
496 Just x -> x 502 Just x -> x
503
497 let dim = V.length f0 504 let dim = V.length f0
498 nEq :: CLong 505 nEq :: CLong
499 nEq = fromIntegral dim 506 nEq = fromIntegral dim
500 nTs :: CInt 507 nTs :: CInt
501 nTs = fromIntegral $ V.length ts 508 nTs = fromIntegral $ V.length ts
502 -- FIXME: fMut is not actually mutatated
503 fMut <- V.thaw f0
504 tMut <- V.thaw ts
505 -- FIXME: I believe this gets taken from the ghc heap and so should 509 -- FIXME: I believe this gets taken from the ghc heap and so should
506 -- be subject to garbage collection. 510 -- be subject to garbage collection.
507 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) 511 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
@@ -509,7 +513,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
509 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME 513 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
510 diagMut <- V.thaw diagnostics 514 diagMut <- V.thaw diagnostics
511 -- We need the types that sundials expects. These are tied together 515 -- We need the types that sundials expects. These are tied together
512 -- in 'Types'. FIXME: The Haskell type is currently empty! 516 -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty!
513 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt 517 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
514 funIO x y f _ptr = do 518 funIO x y f _ptr = do
515 -- Convert the pointer we get from C (y) to a vector, and then 519 -- Convert the pointer we get from C (y) to a vector, and then
@@ -529,7 +533,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
529 case jacH of 533 case jacH of
530 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" 534 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined"
531 Just jacI -> do j <- jacI t <$> getDataFromContents dim y 535 Just jacI -> do j <- jacI t <$> getDataFromContents dim y
532 putMatrixDataFromContents j jacS 536 poke jacS j
533 -- FIXME: I don't understand what this comment means 537 -- FIXME: I don't understand what this comment means
534 -- Unsafe since the function will be called many times. 538 -- Unsafe since the function will be called many times.
535 [CU.exp| int{ 0 } |] 539 [CU.exp| int{ 0 } |]
@@ -549,7 +553,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
549 553
550 /* general problem parameters */ 554 /* general problem parameters */
551 555
552 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ 556 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
553 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ 557 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
554 558
555 /* Initialize data structures */ 559 /* Initialize data structures */
@@ -558,14 +562,14 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
558 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 562 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
559 /* Specify initial condition */ 563 /* Specify initial condition */
560 for (i = 0; i < NEQ; i++) { 564 for (i = 0; i < NEQ; i++) {
561 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; 565 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
562 }; 566 };
563 567
564 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ 568 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
565 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; 569 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
566 /* Specify tolerances */ 570 /* Specify tolerances */
567 for (i = 0; i < NEQ; i++) { 571 for (i = 0; i < NEQ; i++) {
568 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; 572 NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i];
569 }; 573 };
570 574
571 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 575 arkode_mem = ARKodeCreate(); /* Create the solver memory */
@@ -577,7 +581,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
577 /* problem as fully implicit and set f_E to NULL and f_I to f. */ 581 /* problem as fully implicit and set f_E to NULL and f_I to f. */
578 582
579 /* Here we use the C types defined in helpers.h which tie up with */ 583 /* Here we use the C types defined in helpers.h which tie up with */
580 /* the Haskell types defined in Types */ 584 /* the Haskell types defined in CLangToHaskellTypes */
581 if ($(int method) < MIN_DIRK_NUM) { 585 if ($(int method) < MIN_DIRK_NUM) {
582 flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); 586 flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y);
583 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 587 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
@@ -586,14 +590,15 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
586 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 590 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
587 } 591 }
588 592
589 /* FIXME: A hack for initial testing */ 593 flag = ARKodeSetMinStep(arkode_mem, $(double minStep_));
590 flag = ARKodeSetMinStep(arkode_mem, 1.0e-12);
591 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; 594 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1;
592 flag = ARKodeSetMaxNumSteps(arkode_mem, 10000); 595 flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_));
593 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; 596 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1;
597 flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails));
598 if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1;
594 599
595 /* Set routines */ 600 /* Set routines */
596 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); 601 flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv);
597 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; 602 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1;
598 603
599 /* Initialize dense matrix data structure and solver */ 604 /* Initialize dense matrix data structure and solver */
@@ -638,7 +643,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
638 /* Stops when the final time has been reached */ 643 /* Stops when the final time has been reached */
639 for (i = 1; i < $(int nTs); i++) { 644 for (i = 1; i < $(int nTs); i++) {
640 645
641 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ 646 flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */
642 if (check_flag(&flag, "ARKode", 1)) break; 647 if (check_flag(&flag, "ARKode", 1)) break;
643 648
644 /* Store the results for Haskell */ 649 /* Store the results for Haskell */
@@ -738,7 +743,7 @@ butcherTable method =
738 case getBT method of 743 case getBT method of
739 Left c -> error $ show c -- FIXME 744 Left c -> error $ show c -- FIXME
740 Right (ButcherTable' v w x y, sqp) -> 745 Right (ButcherTable' v w x y, sqp) ->
741 ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) 746 ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v)
742 , cv = subVector 0 s w 747 , cv = subVector 0 s w
743 , bv = subVector 0 s x 748 , bv = subVector 0 s x
744 , b2v = subVector 0 s y 749 , b2v = subVector 0 s y
@@ -773,11 +778,11 @@ getButcherTable method = unsafePerformIO $ do
773 778
774 btSQP :: V.Vector CInt <- createVector 3 779 btSQP :: V.Vector CInt <- createVector 3
775 btSQPMut <- V.thaw btSQP 780 btSQPMut <- V.thaw btSQP
776 btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) 781 btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax)
777 btAsMut <- V.thaw btAs 782 btAsMut <- V.thaw btAs
778 btCs :: V.Vector CDouble <- createVector B.arkSMax 783 btCs :: V.Vector CDouble <- createVector arkSMax
779 btBs :: V.Vector CDouble <- createVector B.arkSMax 784 btBs :: V.Vector CDouble <- createVector arkSMax
780 btB2s :: V.Vector CDouble <- createVector B.arkSMax 785 btB2s :: V.Vector CDouble <- createVector arkSMax
781 btCsMut <- V.thaw btCs 786 btCsMut <- V.thaw btCs
782 btBsMut <- V.thaw btBs 787 btBsMut <- V.thaw btBs
783 btB2sMut <- V.thaw btB2s 788 btB2sMut <- V.thaw btB2s
diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc
index 9db37b5..0850258 100644
--- a/packages/sundials/src/Arkode.hsc
+++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc
@@ -1,7 +1,23 @@
1module Arkode where 1{-# LANGUAGE QuasiQuotes #-}
2{-# LANGUAGE TemplateHaskell #-}
3{-# LANGUAGE OverloadedStrings #-}
4{-# LANGUAGE EmptyDataDecls #-}
2 5
3import Foreign 6module Numeric.Sundials.Arkode where
4import Foreign.C.Types 7
8import Foreign
9import Foreign.C.Types
10
11import Language.C.Types as CT
12
13import qualified Data.Vector.Storable as VS
14import qualified Data.Vector.Storable.Mutable as VM
15
16import qualified Language.Haskell.TH as TH
17import qualified Data.Map as Map
18import Language.C.Inline.Context
19
20import qualified Data.Vector.Storable as V
5 21
6 22
7#include <stdio.h> 23#include <stdio.h>
@@ -10,7 +26,76 @@ import Foreign.C.Types
10#include <nvector/nvector_serial.h> 26#include <nvector/nvector_serial.h>
11#include <sunmatrix/sunmatrix_dense.h> 27#include <sunmatrix/sunmatrix_dense.h>
12#include <arkode/arkode.h> 28#include <arkode/arkode.h>
13 29#include <cvode/cvode.h>
30
31
32data SunVector
33data SunMatrix = SunMatrix { rows :: CInt
34 , cols :: CInt
35 , vals :: V.Vector CDouble
36 }
37
38-- | This is true only if configured/ built as 64 bits
39type SunIndexType = CLong
40
41sunTypesTable :: Map.Map TypeSpecifier TH.TypeQ
42sunTypesTable = Map.fromList
43 [
44 (TypeName "sunindextype", [t| SunIndexType |] )
45 , (TypeName "SunVector", [t| SunVector |] )
46 , (TypeName "SunMatrix", [t| SunMatrix |] )
47 ]
48
49sunCtx :: Context
50sunCtx = mempty {ctxTypesTable = sunTypesTable}
51
52getMatrixDataFromContents :: Ptr SunMatrix -> IO SunMatrix
53getMatrixDataFromContents ptr = do
54 qtr <- getContentMatrixPtr ptr
55 rs <- getNRows qtr
56 cs <- getNCols qtr
57 rtr <- getMatrixData qtr
58 vs <- vectorFromC (fromIntegral $ rs * cs) rtr
59 return $ SunMatrix { rows = rs, cols = cs, vals = vs }
60
61putMatrixDataFromContents :: SunMatrix -> Ptr SunMatrix -> IO ()
62putMatrixDataFromContents mat ptr = do
63 let rs = rows mat
64 cs = cols mat
65 vs = vals mat
66 qtr <- getContentMatrixPtr ptr
67 putNRows rs qtr
68 putNCols cs qtr
69 rtr <- getMatrixData qtr
70 vectorToC vs (fromIntegral $ rs * cs) rtr
71
72instance Storable SunMatrix where
73 poke = flip putMatrixDataFromContents
74 peek = getMatrixDataFromContents
75 sizeOf _ = error "sizeOf not supported for SunMatrix"
76 alignment _ = error "alignment not supported for SunMatrix"
77
78vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a)
79vectorFromC len ptr = do
80 ptr' <- newForeignPtr_ ptr
81 VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len
82
83vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO ()
84vectorToC vec len ptr = do
85 ptr' <- newForeignPtr_ ptr
86 VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
87
88getDataFromContents :: Int -> Ptr SunVector -> IO (VS.Vector CDouble)
89getDataFromContents len ptr = do
90 qtr <- getContentPtr ptr
91 rtr <- getData qtr
92 vectorFromC len rtr
93
94putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO ()
95putDataInContents vec len ptr = do
96 qtr <- getContentPtr ptr
97 rtr <- getData qtr
98 vectorToC vec len rtr
14 99
15#def typedef struct _generic_N_Vector SunVector; 100#def typedef struct _generic_N_Vector SunVector;
16#def typedef struct _N_VectorContent_Serial SunContent; 101#def typedef struct _N_VectorContent_Serial SunContent;
@@ -40,6 +125,11 @@ getContentPtr ptr = (#peek SunVector, content) ptr
40getData :: Storable a => Ptr b -> IO a 125getData :: Storable a => Ptr b -> IO a
41getData ptr = (#peek SunContent, data) ptr 126getData ptr = (#peek SunContent, data) ptr
42 127
128cV_ADAMS :: Int
129cV_ADAMS = #const CV_ADAMS
130cV_BDF :: Int
131cV_BDF = #const CV_BDF
132
43arkSMax :: Int 133arkSMax :: Int
44arkSMax = #const ARK_S_MAX 134arkSMax = #const ARK_S_MAX
45 135
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
new file mode 100644
index 0000000..a6f185e
--- /dev/null
+++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
@@ -0,0 +1,476 @@
1{-# OPTIONS_GHC -Wall #-}
2
3{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-}
6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8
9-----------------------------------------------------------------------------
10-- |
11-- Module : Numeric.Sundials.CVode.ODE
12-- Copyright : Dominic Steinitz 2018,
13-- Novadiscovery 2018
14-- License : BSD
15-- Maintainer : Dominic Steinitz
16-- Stability : provisional
17--
18-- Solution of ordinary differential equation (ODE) initial value problems.
19--
20-- <https://computation.llnl.gov/projects/sundials/sundials-software>
21--
22-- A simple example:
23--
24-- <<diagrams/brusselator.png#diagram=brusselator&height=400&width=500>>
25--
26-- @
27-- import Numeric.Sundials.CVode.ODE
28-- import Numeric.LinearAlgebra
29--
30-- import Plots as P
31-- import qualified Diagrams.Prelude as D
32-- import Diagrams.Backend.Rasterific
33--
34-- brusselator :: Double -> [Double] -> [Double]
35-- brusselator _t x = [ a - (w + 1) * u + v * u * u
36-- , w * u - v * u * u
37-- , (b - w) / eps - w * u
38-- ]
39-- where
40-- a = 1.0
41-- b = 3.5
42-- eps = 5.0e-6
43-- u = x !! 0
44-- v = x !! 1
45-- w = x !! 2
46--
47-- lSaxis :: [[Double]] -> P.Axis B D.V2 Double
48-- lSaxis xs = P.r2Axis &~ do
49-- let ts = xs!!0
50-- us = xs!!1
51-- vs = xs!!2
52-- ws = xs!!3
53-- P.linePlot' $ zip ts us
54-- P.linePlot' $ zip ts vs
55-- P.linePlot' $ zip ts ws
56--
57-- main = do
58-- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
59-- renderRasterific "diagrams/brusselator.png"
60-- (D.dims2D 500.0 500.0)
61-- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1))
62-- @
63--
64-----------------------------------------------------------------------------
65module Numeric.Sundials.CVode.ODE ( odeSolve
66 , odeSolveV
67 , odeSolveVWith
68 , odeSolveVWith'
69 , ODEMethod(..)
70 , StepControl(..)
71 ) where
72
73import qualified Language.C.Inline as C
74import qualified Language.C.Inline.Unsafe as CU
75
76import Data.Monoid ((<>))
77import Data.Maybe (isJust)
78
79import Foreign.C.Types (CDouble, CInt, CLong)
80import Foreign.Ptr (Ptr)
81import Foreign.Storable (poke)
82
83import qualified Data.Vector.Storable as V
84
85import Data.Coerce (coerce)
86import System.IO.Unsafe (unsafePerformIO)
87
88import Numeric.LinearAlgebra.Devel (createVector)
89
90import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows,
91 cols, toLists, size, reshape)
92
93import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF,
94 getDataFromContents, putDataInContents)
95import qualified Numeric.Sundials.Arkode as T
96import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..))
97
98
99C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
100
101C.include "<stdlib.h>"
102C.include "<stdio.h>"
103C.include "<math.h>"
104C.include "<cvode/cvode.h>" -- prototypes for CVODE fcts., consts.
105C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
106C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
107C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
108C.include "<cvode/cvode_direct.h>" -- access to CVDls interface
109C.include "<sundials/sundials_types.h>" -- definition of type realtype
110C.include "<sundials/sundials_math.h>"
111C.include "../../../helpers.h"
112C.include "Numeric/Sundials/Arkode_hsc.h"
113
114
115-- | Stepping functions
116data ODEMethod = ADAMS
117 | BDF
118
119getMethod :: ODEMethod -> Int
120getMethod (ADAMS) = cV_ADAMS
121getMethod (BDF) = cV_BDF
122
123getJacobian :: ODEMethod -> Maybe Jacobian
124getJacobian _ = Nothing
125
126-- | A version of 'odeSolveVWith' with reasonable default step control.
127odeSolveV
128 :: ODEMethod
129 -> Maybe Double -- ^ initial step size - by default, CVode
130 -- estimates the initial step size to be the
131 -- solution \(h\) of the equation
132 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
133 -- \(\ddot{y}\) is an estimated value of the
134 -- second derivative of the solution at \(t_0\)
135 -> Double -- ^ absolute tolerance for the state vector
136 -> Double -- ^ relative tolerance for the state vector
137 -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
138 -> Vector Double -- ^ initial conditions
139 -> Vector Double -- ^ desired solution times
140 -> Matrix Double -- ^ solution
141odeSolveV meth hi epsAbs epsRel f y0 ts =
142 odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts
143 where
144 g t x0 = coerce $ f t x0
145
146-- | A version of 'odeSolveV' with reasonable default parameters and
147-- system of equations defined using lists. FIXME: we should say
148-- something about the fact we could use the Jacobian but don't for
149-- compatibility with hmatrix-gsl.
150odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
151 -> [Double] -- ^ initial conditions
152 -> Vector Double -- ^ desired solution times
153 -> Matrix Double -- ^ solution
154odeSolve f y0 ts =
155 -- FIXME: These tolerances are different from the ones in GSL
156 odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts)
157 where
158 g t x0 = V.fromList $ f t (V.toList x0)
159
160odeSolveVWith ::
161 ODEMethod
162 -> StepControl
163 -> Maybe Double -- ^ initial step size - by default, CVode
164 -- estimates the initial step size to be the
165 -- solution \(h\) of the equation
166 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
167 -- \(\ddot{y}\) is an estimated value of the second
168 -- derivative of the solution at \(t_0\)
169 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
170 -> V.Vector Double -- ^ Initial conditions
171 -> V.Vector Double -- ^ Desired solution times
172 -> Matrix Double -- ^ Error code or solution
173odeSolveVWith method control initStepSize f y0 tt =
174 case odeSolveVWith' opts method control initStepSize f y0 tt of
175 Left c -> error $ show c -- FIXME
176 Right (v, _d) -> v
177 where
178 opts = ODEOpts { maxNumSteps = 10000
179 , minStep = 1.0e-12
180 , relTol = error "relTol"
181 , absTols = error "absTol"
182 , initStep = error "initStep"
183 , maxFail = 10
184 }
185
186odeSolveVWith' ::
187 ODEOpts
188 -> ODEMethod
189 -> StepControl
190 -> Maybe Double -- ^ initial step size - by default, CVode
191 -- estimates the initial step size to be the
192 -- solution \(h\) of the equation
193 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
194 -- \(\ddot{y}\) is an estimated value of the second
195 -- derivative of the solution at \(t_0\)
196 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
197 -> V.Vector Double -- ^ Initial conditions
198 -> V.Vector Double -- ^ Desired solution times
199 -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution
200odeSolveVWith' opts method control initStepSize f y0 tt =
201 case solveOdeC (fromIntegral $ maxFail opts)
202 (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
203 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
204 (coerce f) (coerce y0) (coerce tt) of
205 Left c -> Left $ fromIntegral c
206 Right (v, d) -> Right (reshape l (coerce v), d)
207 where
208 l = size y0
209 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
210 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
211 scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol)
212 -- FIXME; Should we check that the length of ss is correct?
213 scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol)
214 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
215 getJacobian method
216 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
217 where
218 nr = fromIntegral $ rows m
219 nc = fromIntegral $ cols m
220 -- FIXME: efficiency
221 vs = V.fromList $ map coerce $ concat $ toLists m
222
223solveOdeC ::
224 CInt ->
225 CLong ->
226 CDouble ->
227 CInt ->
228 Maybe CDouble ->
229 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
230 (V.Vector CDouble, CDouble) ->
231 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
232 -> V.Vector CDouble -- ^ Initial conditions
233 -> V.Vector CDouble -- ^ Desired solution times
234 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
235solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize
236 jacH (aTols, rTol) fun f0 ts =
237 unsafePerformIO $ do
238
239 let isInitStepSize :: CInt
240 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
241 ss :: CDouble
242 ss = case initStepSize of
243 -- It would be better to put an error message here but
244 -- inline-c seems to evaluate this even if it is never
245 -- used :(
246 Nothing -> 0.0
247 Just x -> x
248
249 let dim = V.length f0
250 nEq :: CLong
251 nEq = fromIntegral dim
252 nTs :: CInt
253 nTs = fromIntegral $ V.length ts
254 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
255 qMatMut <- V.thaw quasiMatrixRes
256 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
257 diagMut <- V.thaw diagnostics
258 -- We need the types that sundials expects. These are tied together
259 -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty!
260 let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
261 funIO x y f _ptr = do
262 -- Convert the pointer we get from C (y) to a vector, and then
263 -- apply the user-supplied function.
264 fImm <- fun x <$> getDataFromContents dim y
265 -- Fill in the provided pointer with the resulting vector.
266 putDataInContents fImm dim f
267 -- FIXME: I don't understand what this comment means
268 -- Unsafe since the function will be called many times.
269 [CU.exp| int{ 0 } |]
270 let isJac :: CInt
271 isJac = fromIntegral $ fromEnum $ isJust jacH
272 jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix ->
273 Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector ->
274 IO CInt
275 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
276 case jacH of
277 Nothing -> error "Numeric.Sundials.CVode.ODE: Jacobian not defined"
278 Just jacI -> do j <- jacI t <$> getDataFromContents dim y
279 poke jacS j
280 -- FIXME: I don't understand what this comment means
281 -- Unsafe since the function will be called many times.
282 [CU.exp| int{ 0 } |]
283
284 res <- [C.block| int {
285 /* general problem variables */
286
287 int flag; /* reusable error-checking flag */
288 int i, j; /* reusable loop indices */
289 N_Vector y = NULL; /* empty vector for storing solution */
290 N_Vector tv = NULL; /* empty vector for storing absolute tolerances */
291
292 SUNMatrix A = NULL; /* empty matrix for linear solver */
293 SUNLinearSolver LS = NULL; /* empty linear solver object */
294 void *cvode_mem = NULL; /* empty CVODE memory structure */
295 realtype t;
296 long int nst, nfe, nsetups, nje, nfeLS, nni, ncfn, netf, nge;
297
298 /* general problem parameters */
299
300 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
301 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
302
303 /* Initialize data structures */
304
305 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
306 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
307 /* Specify initial condition */
308 for (i = 0; i < NEQ; i++) {
309 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
310 };
311
312 cvode_mem = CVodeCreate($(int method), CV_NEWTON);
313 if (check_flag((void *)cvode_mem, "CVodeCreate", 0)) return(1);
314
315 /* Call CVodeInit to initialize the integrator memory and specify the
316 * user's right hand side function in y'=f(t,y), the inital time T0, and
317 * the initial dependent variable vector y. */
318 flag = CVodeInit(cvode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
319 if (check_flag(&flag, "CVodeInit", 1)) return(1);
320
321 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
322 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
323 /* Specify tolerances */
324 for (i = 0; i < NEQ; i++) {
325 NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i];
326 };
327
328 flag = CVodeSetMinStep(cvode_mem, $(double minStep_));
329 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1;
330 flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_));
331 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1;
332 flag = CVodeSetMaxErrTestFails(cvode_mem, $(int maxErrTestFails));
333 if (check_flag(&flag, "CVodeSetMaxErrTestFails", 1)) return 1;
334
335 /* Call CVodeSVtolerances to specify the scalar relative tolerance
336 * and vector absolute tolerances */
337 flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv);
338 if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1);
339
340 /* Initialize dense matrix data structure and solver */
341 A = SUNDenseMatrix(NEQ, NEQ);
342 if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
343 LS = SUNDenseLinearSolver(y, A);
344 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
345
346 /* Attach matrix and linear solver */
347 flag = CVDlsSetLinearSolver(cvode_mem, LS, A);
348 if (check_flag(&flag, "CVDlsSetLinearSolver", 1)) return 1;
349
350 /* Set the initial step size if there is one */
351 if ($(int isInitStepSize)) {
352 /* FIXME: We could check if the initial step size is 0 */
353 /* or even NaN and then throw an error */
354 flag = CVodeSetInitStep(cvode_mem, $(double ss));
355 if (check_flag(&flag, "CVodeSetInitStep", 1)) return 1;
356 }
357
358 /* Set the Jacobian if there is one */
359 if ($(int isJac)) {
360 flag = CVDlsSetJacFn(cvode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
361 if (check_flag(&flag, "CVDlsSetJacFn", 1)) return 1;
362 }
363
364 /* Store initial conditions */
365 for (j = 0; j < NEQ; j++) {
366 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
367 }
368
369 /* Main time-stepping loop: calls CVode to perform the integration */
370 /* Stops when the final time has been reached */
371 for (i = 1; i < $(int nTs); i++) {
372
373 flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */
374 if (check_flag(&flag, "CVode", 1)) break;
375
376 /* Store the results for Haskell */
377 for (j = 0; j < NEQ; j++) {
378 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
379 }
380
381 /* unsuccessful solve: break */
382 if (flag < 0) {
383 fprintf(stderr,"Solver failure, stopping integration\n");
384 break;
385 }
386 }
387
388 /* Get some final statistics on how the solve progressed */
389
390 flag = CVodeGetNumSteps(cvode_mem, &nst);
391 check_flag(&flag, "CVodeGetNumSteps", 1);
392 ($vec-ptr:(long int *diagMut))[0] = nst;
393
394 /* FIXME */
395 ($vec-ptr:(long int *diagMut))[1] = 0;
396
397 flag = CVodeGetNumRhsEvals(cvode_mem, &nfe);
398 check_flag(&flag, "CVodeGetNumRhsEvals", 1);
399 ($vec-ptr:(long int *diagMut))[2] = nfe;
400 /* FIXME */
401 ($vec-ptr:(long int *diagMut))[3] = 0;
402
403 flag = CVodeGetNumLinSolvSetups(cvode_mem, &nsetups);
404 check_flag(&flag, "CVodeGetNumLinSolvSetups", 1);
405 ($vec-ptr:(long int *diagMut))[4] = nsetups;
406
407 flag = CVodeGetNumErrTestFails(cvode_mem, &netf);
408 check_flag(&flag, "CVodeGetNumErrTestFails", 1);
409 ($vec-ptr:(long int *diagMut))[5] = netf;
410
411 flag = CVodeGetNumNonlinSolvIters(cvode_mem, &nni);
412 check_flag(&flag, "CVodeGetNumNonlinSolvIters", 1);
413 ($vec-ptr:(long int *diagMut))[6] = nni;
414
415 flag = CVodeGetNumNonlinSolvConvFails(cvode_mem, &ncfn);
416 check_flag(&flag, "CVodeGetNumNonlinSolvConvFails", 1);
417 ($vec-ptr:(long int *diagMut))[7] = ncfn;
418
419 flag = CVDlsGetNumJacEvals(cvode_mem, &nje);
420 check_flag(&flag, "CVDlsGetNumJacEvals", 1);
421 ($vec-ptr:(long int *diagMut))[8] = ncfn;
422
423 flag = CVDlsGetNumRhsEvals(cvode_mem, &nfeLS);
424 check_flag(&flag, "CVDlsGetNumRhsEvals", 1);
425 ($vec-ptr:(long int *diagMut))[9] = ncfn;
426
427 /* Clean up and return */
428
429 N_VDestroy(y); /* Free y vector */
430 N_VDestroy(tv); /* Free tv vector */
431 CVodeFree(&cvode_mem); /* Free integrator memory */
432 SUNLinSolFree(LS); /* Free linear solver */
433 SUNMatDestroy(A); /* Free A matrix */
434
435 return flag;
436 } |]
437 if res == 0
438 then do
439 preD <- V.freeze diagMut
440 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
441 (fromIntegral $ preD V.!1)
442 (fromIntegral $ preD V.!2)
443 (fromIntegral $ preD V.!3)
444 (fromIntegral $ preD V.!4)
445 (fromIntegral $ preD V.!5)
446 (fromIntegral $ preD V.!6)
447 (fromIntegral $ preD V.!7)
448 (fromIntegral $ preD V.!8)
449 (fromIntegral $ preD V.!9)
450 m <- V.freeze qMatMut
451 return $ Right (m, d)
452 else do
453 return $ Left res
454
455-- | Adaptive step-size control
456-- functions.
457--
458-- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control)
459-- allows the user to control the step size adjustment using
460-- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where
461-- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\)
462-- is the required relative error, \(s_i\) is a vector of scaling
463-- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and
464-- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\).
465--
466-- [ARKode](https://computation.llnl.gov/projects/sundials/arkode)
467-- allows the user to control the step size adjustment using
468-- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with
469-- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl),
470-- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no
471-- effect.
472data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical
473 | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem
474 | XX' Double Double Double Double -- ^ include both via relative tolerance
475 -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\)
476 | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\)
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
new file mode 100644
index 0000000..027d99a
--- /dev/null
+++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
@@ -0,0 +1,32 @@
1module Numeric.Sundials.ODEOpts where
2
3import Data.Word (Word32)
4import qualified Data.Vector.Storable as VS
5
6import Numeric.LinearAlgebra.HMatrix (Vector, Matrix)
7
8
9type Jacobian = Double -> Vector Double -> Matrix Double
10
11data ODEOpts = ODEOpts {
12 maxNumSteps :: Word32
13 , minStep :: Double
14 , relTol :: Double
15 , absTols :: VS.Vector Double
16 , initStep :: Maybe Double
17 , maxFail :: Word32
18 } deriving (Read, Show, Eq, Ord)
19
20data SundialsDiagnostics = SundialsDiagnostics {
21 aRKodeGetNumSteps :: Int
22 , aRKodeGetNumStepAttempts :: Int
23 , aRKodeGetNumRhsEvals_fe :: Int
24 , aRKodeGetNumRhsEvals_fi :: Int
25 , aRKodeGetNumLinSolvSetups :: Int
26 , aRKodeGetNumErrTestFails :: Int
27 , aRKodeGetNumNonlinSolvIters :: Int
28 , aRKodeGetNumNonlinSolvConvFails :: Int
29 , aRKDlsGetNumJacEvals :: Int
30 , aRKDlsGetNumRhsEvals :: Int
31 } deriving Show
32
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs
deleted file mode 100644
index 04e4280..0000000
--- a/packages/sundials/src/Types.hs
+++ /dev/null
@@ -1,40 +0,0 @@
1{-# OPTIONS_GHC -Wall #-}
2
3{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-}
6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE EmptyDataDecls #-}
8
9module Types where
10
11import Foreign.C.Types
12
13import qualified Language.Haskell.TH as TH
14import qualified Language.C.Types as CT
15import qualified Data.Map as Map
16import Language.C.Inline.Context
17
18import qualified Data.Vector.Storable as V
19
20
21data SunVector
22data SunMatrix = SunMatrix { rows :: CInt
23 , cols :: CInt
24 , vals :: V.Vector CDouble
25 }
26
27-- FIXME: Is this true?
28type SunIndexType = CLong
29
30sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ
31sunTypesTable = Map.fromList
32 [
33 (CT.TypeName "sunindextype", [t| SunIndexType |] )
34 , (CT.TypeName "SunVector", [t| SunVector |] )
35 , (CT.TypeName "SunMatrix", [t| SunMatrix |] )
36 ]
37
38sunCtx :: Context
39sunCtx = mempty {ctxTypesTable = sunTypesTable}
40