From e8e631c7dbc7ea34b51dcce5fef5e2ec620f9458 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Thu, 26 Apr 2018 13:53:56 +0100 Subject: Refactor CVODE --- .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 75 ++++++++++------------ packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 4 ++ 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index d7a2b53..0871f9b 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -87,11 +87,12 @@ import System.IO.Unsafe (unsafePerformIO) import Numeric.LinearAlgebra.Devel (createVector) -import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), - rows, cols, toLists, size) +import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, + cols, toLists, size, reshape) import qualified Types as T import Arkode (cV_ADAMS, cV_BDF) +import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian) import qualified Numeric.Sundials.ODEOpts as SO @@ -111,8 +112,6 @@ C.include "../../../helpers.h" C.include "Arkode_hsc.h" -type Jacobian = Double -> Vector Double -> Matrix Double - -- | Stepping functions data ODEMethod = ADAMS | BDF @@ -140,14 +139,8 @@ odeSolveV -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolveV meth hi epsAbs epsRel f y0 ts = - case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of - Left c -> error $ show c -- FIXME - -- FIXME: Can we do better than using lists? - Right (v, _d) -> (nR >< nC) (V.toList v) + odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts where - us = toList ts - nR = length us - nC = size y0 g t x0 = coerce $ f t x0 -- | A version of 'odeSolveV' with reasonable default parameters and @@ -160,13 +153,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Matrix Double -- ^ solution odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL - case odeSolveVWith' BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of - Left c -> error $ show c -- FIXME - Right (v, _d) -> (nR >< nC) (V.toList v) + odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) where - us = toList ts - nR = length us - nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) odeSolveVWith :: @@ -183,15 +171,20 @@ odeSolveVWith :: -> V.Vector Double -- ^ Desired solution times -> Matrix Double -- ^ Error code or solution odeSolveVWith method control initStepSize f y0 tt = - case odeSolveVWith' method control initStepSize f y0 tt of + case odeSolveVWith' opts method control initStepSize f y0 tt of Left c -> error $ show c -- FIXME - Right (v, _d) -> (nR >< nC) (V.toList v) + Right (v, _d) -> v where - nR = V.length tt - nC = V.length y0 + opts = ODEOpts { maxNumSteps = 10000 + , minStep = 1.0e-12 + , relTol = error "relTol" + , absTols = error "absTol" + , initStep = error "initStep" + } odeSolveVWith' :: - ODEMethod + ODEOpts + -> ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode -- estimates the initial step size to be the @@ -202,19 +195,21 @@ odeSolveVWith' :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith' method control initStepSize f y0 tt = - case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) + -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution +odeSolveVWith' opts method control initStepSize f y0 tt = + case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) + (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c - Right (v, d) -> Right (coerce v, d) + Right (v, d) -> Right (reshape nC (coerce v), d) where + nC = V.length y0 l = size y0 - scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) - scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) - scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) + scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) + scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) + scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) -- FIXME; Should we check that the length of ss is correct? - scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) + scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ getJacobian method matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } @@ -225,6 +220,8 @@ odeSolveVWith' method control initStepSize f y0 tt = vs = V.fromList $ map coerce $ concat $ toLists m solveOdeC :: + CLong -> + CDouble -> CInt -> Maybe CDouble -> (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> @@ -233,7 +230,8 @@ solveOdeC :: -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution -solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do +solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = + unsafePerformIO $ do let isInitStepSize :: CInt isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize @@ -249,10 +247,6 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO nEq = fromIntegral dim nTs :: CInt nTs = fromIntegral $ V.length ts - -- FIXME: tMut is not actually mutatated? - tMut <- V.thaw ts - -- FIXME: I believe this gets taken from the ghc heap and so should - -- be subject to garbage collection. quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) qMatMut <- V.thaw quasiMatrixRes diagnostics :: V.Vector CLong <- createVector 10 -- FIXME @@ -324,18 +318,17 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; /* Specify tolerances */ for (i = 0; i < NEQ; i++) { - NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; + NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; }; - /* FIXME: A hack for initial testing */ - flag = CVodeSetMinStep(cvode_mem, 1.0e-12); + flag = CVodeSetMinStep(cvode_mem, $(double minStep_)); if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; - flag = CVodeSetMaxNumSteps(cvode_mem, 10000); + flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_)); if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; /* Call CVodeSVtolerances to specify the scalar relative tolerance * and vector absolute tolerances */ - flag = CVodeSVtolerances(cvode_mem, $(double relTol), tv); + flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv); if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); /* Initialize dense matrix data structure and solver */ @@ -371,7 +364,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO /* Stops when the final time has been reached */ for (i = 1; i < $(int nTs); i++) { - flag = CVode(cvode_mem, ($vec-ptr:(double *tMut))[i], y, &t, CV_NORMAL); /* call integrator */ + flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */ if (check_flag(&flag, "CVode", 1)) break; /* Store the results for Haskell */ diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index e924292..538b474 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs @@ -8,9 +8,13 @@ import Foreign.C.Types import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VM +import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) + import qualified Types as T import qualified Arkode as B +type Jacobian = Double -> Vector Double -> Matrix Double + data ODEOpts = ODEOpts { maxNumSteps :: Int32 , minStep :: Double -- cgit v1.2.3