summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-05-02 14:42:43 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-05-02 14:42:43 +0100
commit4ba859636396d211637b5507f19722b6953656a5 (patch)
tree9493c4851e6141a400e6345efe59a07197709f63
parent149dedfc6ec8dea039a4df7ad1d31880820c52eb (diff)
Add more options
-rw-r--r--packages/sundials/src/Main.hs48
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs113
-rw-r--r--packages/sundials/src/Numeric/Sundials/CVode/ODE.hs23
-rw-r--r--packages/sundials/src/Numeric/Sundials/ODEOpts.hs7
4 files changed, 116 insertions, 75 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 85928e2..16c21c5 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -81,6 +81,23 @@ _stiffJac _t _v = (1><1) [ lamda ]
81 where 81 where
82 lamda = -100.0 82 lamda = -100.0
83 83
84predatorPrey :: Double -> [Double] -> [Double]
85predatorPrey _t v = [ x * a - b * x * y
86 , d * x * y - c * y - e * y * z
87 , (-f) * z + g * y * z
88 ]
89 where
90 x = v!!0
91 y = v!!1
92 z = v!!2
93 a = 1.0
94 b = 1.0
95 c = 1.0
96 d = 1.0
97 e = 1.0
98 f = 1.0
99 g = 1.0
100
84lSaxis :: [[Double]] -> P.Axis B D.V2 Double 101lSaxis :: [[Double]] -> P.Axis B D.V2 Double
85lSaxis xs = P.r2Axis &~ do 102lSaxis xs = P.r2Axis &~ do
86 let ts = xs!!0 103 let ts = xs!!0
@@ -128,11 +145,6 @@ main = do
128 let maxDiffC = maximum $ map abs $ 145 let maxDiffC = maximum $ map abs $
129 zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0) 146 zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0)
130 147
131 hspec $ describe "Compare results" $ do
132 it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6
133 it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6
134 it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6
135
136 let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) 148 let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0])
137 149
138 renderRasterific "diagrams/lorenz.png" 150 renderRasterific "diagrams/lorenz.png"
@@ -146,3 +158,29 @@ main = do
146 renderRasterific "diagrams/lorenz2.png" 158 renderRasterific "diagrams/lorenz2.png"
147 (D.dims2D 500.0 500.0) 159 (D.dims2D 500.0 500.0)
148 (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2)) 160 (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2))
161
162 let res4 = CV.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0])
163
164 renderRasterific "diagrams/predatorPrey.png"
165 (D.dims2D 500.0 500.0)
166 (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!1))
167
168 renderRasterific "diagrams/predatorPrey1.png"
169 (D.dims2D 500.0 500.0)
170 (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!2))
171
172 renderRasterific "diagrams/predatorPrey2.png"
173 (D.dims2D 500.0 500.0)
174 (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!1) ((toLists $ tr res4)!!2))
175
176 let res4a = ARK.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0])
177
178 let maxDiffPpA = maximum $ map abs $
179 zipWith (-) ((toLists $ tr res4)!!0) ((toLists $ tr res4a)!!0)
180
181 hspec $ describe "Compare results" $ do
182 it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6
183 it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6
184 it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6
185 it "for CV and ARK for the Predator Prey model" $ maxDiffPpA < 1.0e-3
186
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 13b7eb8..ce46968 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -114,7 +114,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve
114 , butcherTable 114 , butcherTable
115 , ODEMethod(..) 115 , ODEMethod(..)
116 , StepControl(..) 116 , StepControl(..)
117 , Jacobian
118 ) where 117 ) where
119 118
120import qualified Language.C.Inline as C 119import qualified Language.C.Inline as C
@@ -136,11 +135,11 @@ import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(
136 135
137import Numeric.LinearAlgebra.Devel (createVector) 136import Numeric.LinearAlgebra.Devel (createVector)
138 137
139import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), 138import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows,
140 subMatrix, rows, cols, toLists, 139 cols, toLists, size, reshape,
141 size, subVector) 140 subVector, subMatrix, (><))
142 141
143import qualified Numeric.Sundials.ODEOpts as SO 142import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..))
144import qualified Numeric.Sundials.Arkode as T 143import qualified Numeric.Sundials.Arkode as T
145import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, 144import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax,
146 sDIRK_2_1_2, 145 sDIRK_2_1_2,
@@ -185,8 +184,6 @@ C.include "../../../helpers.h"
185C.include "Numeric/Sundials/Arkode_hsc.h" 184C.include "Numeric/Sundials/Arkode_hsc.h"
186 185
187 186
188type Jacobian = Double -> Vector Double -> Matrix Double
189
190-- | Stepping functions 187-- | Stepping functions
191data ODEMethod = SDIRK_2_1_2 Jacobian 188data ODEMethod = SDIRK_2_1_2 Jacobian
192 | SDIRK_2_1_2' 189 | SDIRK_2_1_2'
@@ -351,15 +348,9 @@ odeSolveV
351 -> Vector Double -- ^ desired solution times 348 -> Vector Double -- ^ desired solution times
352 -> Matrix Double -- ^ solution 349 -> Matrix Double -- ^ solution
353odeSolveV meth hi epsAbs epsRel f y0 ts = 350odeSolveV meth hi epsAbs epsRel f y0 ts =
354 case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of 351 odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts
355 Left c -> error $ show c -- FIXME 352 where
356 -- FIXME: Can we do better than using lists? 353 g t x0 = coerce $ f t x0
357 Right (v, _d) -> (nR >< nC) (V.toList v)
358 where
359 us = toList ts
360 nR = length us
361 nC = size y0
362 g t x0 = coerce $ f t x0
363 354
364-- | A version of 'odeSolveV' with reasonable default parameters and 355-- | A version of 'odeSolveV' with reasonable default parameters and
365-- system of equations defined using lists. FIXME: we should say 356-- system of equations defined using lists. FIXME: we should say
@@ -371,13 +362,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
371 -> Matrix Double -- ^ solution 362 -> Matrix Double -- ^ solution
372odeSolve f y0 ts = 363odeSolve f y0 ts =
373 -- FIXME: These tolerances are different from the ones in GSL 364 -- FIXME: These tolerances are different from the ones in GSL
374 case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 365 odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts)
375 Left c -> error $ show c -- FIXME
376 Right (v, _d) -> (nR >< nC) (V.toList v)
377 where 366 where
378 us = toList ts
379 nR = length us
380 nC = length y0
381 g t x0 = V.fromList $ f t (V.toList x0) 367 g t x0 = V.fromList $ f t (V.toList x0)
382 368
383odeSolveVWith :: 369odeSolveVWith ::
@@ -394,15 +380,21 @@ odeSolveVWith ::
394 -> V.Vector Double -- ^ Desired solution times 380 -> V.Vector Double -- ^ Desired solution times
395 -> Matrix Double -- ^ Error code or solution 381 -> Matrix Double -- ^ Error code or solution
396odeSolveVWith method control initStepSize f y0 tt = 382odeSolveVWith method control initStepSize f y0 tt =
397 case odeSolveVWith' method control initStepSize f y0 tt of 383 case odeSolveVWith' opts method control initStepSize f y0 tt of
398 Left c -> error $ show c -- FIXME 384 Left c -> error $ show c -- FIXME
399 Right (v, _d) -> (nR >< nC) (V.toList v) 385 Right (v, _d) -> v
400 where 386 where
401 nR = V.length tt 387 opts = ODEOpts { maxNumSteps = 10000
402 nC = V.length y0 388 , minStep = 1.0e-12
389 , relTol = error "relTol"
390 , absTols = error "absTol"
391 , initStep = error "initStep"
392 , maxFail = 10
393 }
403 394
404odeSolveVWith' :: 395odeSolveVWith' ::
405 ODEMethod 396 ODEOpts
397 -> ODEMethod
406 -> StepControl 398 -> StepControl
407 -> Maybe Double -- ^ initial step size - by default, ARKode 399 -> Maybe Double -- ^ initial step size - by default, ARKode
408 -- estimates the initial step size to be the 400 -- estimates the initial step size to be the
@@ -413,19 +405,21 @@ odeSolveVWith' ::
413 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 405 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
414 -> V.Vector Double -- ^ Initial conditions 406 -> V.Vector Double -- ^ Initial conditions
415 -> V.Vector Double -- ^ Desired solution times 407 -> V.Vector Double -- ^ Desired solution times
416 -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution 408 -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution
417odeSolveVWith' method control initStepSize f y0 tt = 409odeSolveVWith' opts method control initStepSize f y0 tt =
418 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) 410 case solveOdeC (fromIntegral $ maxFail opts)
411 (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
412 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
419 (coerce f) (coerce y0) (coerce tt) of 413 (coerce f) (coerce y0) (coerce tt) of
420 Left c -> Left $ fromIntegral c 414 Left c -> Left $ fromIntegral c
421 Right (v, d) -> Right (coerce v, d) 415 Right (v, d) -> Right (reshape l (coerce v), d)
422 where 416 where
423 l = size y0 417 l = size y0
424 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) 418 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
425 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) 419 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
426 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) 420 scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol)
427 -- FIXME; Should we check that the length of ss is correct? 421 -- FIXME; Should we check that the length of ss is correct?
428 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) 422 scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol)
429 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ 423 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
430 getJacobian method 424 getJacobian method
431 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } 425 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
@@ -437,14 +431,18 @@ odeSolveVWith' method control initStepSize f y0 tt =
437 431
438solveOdeC :: 432solveOdeC ::
439 CInt -> 433 CInt ->
434 CLong ->
435 CDouble ->
436 CInt ->
440 Maybe CDouble -> 437 Maybe CDouble ->
441 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> 438 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
442 (V.Vector CDouble, CDouble) -> 439 (V.Vector CDouble, CDouble) ->
443 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 440 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
444 -> V.Vector CDouble -- ^ Initial conditions 441 -> V.Vector CDouble -- ^ Initial conditions
445 -> V.Vector CDouble -- ^ Desired solution times 442 -> V.Vector CDouble -- ^ Desired solution times
446 -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution 443 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
447solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do 444solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize
445 jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do
448 446
449 let isInitStepSize :: CInt 447 let isInitStepSize :: CInt
450 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize 448 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
@@ -455,14 +453,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
455 -- used :( 453 -- used :(
456 Nothing -> 0.0 454 Nothing -> 0.0
457 Just x -> x 455 Just x -> x
456
458 let dim = V.length f0 457 let dim = V.length f0
459 nEq :: CLong 458 nEq :: CLong
460 nEq = fromIntegral dim 459 nEq = fromIntegral dim
461 nTs :: CInt 460 nTs :: CInt
462 nTs = fromIntegral $ V.length ts 461 nTs = fromIntegral $ V.length ts
463 -- FIXME: fMut is not actually mutatated
464 fMut <- V.thaw f0
465 tMut <- V.thaw ts
466 -- FIXME: I believe this gets taken from the ghc heap and so should 462 -- FIXME: I believe this gets taken from the ghc heap and so should
467 -- be subject to garbage collection. 463 -- be subject to garbage collection.
468 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) 464 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
@@ -510,7 +506,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
510 506
511 /* general problem parameters */ 507 /* general problem parameters */
512 508
513 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ 509 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
514 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ 510 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
515 511
516 /* Initialize data structures */ 512 /* Initialize data structures */
@@ -519,14 +515,14 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
519 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 515 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
520 /* Specify initial condition */ 516 /* Specify initial condition */
521 for (i = 0; i < NEQ; i++) { 517 for (i = 0; i < NEQ; i++) {
522 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; 518 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
523 }; 519 };
524 520
525 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ 521 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
526 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; 522 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
527 /* Specify tolerances */ 523 /* Specify tolerances */
528 for (i = 0; i < NEQ; i++) { 524 for (i = 0; i < NEQ; i++) {
529 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; 525 NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i];
530 }; 526 };
531 527
532 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 528 arkode_mem = ARKodeCreate(); /* Create the solver memory */
@@ -547,14 +543,15 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
547 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 543 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
548 } 544 }
549 545
550 /* FIXME: A hack for initial testing */ 546 flag = ARKodeSetMinStep(arkode_mem, $(double minStep_));
551 flag = ARKodeSetMinStep(arkode_mem, 1.0e-12);
552 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; 547 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1;
553 flag = ARKodeSetMaxNumSteps(arkode_mem, 10000); 548 flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_));
554 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; 549 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1;
550 flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails));
551 if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1;
555 552
556 /* Set routines */ 553 /* Set routines */
557 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); 554 flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv);
558 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; 555 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1;
559 556
560 /* Initialize dense matrix data structure and solver */ 557 /* Initialize dense matrix data structure and solver */
@@ -599,7 +596,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
599 /* Stops when the final time has been reached */ 596 /* Stops when the final time has been reached */
600 for (i = 1; i < $(int nTs); i++) { 597 for (i = 1; i < $(int nTs); i++) {
601 598
602 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ 599 flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */
603 if (check_flag(&flag, "ARKode", 1)) break; 600 if (check_flag(&flag, "ARKode", 1)) break;
604 601
605 /* Store the results for Haskell */ 602 /* Store the results for Haskell */
@@ -665,16 +662,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
665 if res == 0 662 if res == 0
666 then do 663 then do
667 preD <- V.freeze diagMut 664 preD <- V.freeze diagMut
668 let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) 665 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
669 (fromIntegral $ preD V.!1) 666 (fromIntegral $ preD V.!1)
670 (fromIntegral $ preD V.!2) 667 (fromIntegral $ preD V.!2)
671 (fromIntegral $ preD V.!3) 668 (fromIntegral $ preD V.!3)
672 (fromIntegral $ preD V.!4) 669 (fromIntegral $ preD V.!4)
673 (fromIntegral $ preD V.!5) 670 (fromIntegral $ preD V.!5)
674 (fromIntegral $ preD V.!6) 671 (fromIntegral $ preD V.!6)
675 (fromIntegral $ preD V.!7) 672 (fromIntegral $ preD V.!7)
676 (fromIntegral $ preD V.!8) 673 (fromIntegral $ preD V.!8)
677 (fromIntegral $ preD V.!9) 674 (fromIntegral $ preD V.!9)
678 m <- V.freeze qMatMut 675 m <- V.freeze qMatMut
679 return $ Right (m, d) 676 return $ Right (m, d)
680 else do 677 else do
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
index 159fbe2..a6f185e 100644
--- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
@@ -68,7 +68,6 @@ module Numeric.Sundials.CVode.ODE ( odeSolve
68 , odeSolveVWith' 68 , odeSolveVWith'
69 , ODEMethod(..) 69 , ODEMethod(..)
70 , StepControl(..) 70 , StepControl(..)
71 , Jacobian
72 ) where 71 ) where
73 72
74import qualified Language.C.Inline as C 73import qualified Language.C.Inline as C
@@ -127,7 +126,7 @@ getJacobian _ = Nothing
127-- | A version of 'odeSolveVWith' with reasonable default step control. 126-- | A version of 'odeSolveVWith' with reasonable default step control.
128odeSolveV 127odeSolveV
129 :: ODEMethod 128 :: ODEMethod
130 -> Maybe Double -- ^ initial step size - by default, ARKode 129 -> Maybe Double -- ^ initial step size - by default, CVode
131 -- estimates the initial step size to be the 130 -- estimates the initial step size to be the
132 -- solution \(h\) of the equation 131 -- solution \(h\) of the equation
133 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where 132 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
@@ -161,7 +160,7 @@ odeSolve f y0 ts =
161odeSolveVWith :: 160odeSolveVWith ::
162 ODEMethod 161 ODEMethod
163 -> StepControl 162 -> StepControl
164 -> Maybe Double -- ^ initial step size - by default, ARKode 163 -> Maybe Double -- ^ initial step size - by default, CVode
165 -- estimates the initial step size to be the 164 -- estimates the initial step size to be the
166 -- solution \(h\) of the equation 165 -- solution \(h\) of the equation
167 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where 166 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
@@ -181,13 +180,14 @@ odeSolveVWith method control initStepSize f y0 tt =
181 , relTol = error "relTol" 180 , relTol = error "relTol"
182 , absTols = error "absTol" 181 , absTols = error "absTol"
183 , initStep = error "initStep" 182 , initStep = error "initStep"
183 , maxFail = 10
184 } 184 }
185 185
186odeSolveVWith' :: 186odeSolveVWith' ::
187 ODEOpts 187 ODEOpts
188 -> ODEMethod 188 -> ODEMethod
189 -> StepControl 189 -> StepControl
190 -> Maybe Double -- ^ initial step size - by default, ARKode 190 -> Maybe Double -- ^ initial step size - by default, CVode
191 -- estimates the initial step size to be the 191 -- estimates the initial step size to be the
192 -- solution \(h\) of the equation 192 -- solution \(h\) of the equation
193 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where 193 -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where
@@ -198,13 +198,13 @@ odeSolveVWith' ::
198 -> V.Vector Double -- ^ Desired solution times 198 -> V.Vector Double -- ^ Desired solution times
199 -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution 199 -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution
200odeSolveVWith' opts method control initStepSize f y0 tt = 200odeSolveVWith' opts method control initStepSize f y0 tt =
201 case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) 201 case solveOdeC (fromIntegral $ maxFail opts)
202 (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
202 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) 203 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
203 (coerce f) (coerce y0) (coerce tt) of 204 (coerce f) (coerce y0) (coerce tt) of
204 Left c -> Left $ fromIntegral c 205 Left c -> Left $ fromIntegral c
205 Right (v, d) -> Right (reshape nC (coerce v), d) 206 Right (v, d) -> Right (reshape l (coerce v), d)
206 where 207 where
207 nC = V.length y0
208 l = size y0 208 l = size y0
209 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) 209 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
210 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) 210 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
@@ -221,6 +221,7 @@ odeSolveVWith' opts method control initStepSize f y0 tt =
221 vs = V.fromList $ map coerce $ concat $ toLists m 221 vs = V.fromList $ map coerce $ concat $ toLists m
222 222
223solveOdeC :: 223solveOdeC ::
224 CInt ->
224 CLong -> 225 CLong ->
225 CDouble -> 226 CDouble ->
226 CInt -> 227 CInt ->
@@ -231,7 +232,8 @@ solveOdeC ::
231 -> V.Vector CDouble -- ^ Initial conditions 232 -> V.Vector CDouble -- ^ Initial conditions
232 -> V.Vector CDouble -- ^ Desired solution times 233 -> V.Vector CDouble -- ^ Desired solution times
233 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 234 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
234solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = 235solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize
236 jacH (aTols, rTol) fun f0 ts =
235 unsafePerformIO $ do 237 unsafePerformIO $ do
236 238
237 let isInitStepSize :: CInt 239 let isInitStepSize :: CInt
@@ -243,6 +245,7 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts
243 -- used :( 245 -- used :(
244 Nothing -> 0.0 246 Nothing -> 0.0
245 Just x -> x 247 Just x -> x
248
246 let dim = V.length f0 249 let dim = V.length f0
247 nEq :: CLong 250 nEq :: CLong
248 nEq = fromIntegral dim 251 nEq = fromIntegral dim
@@ -271,7 +274,7 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts
271 IO CInt 274 IO CInt
272 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do 275 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
273 case jacH of 276 case jacH of
274 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" 277 Nothing -> error "Numeric.Sundials.CVode.ODE: Jacobian not defined"
275 Just jacI -> do j <- jacI t <$> getDataFromContents dim y 278 Just jacI -> do j <- jacI t <$> getDataFromContents dim y
276 poke jacS j 279 poke jacS j
277 -- FIXME: I don't understand what this comment means 280 -- FIXME: I don't understand what this comment means
@@ -326,6 +329,8 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts
326 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; 329 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1;
327 flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_)); 330 flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_));
328 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; 331 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1;
332 flag = CVodeSetMaxErrTestFails(cvode_mem, $(int maxErrTestFails));
333 if (check_flag(&flag, "CVodeSetMaxErrTestFails", 1)) return 1;
329 334
330 /* Call CVodeSVtolerances to specify the scalar relative tolerance 335 /* Call CVodeSVtolerances to specify the scalar relative tolerance
331 * and vector absolute tolerances */ 336 * and vector absolute tolerances */
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
index 89f2306..027d99a 100644
--- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
+++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
@@ -1,6 +1,6 @@
1module Numeric.Sundials.ODEOpts where 1module Numeric.Sundials.ODEOpts where
2 2
3import Data.Int (Int32) 3import Data.Word (Word32)
4import qualified Data.Vector.Storable as VS 4import qualified Data.Vector.Storable as VS
5 5
6import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) 6import Numeric.LinearAlgebra.HMatrix (Vector, Matrix)
@@ -9,11 +9,12 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix)
9type Jacobian = Double -> Vector Double -> Matrix Double 9type Jacobian = Double -> Vector Double -> Matrix Double
10 10
11data ODEOpts = ODEOpts { 11data ODEOpts = ODEOpts {
12 maxNumSteps :: Int32 12 maxNumSteps :: Word32
13 , minStep :: Double 13 , minStep :: Double
14 , relTol :: Double 14 , relTol :: Double
15 , absTols :: VS.Vector Double 15 , absTols :: VS.Vector Double
16 , initStep :: Double 16 , initStep :: Maybe Double
17 , maxFail :: Word32
17 } deriving (Read, Show, Eq, Ord) 18 } deriving (Read, Show, Eq, Ord)
18 19
19data SundialsDiagnostics = SundialsDiagnostics { 20data SundialsDiagnostics = SundialsDiagnostics {