diff options
author | idontgetoutmuch <dominic@steinitz.org> | 2018-05-03 03:09:25 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-03 03:09:25 -0700 |
commit | 1675813d8f540af9832a78c7a7a40bbdf1cec42c (patch) | |
tree | 01cf740b4d93ae2668c5f320f207387661cd29ca /packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |
parent | dd96a98207dbafbf81b4a5f02613963cf5bd4b4c (diff) | |
parent | 686bd51792648dee967c611225cb1a59efa6b1c2 (diff) |
Merge pull request #269 from idontgetoutmuch/feature-cvode
Feature cvode
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 269 |
1 files changed, 137 insertions, 132 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index e5a2e4d..fafc237 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -1,5 +1,3 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | 1 | {-# LANGUAGE QuasiQuotes #-} |
4 | {-# LANGUAGE TemplateHaskell #-} | 2 | {-# LANGUAGE TemplateHaskell #-} |
5 | {-# LANGUAGE MultiWayIf #-} | 3 | {-# LANGUAGE MultiWayIf #-} |
@@ -22,8 +20,7 @@ | |||
22 | -- Stability : provisional | 20 | -- Stability : provisional |
23 | -- | 21 | -- |
24 | -- Solution of ordinary differential equation (ODE) initial value problems. | 22 | -- Solution of ordinary differential equation (ODE) initial value problems. |
25 | -- | 23 | -- See <https://computation.llnl.gov/projects/sundials/sundials-software> for more detail. |
26 | -- <https://computation.llnl.gov/projects/sundials/sundials-software> | ||
27 | -- | 24 | -- |
28 | -- A simple example: | 25 | -- A simple example: |
29 | -- | 26 | -- |
@@ -67,6 +64,54 @@ | |||
67 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | 64 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) |
68 | -- @ | 65 | -- @ |
69 | -- | 66 | -- |
67 | -- With Sundials ARKode, it is possible to retrieve the Butcher tableau for the solver. | ||
68 | -- | ||
69 | -- @ | ||
70 | -- import Numeric.Sundials.ARKode.ODE | ||
71 | -- import Numeric.LinearAlgebra | ||
72 | -- | ||
73 | -- import Data.List (intercalate) | ||
74 | -- | ||
75 | -- import Text.PrettyPrint.HughesPJClass | ||
76 | -- | ||
77 | -- | ||
78 | -- butcherTableauTex :: ButcherTable -> String | ||
79 | -- butcherTableauTex (ButcherTable m c b b2) = | ||
80 | -- render $ | ||
81 | -- vcat [ text ("\n\\begin{array}{c|" ++ (concat $ replicate n "c") ++ "}") | ||
82 | -- , us | ||
83 | -- , text "\\hline" | ||
84 | -- , text bs <+> text "\\\\" | ||
85 | -- , text b2s <+> text "\\\\" | ||
86 | -- , text "\\end{array}" | ||
87 | -- ] | ||
88 | -- where | ||
89 | -- n = rows m | ||
90 | -- rs = toLists m | ||
91 | -- ss = map (\r -> intercalate " & " $ map show r) rs | ||
92 | -- ts = zipWith (\i r -> show i ++ " & " ++ r) (toList c) ss | ||
93 | -- us = vcat $ map (\r -> text r <+> text "\\\\") ts | ||
94 | -- bs = " & " ++ (intercalate " & " $ map show $ toList b) | ||
95 | -- b2s = " & " ++ (intercalate " & " $ map show $ toList b2) | ||
96 | -- | ||
97 | -- main :: IO () | ||
98 | -- main = do | ||
99 | -- | ||
100 | -- let res = butcherTable (SDIRK_2_1_2 undefined) | ||
101 | -- putStrLn $ show res | ||
102 | -- putStrLn $ butcherTableauTex res | ||
103 | -- | ||
104 | -- let resA = butcherTable (KVAERNO_4_2_3 undefined) | ||
105 | -- putStrLn $ show resA | ||
106 | -- putStrLn $ butcherTableauTex resA | ||
107 | -- | ||
108 | -- let resB = butcherTable (SDIRK_5_3_4 undefined) | ||
109 | -- putStrLn $ show resB | ||
110 | -- putStrLn $ butcherTableauTex resB | ||
111 | -- @ | ||
112 | -- | ||
113 | -- Using the code above from the examples gives | ||
114 | -- | ||
70 | -- KVAERNO_4_2_3 | 115 | -- KVAERNO_4_2_3 |
71 | -- | 116 | -- |
72 | -- \[ | 117 | -- \[ |
@@ -116,8 +161,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve | |||
116 | , butcherTable | 161 | , butcherTable |
117 | , ODEMethod(..) | 162 | , ODEMethod(..) |
118 | , StepControl(..) | 163 | , StepControl(..) |
119 | , Jacobian | ||
120 | , SundialsDiagnostics(..) | ||
121 | ) where | 164 | ) where |
122 | 165 | ||
123 | import qualified Language.C.Inline as C | 166 | import qualified Language.C.Inline as C |
@@ -126,27 +169,50 @@ import qualified Language.C.Inline.Unsafe as CU | |||
126 | import Data.Monoid ((<>)) | 169 | import Data.Monoid ((<>)) |
127 | import Data.Maybe (isJust) | 170 | import Data.Maybe (isJust) |
128 | 171 | ||
129 | import Foreign.C.Types | 172 | import Foreign.C.Types (CDouble, CInt, CLong) |
130 | import Foreign.Ptr (Ptr) | 173 | import Foreign.Ptr (Ptr) |
131 | import Foreign.ForeignPtr (newForeignPtr_) | 174 | import Foreign.Storable (poke) |
132 | import Foreign.Storable (Storable) | ||
133 | 175 | ||
134 | import qualified Data.Vector.Storable as V | 176 | import qualified Data.Vector.Storable as V |
135 | import qualified Data.Vector.Storable.Mutable as VM | ||
136 | 177 | ||
137 | import Data.Coerce (coerce) | 178 | import Data.Coerce (coerce) |
138 | import System.IO.Unsafe (unsafePerformIO) | 179 | import System.IO.Unsafe (unsafePerformIO) |
139 | import GHC.Generics | 180 | import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..), |
181 | from, conName) | ||
140 | 182 | ||
141 | import Numeric.LinearAlgebra.Devel (createVector) | 183 | import Numeric.LinearAlgebra.Devel (createVector) |
142 | 184 | ||
143 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | 185 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, |
144 | subMatrix, rows, cols, toLists, | 186 | cols, toLists, size, reshape, |
145 | size, subVector) | 187 | subVector, subMatrix, (><)) |
146 | 188 | ||
147 | import qualified Types as T | 189 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) |
148 | import Arkode | 190 | import qualified Numeric.Sundials.Arkode as T |
149 | import qualified Arkode as B | 191 | import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, |
192 | sDIRK_2_1_2, | ||
193 | bILLINGTON_3_3_2, | ||
194 | tRBDF2_3_3_2, | ||
195 | kVAERNO_4_2_3, | ||
196 | aRK324L2SA_DIRK_4_2_3, | ||
197 | cASH_5_2_4, | ||
198 | cASH_5_3_4, | ||
199 | sDIRK_5_3_4, | ||
200 | kVAERNO_5_3_4, | ||
201 | aRK436L2SA_DIRK_6_3_4, | ||
202 | kVAERNO_7_4_5, | ||
203 | aRK548L2SA_DIRK_8_4_5, | ||
204 | hEUN_EULER_2_1_2, | ||
205 | bOGACKI_SHAMPINE_4_2_3, | ||
206 | aRK324L2SA_ERK_4_2_3, | ||
207 | zONNEVELD_5_3_4, | ||
208 | aRK436L2SA_ERK_6_3_4, | ||
209 | sAYFY_ABURUB_6_3_4, | ||
210 | cASH_KARP_6_4_5, | ||
211 | fEHLBERG_6_4_5, | ||
212 | dORMAND_PRINCE_7_4_5, | ||
213 | aRK548L2SA_ERK_8_4_5, | ||
214 | vERNER_8_5_6, | ||
215 | fEHLBERG_13_7_8) | ||
150 | 216 | ||
151 | 217 | ||
152 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 218 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -162,69 +228,8 @@ C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | |||
162 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | 228 | C.include "<sundials/sundials_types.h>" -- definition of type realtype |
163 | C.include "<sundials/sundials_math.h>" | 229 | C.include "<sundials/sundials_math.h>" |
164 | C.include "../../../helpers.h" | 230 | C.include "../../../helpers.h" |
165 | C.include "Arkode_hsc.h" | 231 | C.include "Numeric/Sundials/Arkode_hsc.h" |
166 | 232 | ||
167 | |||
168 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) | ||
169 | getDataFromContents len ptr = do | ||
170 | qtr <- B.getContentPtr ptr | ||
171 | rtr <- B.getData qtr | ||
172 | vectorFromC len rtr | ||
173 | |||
174 | -- FIXME: Potentially an instance of Storable | ||
175 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
176 | _getMatrixDataFromContents ptr = do | ||
177 | qtr <- B.getContentMatrixPtr ptr | ||
178 | rs <- B.getNRows qtr | ||
179 | cs <- B.getNCols qtr | ||
180 | rtr <- B.getMatrixData qtr | ||
181 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
182 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
183 | |||
184 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
185 | putMatrixDataFromContents mat ptr = do | ||
186 | let rs = T.rows mat | ||
187 | cs = T.cols mat | ||
188 | vs = T.vals mat | ||
189 | qtr <- B.getContentMatrixPtr ptr | ||
190 | B.putNRows rs qtr | ||
191 | B.putNCols cs qtr | ||
192 | rtr <- B.getMatrixData qtr | ||
193 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
194 | -- FIXME: END | ||
195 | |||
196 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
197 | putDataInContents vec len ptr = do | ||
198 | qtr <- B.getContentPtr ptr | ||
199 | rtr <- B.getData qtr | ||
200 | vectorToC vec len rtr | ||
201 | |||
202 | -- Utils | ||
203 | |||
204 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
205 | vectorFromC len ptr = do | ||
206 | ptr' <- newForeignPtr_ ptr | ||
207 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
208 | |||
209 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
210 | vectorToC vec len ptr = do | ||
211 | ptr' <- newForeignPtr_ ptr | ||
212 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
213 | |||
214 | data SundialsDiagnostics = SundialsDiagnostics { | ||
215 | aRKodeGetNumSteps :: Int | ||
216 | , aRKodeGetNumStepAttempts :: Int | ||
217 | , aRKodeGetNumRhsEvals_fe :: Int | ||
218 | , aRKodeGetNumRhsEvals_fi :: Int | ||
219 | , aRKodeGetNumLinSolvSetups :: Int | ||
220 | , aRKodeGetNumErrTestFails :: Int | ||
221 | , aRKodeGetNumNonlinSolvIters :: Int | ||
222 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
223 | , aRKDlsGetNumJacEvals :: Int | ||
224 | , aRKDlsGetNumRhsEvals :: Int | ||
225 | } deriving Show | ||
226 | |||
227 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
228 | 233 | ||
229 | -- | Stepping functions | 234 | -- | Stepping functions |
230 | data ODEMethod = SDIRK_2_1_2 Jacobian | 235 | data ODEMethod = SDIRK_2_1_2 Jacobian |
@@ -390,15 +395,9 @@ odeSolveV | |||
390 | -> Vector Double -- ^ desired solution times | 395 | -> Vector Double -- ^ desired solution times |
391 | -> Matrix Double -- ^ solution | 396 | -> Matrix Double -- ^ solution |
392 | odeSolveV meth hi epsAbs epsRel f y0 ts = | 397 | odeSolveV meth hi epsAbs epsRel f y0 ts = |
393 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of | 398 | odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts |
394 | Left c -> error $ show c -- FIXME | 399 | where |
395 | -- FIXME: Can we do better than using lists? | 400 | g t x0 = coerce $ f t x0 |
396 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
397 | where | ||
398 | us = toList ts | ||
399 | nR = length us | ||
400 | nC = size y0 | ||
401 | g t x0 = coerce $ f t x0 | ||
402 | 401 | ||
403 | -- | A version of 'odeSolveV' with reasonable default parameters and | 402 | -- | A version of 'odeSolveV' with reasonable default parameters and |
404 | -- system of equations defined using lists. FIXME: we should say | 403 | -- system of equations defined using lists. FIXME: we should say |
@@ -410,16 +409,11 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
410 | -> Matrix Double -- ^ solution | 409 | -> Matrix Double -- ^ solution |
411 | odeSolve f y0 ts = | 410 | odeSolve f y0 ts = |
412 | -- FIXME: These tolerances are different from the ones in GSL | 411 | -- FIXME: These tolerances are different from the ones in GSL |
413 | case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | 412 | odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) |
414 | Left c -> error $ show c -- FIXME | ||
415 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
416 | where | 413 | where |
417 | us = toList ts | ||
418 | nR = length us | ||
419 | nC = length y0 | ||
420 | g t x0 = V.fromList $ f t (V.toList x0) | 414 | g t x0 = V.fromList $ f t (V.toList x0) |
421 | 415 | ||
422 | odeSolveVWith' :: | 416 | odeSolveVWith :: |
423 | ODEMethod | 417 | ODEMethod |
424 | -> StepControl | 418 | -> StepControl |
425 | -> Maybe Double -- ^ initial step size - by default, ARKode | 419 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -432,16 +426,22 @@ odeSolveVWith' :: | |||
432 | -> V.Vector Double -- ^ Initial conditions | 426 | -> V.Vector Double -- ^ Initial conditions |
433 | -> V.Vector Double -- ^ Desired solution times | 427 | -> V.Vector Double -- ^ Desired solution times |
434 | -> Matrix Double -- ^ Error code or solution | 428 | -> Matrix Double -- ^ Error code or solution |
435 | odeSolveVWith' method control initStepSize f y0 tt = | 429 | odeSolveVWith method control initStepSize f y0 tt = |
436 | case odeSolveVWith method control initStepSize f y0 tt of | 430 | case odeSolveVWith' opts method control initStepSize f y0 tt of |
437 | Left c -> error $ show c -- FIXME | 431 | Left c -> error $ show c -- FIXME |
438 | Right (v, _d) -> (nR >< nC) (V.toList v) | 432 | Right (v, _d) -> v |
439 | where | 433 | where |
440 | nR = V.length tt | 434 | opts = ODEOpts { maxNumSteps = 10000 |
441 | nC = V.length y0 | 435 | , minStep = 1.0e-12 |
436 | , relTol = error "relTol" | ||
437 | , absTols = error "absTol" | ||
438 | , initStep = error "initStep" | ||
439 | , maxFail = 10 | ||
440 | } | ||
442 | 441 | ||
443 | odeSolveVWith :: | 442 | odeSolveVWith' :: |
444 | ODEMethod | 443 | ODEOpts |
444 | -> ODEMethod | ||
445 | -> StepControl | 445 | -> StepControl |
446 | -> Maybe Double -- ^ initial step size - by default, ARKode | 446 | -> Maybe Double -- ^ initial step size - by default, ARKode |
447 | -- estimates the initial step size to be the | 447 | -- estimates the initial step size to be the |
@@ -452,19 +452,21 @@ odeSolveVWith :: | |||
452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
453 | -> V.Vector Double -- ^ Initial conditions | 453 | -> V.Vector Double -- ^ Initial conditions |
454 | -> V.Vector Double -- ^ Desired solution times | 454 | -> V.Vector Double -- ^ Desired solution times |
455 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 455 | -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution |
456 | odeSolveVWith method control initStepSize f y0 tt = | 456 | odeSolveVWith' opts method control initStepSize f y0 tt = |
457 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | 457 | case solveOdeC (fromIntegral $ maxFail opts) |
458 | (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) | ||
459 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | ||
458 | (coerce f) (coerce y0) (coerce tt) of | 460 | (coerce f) (coerce y0) (coerce tt) of |
459 | Left c -> Left $ fromIntegral c | 461 | Left c -> Left $ fromIntegral c |
460 | Right (v, d) -> Right (coerce v, d) | 462 | Right (v, d) -> Right (reshape l (coerce v), d) |
461 | where | 463 | where |
462 | l = size y0 | 464 | l = size y0 |
463 | scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) | 465 | scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) |
464 | scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) | 466 | scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) |
465 | scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) | 467 | scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) |
466 | -- FIXME; Should we check that the length of ss is correct? | 468 | -- FIXME; Should we check that the length of ss is correct? |
467 | scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) | 469 | scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) |
468 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ | 470 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ |
469 | getJacobian method | 471 | getJacobian method |
470 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | 472 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } |
@@ -476,6 +478,9 @@ odeSolveVWith method control initStepSize f y0 tt = | |||
476 | 478 | ||
477 | solveOdeC :: | 479 | solveOdeC :: |
478 | CInt -> | 480 | CInt -> |
481 | CLong -> | ||
482 | CDouble -> | ||
483 | CInt -> | ||
479 | Maybe CDouble -> | 484 | Maybe CDouble -> |
480 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> | 485 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> |
481 | (V.Vector CDouble, CDouble) -> | 486 | (V.Vector CDouble, CDouble) -> |
@@ -483,7 +488,8 @@ solveOdeC :: | |||
483 | -> V.Vector CDouble -- ^ Initial conditions | 488 | -> V.Vector CDouble -- ^ Initial conditions |
484 | -> V.Vector CDouble -- ^ Desired solution times | 489 | -> V.Vector CDouble -- ^ Desired solution times |
485 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | 490 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution |
486 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | 491 | solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize |
492 | jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do | ||
487 | 493 | ||
488 | let isInitStepSize :: CInt | 494 | let isInitStepSize :: CInt |
489 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize | 495 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize |
@@ -494,14 +500,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
494 | -- used :( | 500 | -- used :( |
495 | Nothing -> 0.0 | 501 | Nothing -> 0.0 |
496 | Just x -> x | 502 | Just x -> x |
503 | |||
497 | let dim = V.length f0 | 504 | let dim = V.length f0 |
498 | nEq :: CLong | 505 | nEq :: CLong |
499 | nEq = fromIntegral dim | 506 | nEq = fromIntegral dim |
500 | nTs :: CInt | 507 | nTs :: CInt |
501 | nTs = fromIntegral $ V.length ts | 508 | nTs = fromIntegral $ V.length ts |
502 | -- FIXME: fMut is not actually mutatated | ||
503 | fMut <- V.thaw f0 | ||
504 | tMut <- V.thaw ts | ||
505 | -- FIXME: I believe this gets taken from the ghc heap and so should | 509 | -- FIXME: I believe this gets taken from the ghc heap and so should |
506 | -- be subject to garbage collection. | 510 | -- be subject to garbage collection. |
507 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | 511 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) |
@@ -509,7 +513,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
509 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | 513 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME |
510 | diagMut <- V.thaw diagnostics | 514 | diagMut <- V.thaw diagnostics |
511 | -- We need the types that sundials expects. These are tied together | 515 | -- We need the types that sundials expects. These are tied together |
512 | -- in 'Types'. FIXME: The Haskell type is currently empty! | 516 | -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty! |
513 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 517 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
514 | funIO x y f _ptr = do | 518 | funIO x y f _ptr = do |
515 | -- Convert the pointer we get from C (y) to a vector, and then | 519 | -- Convert the pointer we get from C (y) to a vector, and then |
@@ -529,7 +533,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
529 | case jacH of | 533 | case jacH of |
530 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | 534 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
531 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | 535 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y |
532 | putMatrixDataFromContents j jacS | 536 | poke jacS j |
533 | -- FIXME: I don't understand what this comment means | 537 | -- FIXME: I don't understand what this comment means |
534 | -- Unsafe since the function will be called many times. | 538 | -- Unsafe since the function will be called many times. |
535 | [CU.exp| int{ 0 } |] | 539 | [CU.exp| int{ 0 } |] |
@@ -549,7 +553,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
549 | 553 | ||
550 | /* general problem parameters */ | 554 | /* general problem parameters */ |
551 | 555 | ||
552 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ | 556 | realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ |
553 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | 557 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ |
554 | 558 | ||
555 | /* Initialize data structures */ | 559 | /* Initialize data structures */ |
@@ -558,14 +562,14 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
558 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | 562 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; |
559 | /* Specify initial condition */ | 563 | /* Specify initial condition */ |
560 | for (i = 0; i < NEQ; i++) { | 564 | for (i = 0; i < NEQ; i++) { |
561 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | 565 | NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; |
562 | }; | 566 | }; |
563 | 567 | ||
564 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ | 568 | tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ |
565 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; | 569 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; |
566 | /* Specify tolerances */ | 570 | /* Specify tolerances */ |
567 | for (i = 0; i < NEQ; i++) { | 571 | for (i = 0; i < NEQ; i++) { |
568 | NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; | 572 | NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; |
569 | }; | 573 | }; |
570 | 574 | ||
571 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | 575 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ |
@@ -577,7 +581,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
577 | /* problem as fully implicit and set f_E to NULL and f_I to f. */ | 581 | /* problem as fully implicit and set f_E to NULL and f_I to f. */ |
578 | 582 | ||
579 | /* Here we use the C types defined in helpers.h which tie up with */ | 583 | /* Here we use the C types defined in helpers.h which tie up with */ |
580 | /* the Haskell types defined in Types */ | 584 | /* the Haskell types defined in CLangToHaskellTypes */ |
581 | if ($(int method) < MIN_DIRK_NUM) { | 585 | if ($(int method) < MIN_DIRK_NUM) { |
582 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); | 586 | flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); |
583 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 587 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
@@ -586,14 +590,15 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
586 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 590 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
587 | } | 591 | } |
588 | 592 | ||
589 | /* FIXME: A hack for initial testing */ | 593 | flag = ARKodeSetMinStep(arkode_mem, $(double minStep_)); |
590 | flag = ARKodeSetMinStep(arkode_mem, 1.0e-12); | ||
591 | if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; | 594 | if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; |
592 | flag = ARKodeSetMaxNumSteps(arkode_mem, 10000); | 595 | flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_)); |
593 | if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; | 596 | if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; |
597 | flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails)); | ||
598 | if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1; | ||
594 | 599 | ||
595 | /* Set routines */ | 600 | /* Set routines */ |
596 | flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); | 601 | flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv); |
597 | if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; | 602 | if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; |
598 | 603 | ||
599 | /* Initialize dense matrix data structure and solver */ | 604 | /* Initialize dense matrix data structure and solver */ |
@@ -638,7 +643,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
638 | /* Stops when the final time has been reached */ | 643 | /* Stops when the final time has been reached */ |
639 | for (i = 1; i < $(int nTs); i++) { | 644 | for (i = 1; i < $(int nTs); i++) { |
640 | 645 | ||
641 | flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ | 646 | flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */ |
642 | if (check_flag(&flag, "ARKode", 1)) break; | 647 | if (check_flag(&flag, "ARKode", 1)) break; |
643 | 648 | ||
644 | /* Store the results for Haskell */ | 649 | /* Store the results for Haskell */ |
@@ -738,7 +743,7 @@ butcherTable method = | |||
738 | case getBT method of | 743 | case getBT method of |
739 | Left c -> error $ show c -- FIXME | 744 | Left c -> error $ show c -- FIXME |
740 | Right (ButcherTable' v w x y, sqp) -> | 745 | Right (ButcherTable' v w x y, sqp) -> |
741 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) | 746 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v) |
742 | , cv = subVector 0 s w | 747 | , cv = subVector 0 s w |
743 | , bv = subVector 0 s x | 748 | , bv = subVector 0 s x |
744 | , b2v = subVector 0 s y | 749 | , b2v = subVector 0 s y |
@@ -773,11 +778,11 @@ getButcherTable method = unsafePerformIO $ do | |||
773 | 778 | ||
774 | btSQP :: V.Vector CInt <- createVector 3 | 779 | btSQP :: V.Vector CInt <- createVector 3 |
775 | btSQPMut <- V.thaw btSQP | 780 | btSQPMut <- V.thaw btSQP |
776 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) | 781 | btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax) |
777 | btAsMut <- V.thaw btAs | 782 | btAsMut <- V.thaw btAs |
778 | btCs :: V.Vector CDouble <- createVector B.arkSMax | 783 | btCs :: V.Vector CDouble <- createVector arkSMax |
779 | btBs :: V.Vector CDouble <- createVector B.arkSMax | 784 | btBs :: V.Vector CDouble <- createVector arkSMax |
780 | btB2s :: V.Vector CDouble <- createVector B.arkSMax | 785 | btB2s :: V.Vector CDouble <- createVector arkSMax |
781 | btCsMut <- V.thaw btCs | 786 | btCsMut <- V.thaw btCs |
782 | btBsMut <- V.thaw btBs | 787 | btBsMut <- V.thaw btBs |
783 | btB2sMut <- V.thaw btB2s | 788 | btB2sMut <- V.thaw btB2s |