From 59a0413a83a9bcee93e3f0761cae6fdda2a98933 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Thu, 5 Apr 2018 18:28:56 +0100 Subject: Get closer to the hmatrix-gsl interface --- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 42 ++++++++++++---------- 1 file changed, 24 insertions(+), 18 deletions(-) (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs') diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 0973c82..4270a13 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -185,14 +185,22 @@ data SundialsDiagnostics = SundialsDiagnostics { type Jacobian = Double -> Vector Double -> Matrix Double -- | Stepping functions -data ODEMethod = SDIRK_2_1_2 Jacobian +data ODEMethod = SDIRK_2_1_2 Jacobian | KVAERNO_4_2_3 Jacobian - | SDIRK_5_3_4 Jacobian + | SDIRK_5_3_4 Jacobian + | SDIRK_5_3_4' getMethod :: ODEMethod -> Int -getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 +getMethod (SDIRK_2_1_2 _) = sDIRK_2_1_2 getMethod (KVAERNO_4_2_3 _) = kVAERNO_4_2_3 -getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 +getMethod (SDIRK_5_3_4 _) = sDIRK_5_3_4 +getMethod (SDIRK_5_3_4' ) = sDIRK_5_3_4 + +getJacobian :: ODEMethod -> Maybe Jacobian +getJacobian (SDIRK_2_1_2 j) = Just j +getJacobian (KVAERNO_4_2_3 j) = Just j +getJacobian (SDIRK_5_3_4 j) = Just j +getJacobian (SDIRK_5_3_4' ) = Nothing -- | A version of 'odeSolveVWith' with reasonable default step control. odeSolveV @@ -215,7 +223,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolve f y0 ts = - case odeSolveVWith (SDIRK_5_3_4 undefined) Nothing 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of + case odeSolveVWith SDIRK_5_3_4' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) where @@ -225,13 +233,12 @@ odeSolve f y0 ts = g t x0 = V.fromList $ f t (V.toList x0) odeSolve' :: ODEMethod - -> (Double -> Vector Double -> Matrix Double) -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> [Double] -- ^ initial conditions -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution -odeSolve' method jac f y0 ts = - case odeSolveVWith method (pure jac') 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of +odeSolve' method f y0 ts = + case odeSolveVWith method 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) where @@ -239,30 +246,29 @@ odeSolve' method jac f y0 ts = nR = length us nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) - jac' t v = foo $ jac t (V.fromList $ toList v) - foo m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } - where - nr = fromIntegral $ rows m - nc = fromIntegral $ cols m - vs = V.fromList $ map coerce $ concat $ toLists m odeSolveVWith :: ODEMethod - -> (Maybe (Double -> V.Vector Double -> T.SunMatrix)) -> Double -> Double -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith method jac relTol absTol f y0 tt = +odeSolveVWith method relTol absTol f y0 tt = case solveOdeC (fromIntegral $ getMethod method) jacH (CDouble relTol) (CDouble absTol) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c Right (v, d) -> Right (coerce v, d) where - jacH :: Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix) - jacH = fmap (\g -> (\t v -> g (coerce t) (coerce v))) jac + -- FIXME: Can we do better than going via a list? + jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce $ V.fromList $ toList v)) $ + getJacobian method + matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } + where + nr = fromIntegral $ rows m + nc = fromIntegral $ cols m + vs = V.fromList $ map coerce $ concat $ toLists m solveOdeC :: CInt -> -- cgit v1.2.3