summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs113
1 files changed, 55 insertions, 58 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 13b7eb8..ce46968 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -114,7 +114,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve
114 , butcherTable 114 , butcherTable
115 , ODEMethod(..) 115 , ODEMethod(..)
116 , StepControl(..) 116 , StepControl(..)
117 , Jacobian
118 ) where 117 ) where
119 118
120import qualified Language.C.Inline as C 119import qualified Language.C.Inline as C
@@ -136,11 +135,11 @@ import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(
136 135
137import Numeric.LinearAlgebra.Devel (createVector) 136import Numeric.LinearAlgebra.Devel (createVector)
138 137
139import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), 138import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows,
140 subMatrix, rows, cols, toLists, 139 cols, toLists, size, reshape,
141 size, subVector) 140 subVector, subMatrix, (><))
142 141
143import qualified Numeric.Sundials.ODEOpts as SO 142import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..))
144import qualified Numeric.Sundials.Arkode as T 143import qualified Numeric.Sundials.Arkode as T
145import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, 144import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax,
146 sDIRK_2_1_2, 145 sDIRK_2_1_2,
@@ -185,8 +184,6 @@ C.include "../../../helpers.h"
185C.include "Numeric/Sundials/Arkode_hsc.h" 184C.include "Numeric/Sundials/Arkode_hsc.h"
186 185
187 186
188type Jacobian = Double -> Vector Double -> Matrix Double
189
190-- | Stepping functions 187-- | Stepping functions
191data ODEMethod = SDIRK_2_1_2 Jacobian 188data ODEMethod = SDIRK_2_1_2 Jacobian
192 | SDIRK_2_1_2' 189 | SDIRK_2_1_2'
@@ -351,15 +348,9 @@ odeSolveV
351 -> Vector Double -- ^ desired solution times 348 -> Vector Double -- ^ desired solution times
352 -> Matrix Double -- ^ solution 349 -> Matrix Double -- ^ solution
353odeSolveV meth hi epsAbs epsRel f y0 ts = 350odeSolveV meth hi epsAbs epsRel f y0 ts =
354 case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of 351 odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts
355 Left c -> error $ show c -- FIXME 352 where
356 -- FIXME: Can we do better than using lists? 353 g t x0 = coerce $ f t x0
357 Right (v, _d) -> (nR >< nC) (V.toList v)
358 where
359 us = toList ts
360 nR = length us
361 nC = size y0
362 g t x0 = coerce $ f t x0
363 354
364-- | A version of 'odeSolveV' with reasonable default parameters and 355-- | A version of 'odeSolveV' with reasonable default parameters and
365-- system of equations defined using lists. FIXME: we should say 356-- system of equations defined using lists. FIXME: we should say
@@ -371,13 +362,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
371 -> Matrix Double -- ^ solution 362 -> Matrix Double -- ^ solution
372odeSolve f y0 ts = 363odeSolve f y0 ts =
373 -- FIXME: These tolerances are different from the ones in GSL 364 -- FIXME: These tolerances are different from the ones in GSL
374 case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 365 odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts)
375 Left c -> error $ show c -- FIXME
376 Right (v, _d) -> (nR >< nC) (V.toList v)
377 where 366 where
378 us = toList ts
379 nR = length us
380 nC = length y0
381 g t x0 = V.fromList $ f t (V.toList x0) 367 g t x0 = V.fromList $ f t (V.toList x0)
382 368
383odeSolveVWith :: 369odeSolveVWith ::
@@ -394,15 +380,21 @@ odeSolveVWith ::
394 -> V.Vector Double -- ^ Desired solution times 380 -> V.Vector Double -- ^ Desired solution times
395 -> Matrix Double -- ^ Error code or solution 381 -> Matrix Double -- ^ Error code or solution
396odeSolveVWith method control initStepSize f y0 tt = 382odeSolveVWith method control initStepSize f y0 tt =
397 case odeSolveVWith' method control initStepSize f y0 tt of 383 case odeSolveVWith' opts method control initStepSize f y0 tt of
398 Left c -> error $ show c -- FIXME 384 Left c -> error $ show c -- FIXME
399 Right (v, _d) -> (nR >< nC) (V.toList v) 385 Right (v, _d) -> v
400 where 386 where
401 nR = V.length tt 387 opts = ODEOpts { maxNumSteps = 10000
402 nC = V.length y0 388 , minStep = 1.0e-12
389 , relTol = error "relTol"
390 , absTols = error "absTol"
391 , initStep = error "initStep"
392 , maxFail = 10
393 }
403 394
404odeSolveVWith' :: 395odeSolveVWith' ::
405 ODEMethod 396 ODEOpts
397 -> ODEMethod
406 -> StepControl 398 -> StepControl
407 -> Maybe Double -- ^ initial step size - by default, ARKode 399 -> Maybe Double -- ^ initial step size - by default, ARKode
408 -- estimates the initial step size to be the 400 -- estimates the initial step size to be the
@@ -413,19 +405,21 @@ odeSolveVWith' ::
413 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 405 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
414 -> V.Vector Double -- ^ Initial conditions 406 -> V.Vector Double -- ^ Initial conditions
415 -> V.Vector Double -- ^ Desired solution times 407 -> V.Vector Double -- ^ Desired solution times
416 -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution 408 -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution
417odeSolveVWith' method control initStepSize f y0 tt = 409odeSolveVWith' opts method control initStepSize f y0 tt =
418 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) 410 case solveOdeC (fromIntegral $ maxFail opts)
411 (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
412 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
419 (coerce f) (coerce y0) (coerce tt) of 413 (coerce f) (coerce y0) (coerce tt) of
420 Left c -> Left $ fromIntegral c 414 Left c -> Left $ fromIntegral c
421 Right (v, d) -> Right (coerce v, d) 415 Right (v, d) -> Right (reshape l (coerce v), d)
422 where 416 where
423 l = size y0 417 l = size y0
424 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) 418 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
425 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) 419 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
426 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) 420 scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol)
427 -- FIXME; Should we check that the length of ss is correct? 421 -- FIXME; Should we check that the length of ss is correct?
428 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) 422 scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol)
429 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ 423 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
430 getJacobian method 424 getJacobian method
431 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } 425 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
@@ -437,14 +431,18 @@ odeSolveVWith' method control initStepSize f y0 tt =
437 431
438solveOdeC :: 432solveOdeC ::
439 CInt -> 433 CInt ->
434 CLong ->
435 CDouble ->
436 CInt ->
440 Maybe CDouble -> 437 Maybe CDouble ->
441 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> 438 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
442 (V.Vector CDouble, CDouble) -> 439 (V.Vector CDouble, CDouble) ->
443 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 440 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
444 -> V.Vector CDouble -- ^ Initial conditions 441 -> V.Vector CDouble -- ^ Initial conditions
445 -> V.Vector CDouble -- ^ Desired solution times 442 -> V.Vector CDouble -- ^ Desired solution times
446 -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution 443 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
447solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do 444solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize
445 jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do
448 446
449 let isInitStepSize :: CInt 447 let isInitStepSize :: CInt
450 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize 448 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
@@ -455,14 +453,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
455 -- used :( 453 -- used :(
456 Nothing -> 0.0 454 Nothing -> 0.0
457 Just x -> x 455 Just x -> x
456
458 let dim = V.length f0 457 let dim = V.length f0
459 nEq :: CLong 458 nEq :: CLong
460 nEq = fromIntegral dim 459 nEq = fromIntegral dim
461 nTs :: CInt 460 nTs :: CInt
462 nTs = fromIntegral $ V.length ts 461 nTs = fromIntegral $ V.length ts
463 -- FIXME: fMut is not actually mutatated
464 fMut <- V.thaw f0
465 tMut <- V.thaw ts
466 -- FIXME: I believe this gets taken from the ghc heap and so should 462 -- FIXME: I believe this gets taken from the ghc heap and so should
467 -- be subject to garbage collection. 463 -- be subject to garbage collection.
468 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) 464 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
@@ -510,7 +506,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
510 506
511 /* general problem parameters */ 507 /* general problem parameters */
512 508
513 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ 509 realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */
514 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ 510 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
515 511
516 /* Initialize data structures */ 512 /* Initialize data structures */
@@ -519,14 +515,14 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
519 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 515 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
520 /* Specify initial condition */ 516 /* Specify initial condition */
521 for (i = 0; i < NEQ; i++) { 517 for (i = 0; i < NEQ; i++) {
522 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; 518 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
523 }; 519 };
524 520
525 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ 521 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
526 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; 522 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
527 /* Specify tolerances */ 523 /* Specify tolerances */
528 for (i = 0; i < NEQ; i++) { 524 for (i = 0; i < NEQ; i++) {
529 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; 525 NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i];
530 }; 526 };
531 527
532 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 528 arkode_mem = ARKodeCreate(); /* Create the solver memory */
@@ -547,14 +543,15 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
547 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 543 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
548 } 544 }
549 545
550 /* FIXME: A hack for initial testing */ 546 flag = ARKodeSetMinStep(arkode_mem, $(double minStep_));
551 flag = ARKodeSetMinStep(arkode_mem, 1.0e-12);
552 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; 547 if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1;
553 flag = ARKodeSetMaxNumSteps(arkode_mem, 10000); 548 flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_));
554 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; 549 if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1;
550 flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails));
551 if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1;
555 552
556 /* Set routines */ 553 /* Set routines */
557 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); 554 flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv);
558 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; 555 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1;
559 556
560 /* Initialize dense matrix data structure and solver */ 557 /* Initialize dense matrix data structure and solver */
@@ -599,7 +596,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
599 /* Stops when the final time has been reached */ 596 /* Stops when the final time has been reached */
600 for (i = 1; i < $(int nTs); i++) { 597 for (i = 1; i < $(int nTs); i++) {
601 598
602 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ 599 flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */
603 if (check_flag(&flag, "ARKode", 1)) break; 600 if (check_flag(&flag, "ARKode", 1)) break;
604 601
605 /* Store the results for Haskell */ 602 /* Store the results for Haskell */
@@ -665,16 +662,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
665 if res == 0 662 if res == 0
666 then do 663 then do
667 preD <- V.freeze diagMut 664 preD <- V.freeze diagMut
668 let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) 665 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
669 (fromIntegral $ preD V.!1) 666 (fromIntegral $ preD V.!1)
670 (fromIntegral $ preD V.!2) 667 (fromIntegral $ preD V.!2)
671 (fromIntegral $ preD V.!3) 668 (fromIntegral $ preD V.!3)
672 (fromIntegral $ preD V.!4) 669 (fromIntegral $ preD V.!4)
673 (fromIntegral $ preD V.!5) 670 (fromIntegral $ preD V.!5)
674 (fromIntegral $ preD V.!6) 671 (fromIntegral $ preD V.!6)
675 (fromIntegral $ preD V.!7) 672 (fromIntegral $ preD V.!7)
676 (fromIntegral $ preD V.!8) 673 (fromIntegral $ preD V.!8)
677 (fromIntegral $ preD V.!9) 674 (fromIntegral $ preD V.!9)
678 m <- V.freeze qMatMut 675 m <- V.freeze qMatMut
679 return $ Right (m, d) 676 return $ Right (m, d)
680 else do 677 else do