From 6fcc1b01cecc88f1a8eb1608667368c7e72048aa Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Thu, 5 Apr 2018 12:10:29 +0100 Subject: Start of mirroring hmatrix-gsl ODE module --- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 145 ++++++++++++++++----- 1 file changed, 112 insertions(+), 33 deletions(-) diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 8358954..0973c82 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -6,8 +6,34 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +----------------------------------------------------------------------------- -- | --- Module: Numeric.Sundials.ARKode +-- Module : Numeric.Sundials.ARKode +-- Copyright : Dominic Steinitz 2018, +-- Novadiscovery 2018 +-- License : BSD +-- Maintainer : Dominic Steinitz +-- Stability : provisional +-- +-- Solution of ordinary differential equation (ODE) initial value problems. +-- +-- +-- +-- A simple example: +-- +-- @ +-- import Numeric.Sundials.ARKode +-- import Numeric.LinearAlgebra +-- import Graphics.Plot(mplot) +-- +-- xdot t [x,v] = [v, -0.95*x - 0.1*v] +-- +-- ts = linspace 100 (0,20 :: Double) +-- +-- sol = odeSolve xdot [10,0] ts +-- +-- main = mplot (ts : toColumns sol) +-- @ -- -- KVAERNO_4_2_3 -- @@ -29,19 +55,34 @@ -- \end{array} -- \] -- -module Numeric.Sundials.ARKode.ODE ( solveOde - , odeSolve +-- SDIRK_5_3_4 +-- +-- \[ +-- \begin{array}{c|ccccc} +-- c_1 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ +-- c_2 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ +-- c_3 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ +-- c_4 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ +-- c_5 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ +-- \end{array} +-- \] +----------------------------------------------------------------------------- +module Numeric.Sundials.ARKode.ODE ( odeSolve + , odeSolveV + , odeSolveVWith + , odeSolve' , getButcherTable , getBT , btGet , ODEMethod(..) - , odeSolveV + , StepControl(..) ) where import qualified Language.C.Inline as C import qualified Language.C.Inline.Unsafe as CU import Data.Monoid ((<>)) +import Data.Maybe (isJust) import Foreign.C.Types import Foreign.Ptr (Ptr) @@ -60,9 +101,11 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix, rows, cols, toLists) import qualified Types as T -import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) +import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4) import qualified Arkode as B +import Debug.Trace + C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -139,14 +182,17 @@ data SundialsDiagnostics = SundialsDiagnostics { , _aRKDlsGetNumRhsEvals :: Int } deriving Show +type Jacobian = Double -> Vector Double -> Matrix Double + -- | Stepping functions -data ODEMethod = SDIRK_2_1_2 - | KVAERNO_4_2_3 +data ODEMethod = SDIRK_2_1_2 Jacobian + | KVAERNO_4_2_3 Jacobian + | SDIRK_5_3_4 Jacobian -instance Enum ODEMethod where - fromEnum SDIRK_2_1_2 = sDIRK_2_1_2 - fromEnum KVAERNO_4_2_3 = kVAERNO_4_2_3 - toEnum _ = error "toEnum not defined for ODEMethod" +getMethod :: ODEMethod -> Int +getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 +getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 +getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 -- | A version of 'odeSolveVWith' with reasonable default step control. odeSolveV @@ -160,16 +206,34 @@ odeSolveV -> Matrix Double -- ^ solution odeSolveV _meth _hi _epsAbs _epsRel = undefined -odeSolve :: ODEMethod +-- | A version of 'odeSolveV' with reasonable default parameters and +-- system of equations defined using lists. FIXME: we should say +-- something about the fact we could use the Jacobian but don't for +-- compatibility with hmatrix-gsl. +odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> [Double] -- ^ initial conditions + -> Vector Double -- ^ desired solution times + -> Matrix Double -- ^ solution +odeSolve f y0 ts = + case odeSolveVWith (SDIRK_5_3_4 undefined) Nothing 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of + Left c -> error $ show c -- FIXME + Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) + where + us = toList ts + nR = length us + nC = length y0 + g t x0 = V.fromList $ f t (V.toList x0) + +odeSolve' :: ODEMethod -> (Double -> Vector Double -> Matrix Double) -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> [Double] -- ^ initial conditions -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution -odeSolve method jac f y0 ts = - case solveOde method jac' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of +odeSolve' method jac f y0 ts = + case odeSolveVWith method (pure jac') 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME - Right (v, _) -> (nR >< nC) (V.toList v) + Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) where us = toList ts nR = length us @@ -182,26 +246,27 @@ odeSolve method jac f y0 ts = nc = fromIntegral $ cols m vs = V.fromList $ map coerce $ concat $ toLists m -solveOde :: +odeSolveVWith :: ODEMethod - -> (Double -> V.Vector Double -> T.SunMatrix) + -> (Maybe (Double -> V.Vector Double -> T.SunMatrix)) -> Double -> Double -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -solveOde method jac relTol absTol f y0 tt = - case solveOdeC (fromIntegral $ fromEnum method) jacH (CDouble relTol) (CDouble absTol) +odeSolveVWith method jac relTol absTol f y0 tt = + case solveOdeC (fromIntegral $ getMethod method) jacH (CDouble relTol) (CDouble absTol) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c Right (v, d) -> Right (coerce v, d) where - jacH t v = jac (coerce t) (coerce v) + jacH :: Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix) + jacH = fmap (\g -> (\t v -> g (coerce t) (coerce v))) jac solveOdeC :: CInt -> - (CDouble -> V.Vector CDouble -> T.SunMatrix) -> + (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> CDouble -> CDouble -> (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) @@ -235,15 +300,19 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] - let jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> + let isJac :: CInt + isJac = fromIntegral $ fromEnum $ isJust jacH + jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> IO CInt jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do - j <- jacH t <$> getDataFromContents dim y - putMatrixDataFromContents j jacS - -- FIXME: I don't understand what this comment means - -- Unsafe since the function will be called many times. - [CU.exp| int{ 0 } |] + case jacH of + Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" + Just jacI -> do j <- jacI t <$> getDataFromContents dim y + putMatrixDataFromContents j jacS + -- FIXME: I don't understand what this comment means + -- Unsafe since the function will be called many times. + [CU.exp| int{ 0 } |] res <- [C.block| int { /* general problem variables */ @@ -292,9 +361,10 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do /* Linear solver interface */ flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ - flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); - if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; - + if ($(int isJac)) { + flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); + if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; + } /* Store initial conditions */ for (j = 0; j < NEQ; j++) { @@ -389,9 +459,10 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do btGet :: ODEMethod -> Matrix Double btGet method = case getBT method of Left c -> error $ show c -- FIXME - -- FIXME - Right (v, _sqp) -> subMatrix (0, 0) (2, 2) $ + Right (v, sqp) -> subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) + where + s = fromIntegral $ sqp V.! 0 getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) getBT method = case getButcherTable method of @@ -410,7 +481,7 @@ getButcherTable method = unsafePerformIO $ do nEq :: CLong nEq = fromIntegral dim mN :: CInt - mN = fromIntegral $ fromEnum method + mN = fromIntegral $ getMethod method -- FIXME: I believe these gets taken from the ghc heap and so should -- be subject to garbage collection. @@ -493,3 +564,11 @@ getButcherTable method = unsafePerformIO $ do return $ Right (x, y) else do return $ Left res + +-- | Adaptive step-size control functions. FIXME: It may not be +-- possible to scale the tolerances for the derivatives in sundials so +-- for now we ignore them and emit a warning. +data StepControl = X Double Double -- ^ abs. and rel. tolerance for x(t) + | X' Double Double -- ^ abs. and rel. tolerance for x'(t) + | XX' Double Double Double Double -- ^ include both via rel. tolerance scaling factors a_x, a_x' + | ScXX' Double Double Double Double (Vector Double) -- ^ scale abs. tolerance of x(t) components -- cgit v1.2.3