From 1b64b28dfccf2cb9539cdb4344cd7ecb1c1d0a1d Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 21 Mar 2018 08:24:05 +0000 Subject: A library function + move towards the original hmatrix interface --- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 45 ++++++++-------------- 1 file changed, 16 insertions(+), 29 deletions(-) (limited to 'packages/sundials/src/Numeric/Sundials') diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 58acef3..9de20b6 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -6,7 +6,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} -module Numeric.Sundials.Arkode.ODE ( solveOdeC ) where +module Numeric.Sundials.Arkode.ODE ( solveOde ) where import qualified Language.C.Inline as C import qualified Language.C.Inline.Unsafe as CU @@ -76,32 +76,21 @@ vectorToC vec len ptr = do ptr' <- newForeignPtr_ ptr V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec -stiffish :: Double -> V.Vector Double -> V.Vector Double -stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] - where - u = v V.! 0 - lamda = -100.0 - -brusselator :: Double -> V.Vector Double -> V.Vector Double -brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 - , w * u - v * u^2 - , (b - w) / eps - w * u - ] - where - a = 1.0 - b = 3.5 - eps = 5.0e-6 - u = x V.! 0 - v = x V.! 1 - w = x V.! 2 - - odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> [Double] -- ^ initial conditions -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolve = undefined +solveOde :: + (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> V.Vector Double -- ^ Initial conditions + -> V.Vector Double -- ^ Desired solution times + -> Either Int (V.Vector Double) -- ^ Error code or solution +solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of + Left c -> Left $ fromIntegral c + Right v -> Right $ coerce v + solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions @@ -111,7 +100,10 @@ solveOdeC fun f0 ts = unsafePerformIO $ do let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim + nTs :: CInt + nTs = fromIntegral $ V.length ts fMut <- V.thaw f0 + tMut <- V.thaw ts -- We need the types that sundials expects. These are tied together -- in 'Types'. The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt @@ -136,8 +128,8 @@ solveOdeC fun f0 ts = unsafePerformIO $ do long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; /* general problem parameters */ - realtype T0 = RCONST(0.0); /* initial time */ - realtype Tf = RCONST(10.0); /* final time */ + realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ + realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ realtype dTout = RCONST(1.0); /* time between outputs */ sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ realtype reltol = 1.0e-6; /* tolerances */ @@ -193,7 +185,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do tout = T0+dTout; printf(" t u\n"); printf(" ---------------------\n"); - while (Tf - t > 1.0e-15) { + for (i = 0; i < $(int nTs); i++) { flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ if (check_flag(&flag, "ARKode", 1)) break; @@ -258,8 +250,3 @@ solveOdeC fun f0 ts = unsafePerformIO $ do return $ Right v else do return $ Left res - -main :: IO () -main = do - let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined - putStrLn $ show res -- cgit v1.2.3