From 9f571e009bc46c26334be8b6a635db1e1d5b0341 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Mon, 23 Apr 2018 16:19:22 +0100 Subject: Start of support for CVODE --- .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 456 +++++++++++++++++++++ 1 file changed, 456 insertions(+) create mode 100644 packages/sundials/src/Numeric/Sundials/CVode/ODE.hs diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs new file mode 100644 index 0000000..f75d91f --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -0,0 +1,456 @@ +{-# OPTIONS_GHC -Wall #-} + +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +----------------------------------------------------------------------------- +-- | +-- Module : Numeric.Sundials.CVode.ODE +-- Copyright : Dominic Steinitz 2018, +-- Novadiscovery 2018 +-- License : BSD +-- Maintainer : Dominic Steinitz +-- Stability : provisional +-- +-- Solution of ordinary differential equation (ODE) initial value problems. +-- +-- +-- +-- A simple example: +-- +-- <> +-- +-- @ +-- import Numeric.Sundials.CVode.ODE +-- import Numeric.LinearAlgebra +-- +-- import Plots as P +-- import qualified Diagrams.Prelude as D +-- import Diagrams.Backend.Rasterific +-- +-- brusselator :: Double -> [Double] -> [Double] +-- brusselator _t x = [ a - (w + 1) * u + v * u * u +-- , w * u - v * u * u +-- , (b - w) / eps - w * u +-- ] +-- where +-- a = 1.0 +-- b = 3.5 +-- eps = 5.0e-6 +-- u = x !! 0 +-- v = x !! 1 +-- w = x !! 2 +-- +-- lSaxis :: [[Double]] -> P.Axis B D.V2 Double +-- lSaxis xs = P.r2Axis &~ do +-- let ts = xs!!0 +-- us = xs!!1 +-- vs = xs!!2 +-- ws = xs!!3 +-- P.linePlot' $ zip ts us +-- P.linePlot' $ zip ts vs +-- P.linePlot' $ zip ts ws +-- +-- main = do +-- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) +-- renderRasterific "diagrams/brusselator.png" +-- (D.dims2D 500.0 500.0) +-- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) +-- @ +-- +-- KVAERNO_4_2_3 +-- +-- \[ +-- \begin{array}{c|cccc} +-- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\ +-- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ +-- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ +-- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ +-- \hline +-- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ +-- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ +-- \end{array} +-- \] +-- +-- SDIRK_2_1_2 +-- +-- \[ +-- \begin{array}{c|cc} +-- 1.0 & 1.0 & 0.0 \\ +-- 0.0 & -1.0 & 1.0 \\ +-- \hline +-- & 0.5 & 0.5 \\ +-- & 1.0 & 0.0 \\ +-- \end{array} +-- \] +-- +-- SDIRK_5_3_4 +-- +-- \[ +-- \begin{array}{c|ccccc} +-- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ +-- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ +-- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ +-- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ +-- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ +-- \hline +-- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ +-- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\ +-- \end{array} +-- \] +----------------------------------------------------------------------------- +module Numeric.Sundials.CVode.ODE ( odeSolve + , odeSolveV + , odeSolveVWith + , odeSolveVWith' + , ODEMethod(..) + , StepControl(..) + , Jacobian + , SundialsDiagnostics(..) + ) where + +import qualified Language.C.Inline as C +import qualified Language.C.Inline.Unsafe as CU + +import Data.Monoid ((<>)) +import Data.Maybe (isJust) + +import Foreign.C.Types +import Foreign.Ptr (Ptr) +import Foreign.ForeignPtr (newForeignPtr_) +import Foreign.Storable (Storable) + +import qualified Data.Vector.Storable as V +import qualified Data.Vector.Storable.Mutable as VM + +import Data.Coerce (coerce) +import System.IO.Unsafe (unsafePerformIO) + +import Numeric.LinearAlgebra.Devel (createVector) + +import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), + subMatrix, rows, cols, toLists, + size, subVector) + +import qualified Types as T +import Arkode (cV_ADAMS, cV_BDF) +import qualified Arkode as B + + +C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) + +C.include "" +C.include "" +C.include "" +C.include "" -- prototypes for CVODE fcts., consts. +C.include "" -- serial N_Vector types, fcts., macros +C.include "" -- access to dense SUNMatrix +C.include "" -- access to dense SUNLinearSolver +C.include "" -- access to CVDls interface +C.include "" -- definition of type realtype +C.include "" +C.include "../../../helpers.h" +C.include "Arkode_hsc.h" + + +getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) +getDataFromContents len ptr = do + qtr <- B.getContentPtr ptr + rtr <- B.getData qtr + vectorFromC len rtr + +-- FIXME: Potentially an instance of Storable +_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix +_getMatrixDataFromContents ptr = do + qtr <- B.getContentMatrixPtr ptr + rs <- B.getNRows qtr + cs <- B.getNCols qtr + rtr <- B.getMatrixData qtr + vs <- vectorFromC (fromIntegral $ rs * cs) rtr + return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } + +putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () +putMatrixDataFromContents mat ptr = do + let rs = T.rows mat + cs = T.cols mat + vs = T.vals mat + qtr <- B.getContentMatrixPtr ptr + B.putNRows rs qtr + B.putNCols cs qtr + rtr <- B.getMatrixData qtr + vectorToC vs (fromIntegral $ rs * cs) rtr +-- FIXME: END + +putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () +putDataInContents vec len ptr = do + qtr <- B.getContentPtr ptr + rtr <- B.getData qtr + vectorToC vec len rtr + +-- Utils + +vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) +vectorFromC len ptr = do + ptr' <- newForeignPtr_ ptr + V.freeze $ VM.unsafeFromForeignPtr0 ptr' len + +vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () +vectorToC vec len ptr = do + ptr' <- newForeignPtr_ ptr + V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec + +data SundialsDiagnostics = SundialsDiagnostics { + aRKodeGetNumSteps :: Int + , aRKodeGetNumStepAttempts :: Int + , aRKodeGetNumRhsEvals_fe :: Int + , aRKodeGetNumRhsEvals_fi :: Int + , aRKodeGetNumLinSolvSetups :: Int + , aRKodeGetNumErrTestFails :: Int + , aRKodeGetNumNonlinSolvIters :: Int + , aRKodeGetNumNonlinSolvConvFails :: Int + , aRKDlsGetNumJacEvals :: Int + , aRKDlsGetNumRhsEvals :: Int + } deriving Show + +type Jacobian = Double -> Vector Double -> Matrix Double + +-- | Stepping functions +data ODEMethod = ADAMS + | BDF + +getMethod :: ODEMethod -> Int +getMethod (ADAMS) = cV_ADAMS +getMethod (BDF) = cV_BDF + +getJacobian :: ODEMethod -> Maybe Jacobian +getJacobian _ = Nothing + +-- | A version of 'odeSolveVWith' with reasonable default step control. +odeSolveV + :: ODEMethod + -> Maybe Double -- ^ initial step size - by default, ARKode + -- estimates the initial step size to be the + -- solution \(h\) of the equation + -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where + -- \(\ddot{y}\) is an estimated value of the + -- second derivative of the solution at \(t_0\) + -> Double -- ^ absolute tolerance for the state vector + -> Double -- ^ relative tolerance for the state vector + -> (Double -> Vector Double -> Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> Vector Double -- ^ initial conditions + -> Vector Double -- ^ desired solution times + -> Matrix Double -- ^ solution +odeSolveV meth hi epsAbs epsRel f y0 ts = + case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of + Left c -> error $ show c -- FIXME + -- FIXME: Can we do better than using lists? + Right (v, d) -> (nR >< nC) (V.toList v) + where + us = toList ts + nR = length us + nC = size y0 + g t x0 = coerce $ f t x0 + +-- | A version of 'odeSolveV' with reasonable default parameters and +-- system of equations defined using lists. FIXME: we should say +-- something about the fact we could use the Jacobian but don't for +-- compatibility with hmatrix-gsl. +odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> [Double] -- ^ initial conditions + -> Vector Double -- ^ desired solution times + -> Matrix Double -- ^ solution +odeSolve f y0 ts = + -- FIXME: These tolerances are different from the ones in GSL + case odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of + Left c -> error $ show c -- FIXME + Right (v, d) -> (nR >< nC) (V.toList v) + where + us = toList ts + nR = length us + nC = length y0 + g t x0 = V.fromList $ f t (V.toList x0) + +odeSolveVWith' :: + ODEMethod + -> StepControl + -> Maybe Double -- ^ initial step size - by default, ARKode + -- estimates the initial step size to be the + -- solution \(h\) of the equation + -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where + -- \(\ddot{y}\) is an estimated value of the second + -- derivative of the solution at \(t_0\) + -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> V.Vector Double -- ^ Initial conditions + -> V.Vector Double -- ^ Desired solution times + -> Matrix Double -- ^ Error code or solution +odeSolveVWith' method control initStepSize f y0 tt = + case odeSolveVWith method control initStepSize f y0 tt of + Left c -> error $ show c -- FIXME + Right (v, _d) -> (nR >< nC) (V.toList v) + where + nR = V.length tt + nC = V.length y0 + +odeSolveVWith :: + ODEMethod + -> StepControl + -> Maybe Double -- ^ initial step size - by default, ARKode + -- estimates the initial step size to be the + -- solution \(h\) of the equation + -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where + -- \(\ddot{y}\) is an estimated value of the second + -- derivative of the solution at \(t_0\) + -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> V.Vector Double -- ^ Initial conditions + -> V.Vector Double -- ^ Desired solution times + -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution +odeSolveVWith method control initStepSize f y0 tt = + case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) + (coerce f) (coerce y0) (coerce tt) of + Left c -> Left $ fromIntegral c + Right (v, d) -> Right (coerce v, d) + where + l = size y0 + scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) + scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) + scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) + -- FIXME; Should we check that the length of ss is correct? + scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) + jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ + getJacobian method + matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } + where + nr = fromIntegral $ rows m + nc = fromIntegral $ cols m + -- FIXME: efficiency + vs = V.fromList $ map coerce $ concat $ toLists m + +solveOdeC :: + CInt -> + Maybe CDouble -> + (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> + (V.Vector CDouble, CDouble) -> + (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> V.Vector CDouble -- ^ Initial conditions + -> V.Vector CDouble -- ^ Desired solution times + -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution +solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do + + let isInitStepSize :: CInt + isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize + ss :: CDouble + ss = case initStepSize of + -- It would be better to put an error message here but + -- inline-c seems to evaluate this even if it is never + -- used :( + Nothing -> 0.0 + Just x -> x + let dim = V.length f0 + nEq :: CLong + nEq = fromIntegral dim + nTs :: CInt + nTs = fromIntegral $ V.length ts + -- FIXME: fMut is not actually mutatated + fMut <- V.thaw f0 + tMut <- V.thaw ts + -- FIXME: I believe this gets taken from the ghc heap and so should + -- be subject to garbage collection. + -- quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) + -- qMatMut <- V.thaw quasiMatrixRes + diagnostics :: V.Vector CLong <- createVector 10 -- FIXME + diagMut <- V.thaw diagnostics + -- We need the types that sundials expects. These are tied together + -- in 'Types'. FIXME: The Haskell type is currently empty! + let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt + funIO x y f _ptr = do + -- Convert the pointer we get from C (y) to a vector, and then + -- apply the user-supplied function. + fImm <- fun x <$> getDataFromContents dim y + -- Fill in the provided pointer with the resulting vector. + putDataInContents fImm dim f + -- FIXME: I don't understand what this comment means + -- Unsafe since the function will be called many times. + [CU.exp| int{ 0 } |] + let isJac :: CInt + isJac = fromIntegral $ fromEnum $ isJust jacH + jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> + Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> + IO CInt + jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do + case jacH of + Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" + Just jacI -> do j <- jacI t <$> getDataFromContents dim y + putMatrixDataFromContents j jacS + -- FIXME: I don't understand what this comment means + -- Unsafe since the function will be called many times. + [CU.exp| int{ 0 } |] + + res <- [C.block| int { + /* general problem variables */ + + int flag; /* reusable error-checking flag */ + int i, j; /* reusable loop indices */ + N_Vector y = NULL; /* empty vector for storing solution */ + void *cvode_mem = NULL; /* empty CVODE memory structure */ + + /* general problem parameters */ + + realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ + sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ + + /* Initialize data structures */ + + y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ + if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; + /* Specify initial condition */ + for (i = 0; i < NEQ; i++) { + NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; + }; + + cvode_mem = CVodeCreate(CV_BDF, CV_NEWTON); + if (check_flag((void *)cvode_mem, "CVodeCreate", 0)) return(1); + + /* Call CVodeInit to initialize the integrator memory and specify the + * user's right hand side function in y'=f(t,y), the inital time T0, and + * the initial dependent variable vector y. */ + flag = CVodeInit(cvode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); + if (check_flag(&flag, "CVodeInit", 1)) return(1); + + /* Clean up and return */ + + N_VDestroy(y); /* Free y vector */ + CVodeFree(&cvode_mem); /* Free integrator memory */ + + return flag; + } |] + if res == 0 + then do + return $ Left res + else do + return $ Left res + +-- | Adaptive step-size control +-- functions. +-- +-- [GSL](https://www.gnu.org/software/gsl/doc/html/ode-initval.html#adaptive-step-size-control) +-- allows the user to control the step size adjustment using +-- \(D_i = \epsilon^{abs}s_i + \epsilon^{rel}(a_{y} |y_i| + a_{dy/dt} h |\dot{y}_i|)\) where +-- \(\epsilon^{abs}\) is the required absolute error, \(\epsilon^{rel}\) +-- is the required relative error, \(s_i\) is a vector of scaling +-- factors, \(a_{y}\) is a scaling factor for the solution \(y\) and +-- \(a_{dydt}\) is a scaling factor for the derivative of the solution \(dy/dt\). +-- +-- [ARKode](https://computation.llnl.gov/projects/sundials/arkode) +-- allows the user to control the step size adjustment using +-- \(\eta^{rel}|y_i| + \eta^{abs}_i\). For compatibility with +-- [hmatrix-gsl](https://hackage.haskell.org/package/hmatrix-gsl), +-- tolerances for \(y\) and \(\dot{y}\) can be specified but the latter have no +-- effect. +data StepControl = X Double Double -- ^ absolute and relative tolerance for \(y\); in GSL terms, \(a_{y} = 1\) and \(a_{dy/dt} = 0\); in ARKode terms, the \(\eta^{abs}_i\) are identical + | X' Double Double -- ^ absolute and relative tolerance for \(\dot{y}\); in GSL terms, \(a_{y} = 0\) and \(a_{dy/dt} = 1\); in ARKode terms, the latter is treated as the relative tolerance for \(y\) so this is the same as specifying 'X' which may be entirely incorrect for the given problem + | XX' Double Double Double Double -- ^ include both via relative tolerance + -- scaling factors \(a_y\), \(a_{{dy}/{dt}}\); in ARKode terms, the latter is ignored and \(\eta^{rel} = a_{y}\epsilon^{rel}\) + | ScXX' Double Double Double Double (Vector Double) -- ^ scale absolute tolerance of \(y_i\); in ARKode terms, \(a_{{dy}/{dt}}\) is ignored, \(\eta^{abs}_i = s_i \epsilon^{abs}\) and \(\eta^{rel} = a_{y}\epsilon^{rel}\) -- cgit v1.2.3 From 79962d2141f356b6a8018d767e49db162a146405 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Mon, 23 Apr 2018 16:20:33 +0100 Subject: Ancilliary files for the start of CVODE support --- packages/sundials/hmatrix-sundials.cabal | 10 +++++++--- packages/sundials/src/Arkode.hsc | 6 ++++++ packages/sundials/src/Main.hs | 18 +++++++++++------- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 388f1db..4cc02c6 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal @@ -25,10 +25,12 @@ library template-haskell >=2.12 && <2.13, containers >=0.5 && <0.6, hmatrix>=0.18 - extra-libraries: sundials_arkode + extra-libraries: sundials_arkode, + sundials_cvode other-extensions: QuasiQuotes hs-source-dirs: src - exposed-modules: Numeric.Sundials.ARKode.ODE + exposed-modules: Numeric.Sundials.ARKode.ODE, + Numeric.Sundials.CVode.ODE other-modules: Types, Arkode c-sources: src/helpers.c src/helpers.h @@ -39,6 +41,7 @@ test-suite hmatrix-sundials-testsuite main-is: Main.hs other-modules: Types, Numeric.Sundials.ARKode.ODE, + Numeric.Sundials.CVode.ODE, Arkode build-depends: base >=4.10 && <4.11, inline-c >=0.6 && <0.7, @@ -52,6 +55,7 @@ test-suite hmatrix-sundials-testsuite lens, hspec hs-source-dirs: src - extra-libraries: sundials_arkode + extra-libraries: sundials_arkode, + sundials_cvode c-sources: src/helpers.c src/helpers.h default-language: Haskell2010 diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Arkode.hsc index 9db37b5..558ce9e 100644 --- a/packages/sundials/src/Arkode.hsc +++ b/packages/sundials/src/Arkode.hsc @@ -10,6 +10,7 @@ import Foreign.C.Types #include #include #include +#include #def typedef struct _generic_N_Vector SunVector; @@ -40,6 +41,11 @@ getContentPtr ptr = (#peek SunVector, content) ptr getData :: Storable a => Ptr b -> IO a getData ptr = (#peek SunContent, data) ptr +cV_ADAMS :: Int +cV_ADAMS = #const CV_ADAMS +cV_BDF :: Int +cV_BDF = #const CV_BDF + arkSMax :: Int arkSMax = #const ARK_S_MAX diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 729d35a..3904b09 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -1,6 +1,7 @@ {-# OPTIONS_GHC -Wall #-} -import Numeric.Sundials.ARKode.ODE +import qualified Numeric.Sundials.ARKode.ODE as ARK +import qualified Numeric.Sundials.CVode.ODE as CV import Numeric.LinearAlgebra import Plots as P @@ -97,24 +98,24 @@ kSaxis xs = P.r2Axis &~ do main :: IO () main = do - let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) + let res1 = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) renderRasterific "diagrams/brusselator.png" (D.dims2D 500.0 500.0) (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) - let res1a = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) + let res1a = ARK.odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) renderRasterific "diagrams/brusselatorA.png" (D.dims2D 500.0 500.0) (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1a)) - let res2 = odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) + let res2 = ARK.odeSolve stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) renderRasterific "diagrams/stiffish.png" (D.dims2D 500.0 500.0) (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) - let res2a = odeSolveV (SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) + let res2a = ARK.odeSolveV (ARK.SDIRK_5_3_4') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) - let res2b = odeSolveV (TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) + let res2b = ARK.odeSolveV (ARK.TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) let maxDiff = maximum $ map abs $ zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) @@ -123,7 +124,10 @@ main = do it "for two different RK methods" $ maxDiff < 1.0e-6 - let res3 = odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) + let res2c = CV.odeSolveV (CV.BDF) Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) + putStrLn $ show res2c + + let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) renderRasterific "diagrams/lorenz.png" (D.dims2D 500.0 500.0) -- cgit v1.2.3 From c73f86f64a60209a50b9c4cc3b137726955f2df7 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Tue, 24 Apr 2018 11:53:50 +0100 Subject: CVODE now supported somewhat --- packages/sundials/src/Main.hs | 20 ++- .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 144 +++++++++++++++++++-- 2 files changed, 147 insertions(+), 17 deletions(-) diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 3904b09..85928e2 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -117,15 +117,21 @@ main = do let res2b = ARK.odeSolveV (ARK.TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) - let maxDiff = maximum $ map abs $ - zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) - - hspec $ describe "Compare results" $ do - it "for two different RK methods" $ - maxDiff < 1.0e-6 + let maxDiffA = maximum $ map abs $ + zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) let res2c = CV.odeSolveV (CV.BDF) Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) - putStrLn $ show res2c + + let maxDiffB = maximum $ map abs $ + zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2c)!!0) + + let maxDiffC = maximum $ map abs $ + zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0) + + hspec $ describe "Compare results" $ do + it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6 + it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6 + it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6 let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index f75d91f..abe1bfe 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -132,8 +132,7 @@ import System.IO.Unsafe (unsafePerformIO) import Numeric.LinearAlgebra.Devel (createVector) import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), - subMatrix, rows, cols, toLists, - size, subVector) + rows, cols, toLists, size) import qualified Types as T import Arkode (cV_ADAMS, cV_BDF) @@ -247,7 +246,7 @@ odeSolveV meth hi epsAbs epsRel f y0 ts = case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of Left c -> error $ show c -- FIXME -- FIXME: Can we do better than using lists? - Right (v, d) -> (nR >< nC) (V.toList v) + Right (v, _d) -> (nR >< nC) (V.toList v) where us = toList ts nR = length us @@ -266,7 +265,7 @@ odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL case odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME - Right (v, d) -> (nR >< nC) (V.toList v) + Right (v, _d) -> (nR >< nC) (V.toList v) where us = toList ts nR = length us @@ -353,13 +352,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO nEq = fromIntegral dim nTs :: CInt nTs = fromIntegral $ V.length ts - -- FIXME: fMut is not actually mutatated - fMut <- V.thaw f0 + -- FIXME: tMut is not actually mutatated? tMut <- V.thaw ts -- FIXME: I believe this gets taken from the ghc heap and so should -- be subject to garbage collection. - -- quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) - -- qMatMut <- V.thaw quasiMatrixRes + quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) + qMatMut <- V.thaw quasiMatrixRes diagnostics :: V.Vector CLong <- createVector 10 -- FIXME diagMut <- V.thaw diagnostics -- We need the types that sundials expects. These are tied together @@ -394,7 +392,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO int flag; /* reusable error-checking flag */ int i, j; /* reusable loop indices */ N_Vector y = NULL; /* empty vector for storing solution */ + N_Vector tv = NULL; /* empty vector for storing absolute tolerances */ + + SUNMatrix A = NULL; /* empty matrix for linear solver */ + SUNLinearSolver LS = NULL; /* empty linear solver object */ void *cvode_mem = NULL; /* empty CVODE memory structure */ + realtype t; + long int nst, nfe, nsetups, nje, nfeLS, nni, ncfn, netf, nge; /* general problem parameters */ @@ -410,7 +414,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; }; - cvode_mem = CVodeCreate(CV_BDF, CV_NEWTON); + cvode_mem = CVodeCreate($(int method), CV_NEWTON); if (check_flag((void *)cvode_mem, "CVodeCreate", 0)) return(1); /* Call CVodeInit to initialize the integrator memory and specify the @@ -419,16 +423,136 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO flag = CVodeInit(cvode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); if (check_flag(&flag, "CVodeInit", 1)) return(1); + tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ + if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; + /* Specify tolerances */ + for (i = 0; i < NEQ; i++) { + NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; + }; + + /* FIXME: A hack for initial testing */ + flag = CVodeSetMinStep(cvode_mem, 1.0e-12); + if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; + flag = CVodeSetMaxNumSteps(cvode_mem, 10000); + if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; + + /* Call CVodeSVtolerances to specify the scalar relative tolerance + * and vector absolute tolerances */ + flag = CVodeSVtolerances(cvode_mem, $(double relTol), tv); + if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); + + /* Initialize dense matrix data structure and solver */ + A = SUNDenseMatrix(NEQ, NEQ); + if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; + LS = SUNDenseLinearSolver(y, A); + if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; + + /* Attach matrix and linear solver */ + flag = CVDlsSetLinearSolver(cvode_mem, LS, A); + if (check_flag(&flag, "CVDlsSetLinearSolver", 1)) return 1; + + /* Set the initial step size if there is one */ + if ($(int isInitStepSize)) { + /* FIXME: We could check if the initial step size is 0 */ + /* or even NaN and then throw an error */ + flag = CVodeSetInitStep(cvode_mem, $(double ss)); + if (check_flag(&flag, "CVodeSetInitStep", 1)) return 1; + } + + /* Set the Jacobian if there is one */ + if ($(int isJac)) { + flag = CVDlsSetJacFn(cvode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); + if (check_flag(&flag, "CVDlsSetJacFn", 1)) return 1; + } + + /* Store initial conditions */ + for (j = 0; j < NEQ; j++) { + ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); + } + + /* Main time-stepping loop: calls CVode to perform the integration */ + /* Stops when the final time has been reached */ + for (i = 1; i < $(int nTs); i++) { + + flag = CVode(cvode_mem, ($vec-ptr:(double *tMut))[i], y, &t, CV_NORMAL); /* call integrator */ + if (check_flag(&flag, "CVode", 1)) break; + + /* Store the results for Haskell */ + for (j = 0; j < NEQ; j++) { + ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); + } + + /* unsuccessful solve: break */ + if (flag < 0) { + fprintf(stderr,"Solver failure, stopping integration\n"); + break; + } + } + + /* Get some final statistics on how the solve progressed */ + + flag = CVodeGetNumSteps(cvode_mem, &nst); + check_flag(&flag, "CVodeGetNumSteps", 1); + ($vec-ptr:(long int *diagMut))[0] = nst; + + /* FIXME */ + ($vec-ptr:(long int *diagMut))[1] = 0; + + flag = CVodeGetNumRhsEvals(cvode_mem, &nfe); + check_flag(&flag, "CVodeGetNumRhsEvals", 1); + ($vec-ptr:(long int *diagMut))[2] = nfe; + /* FIXME */ + ($vec-ptr:(long int *diagMut))[3] = 0; + + flag = CVodeGetNumLinSolvSetups(cvode_mem, &nsetups); + check_flag(&flag, "CVodeGetNumLinSolvSetups", 1); + ($vec-ptr:(long int *diagMut))[4] = nsetups; + + flag = CVodeGetNumErrTestFails(cvode_mem, &netf); + check_flag(&flag, "CVodeGetNumErrTestFails", 1); + ($vec-ptr:(long int *diagMut))[5] = netf; + + flag = CVodeGetNumNonlinSolvIters(cvode_mem, &nni); + check_flag(&flag, "CVodeGetNumNonlinSolvIters", 1); + ($vec-ptr:(long int *diagMut))[6] = nni; + + flag = CVodeGetNumNonlinSolvConvFails(cvode_mem, &ncfn); + check_flag(&flag, "CVodeGetNumNonlinSolvConvFails", 1); + ($vec-ptr:(long int *diagMut))[7] = ncfn; + + flag = CVDlsGetNumJacEvals(cvode_mem, &nje); + check_flag(&flag, "CVDlsGetNumJacEvals", 1); + ($vec-ptr:(long int *diagMut))[8] = ncfn; + + flag = CVDlsGetNumRhsEvals(cvode_mem, &nfeLS); + check_flag(&flag, "CVDlsGetNumRhsEvals", 1); + ($vec-ptr:(long int *diagMut))[9] = ncfn; + /* Clean up and return */ N_VDestroy(y); /* Free y vector */ + N_VDestroy(tv); /* Free tv vector */ CVodeFree(&cvode_mem); /* Free integrator memory */ + SUNLinSolFree(LS); /* Free linear solver */ + SUNMatDestroy(A); /* Free A matrix */ return flag; } |] if res == 0 then do - return $ Left res + preD <- V.freeze diagMut + let d = SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) + m <- V.freeze qMatMut + return $ Right (m, d) else do return $ Left res -- cgit v1.2.3 From 729eb192cf77d4cddf33d2724b4409ab7d828921 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 25 Apr 2018 16:25:20 +0100 Subject: Pull out common code and start to follow gsl naming convention --- packages/sundials/hmatrix-sundials.cabal | 4 +- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 123 +++++------------ .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 153 ++++----------------- packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 78 +++++++++++ 4 files changed, 137 insertions(+), 221 deletions(-) create mode 100644 packages/sundials/src/Numeric/Sundials/ODEOpts.hs diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 4cc02c6..b7fa0fe 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal @@ -29,7 +29,8 @@ library sundials_cvode other-extensions: QuasiQuotes hs-source-dirs: src - exposed-modules: Numeric.Sundials.ARKode.ODE, + exposed-modules: Numeric.Sundials.ODEOpts, + Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE other-modules: Types, Arkode @@ -40,6 +41,7 @@ test-suite hmatrix-sundials-testsuite type: exitcode-stdio-1.0 main-is: Main.hs other-modules: Types, + Numeric.Sundials.ODEOpts, Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE, Arkode diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index e5a2e4d..8b713c6 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -117,7 +117,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve , ODEMethod(..) , StepControl(..) , Jacobian - , SundialsDiagnostics(..) ) where import qualified Language.C.Inline as C @@ -126,17 +125,15 @@ import qualified Language.C.Inline.Unsafe as CU import Data.Monoid ((<>)) import Data.Maybe (isJust) -import Foreign.C.Types +import Foreign.C.Types (CDouble, CInt, CLong) import Foreign.Ptr (Ptr) -import Foreign.ForeignPtr (newForeignPtr_) -import Foreign.Storable (Storable) import qualified Data.Vector.Storable as V -import qualified Data.Vector.Storable.Mutable as VM import Data.Coerce (coerce) import System.IO.Unsafe (unsafePerformIO) -import GHC.Generics +import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..), + from, conName) import Numeric.LinearAlgebra.Devel (createVector) @@ -147,6 +144,7 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), import qualified Types as T import Arkode import qualified Arkode as B +import qualified Numeric.Sundials.ODEOpts as SO C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -165,65 +163,6 @@ C.include "../../../helpers.h" C.include "Arkode_hsc.h" -getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) -getDataFromContents len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorFromC len rtr - --- FIXME: Potentially an instance of Storable -_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix -_getMatrixDataFromContents ptr = do - qtr <- B.getContentMatrixPtr ptr - rs <- B.getNRows qtr - cs <- B.getNCols qtr - rtr <- B.getMatrixData qtr - vs <- vectorFromC (fromIntegral $ rs * cs) rtr - return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } - -putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () -putMatrixDataFromContents mat ptr = do - let rs = T.rows mat - cs = T.cols mat - vs = T.vals mat - qtr <- B.getContentMatrixPtr ptr - B.putNRows rs qtr - B.putNCols cs qtr - rtr <- B.getMatrixData qtr - vectorToC vs (fromIntegral $ rs * cs) rtr --- FIXME: END - -putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () -putDataInContents vec len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorToC vec len rtr - --- Utils - -vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) -vectorFromC len ptr = do - ptr' <- newForeignPtr_ ptr - V.freeze $ VM.unsafeFromForeignPtr0 ptr' len - -vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () -vectorToC vec len ptr = do - ptr' <- newForeignPtr_ ptr - V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec - -data SundialsDiagnostics = SundialsDiagnostics { - aRKodeGetNumSteps :: Int - , aRKodeGetNumStepAttempts :: Int - , aRKodeGetNumRhsEvals_fe :: Int - , aRKodeGetNumRhsEvals_fi :: Int - , aRKodeGetNumLinSolvSetups :: Int - , aRKodeGetNumErrTestFails :: Int - , aRKodeGetNumNonlinSolvIters :: Int - , aRKodeGetNumNonlinSolvConvFails :: Int - , aRKDlsGetNumJacEvals :: Int - , aRKDlsGetNumRhsEvals :: Int - } deriving Show - type Jacobian = Double -> Vector Double -> Matrix Double -- | Stepping functions @@ -390,7 +329,7 @@ odeSolveV -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolveV meth hi epsAbs epsRel f y0 ts = - case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of + case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of Left c -> error $ show c -- FIXME -- FIXME: Can we do better than using lists? Right (v, _d) -> (nR >< nC) (V.toList v) @@ -410,7 +349,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Matrix Double -- ^ solution odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL - case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of + case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, _d) -> (nR >< nC) (V.toList v) where @@ -419,7 +358,7 @@ odeSolve f y0 ts = nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) -odeSolveVWith' :: +odeSolveVWith :: ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode @@ -432,15 +371,15 @@ odeSolveVWith' :: -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times -> Matrix Double -- ^ Error code or solution -odeSolveVWith' method control initStepSize f y0 tt = - case odeSolveVWith method control initStepSize f y0 tt of +odeSolveVWith method control initStepSize f y0 tt = + case odeSolveVWith' method control initStepSize f y0 tt of Left c -> error $ show c -- FIXME Right (v, _d) -> (nR >< nC) (V.toList v) where nR = V.length tt nC = V.length y0 -odeSolveVWith :: +odeSolveVWith' :: ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode @@ -452,8 +391,8 @@ odeSolveVWith :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith method control initStepSize f y0 tt = + -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution +odeSolveVWith' method control initStepSize f y0 tt = case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c @@ -482,7 +421,7 @@ solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution + -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do let isInitStepSize :: CInt @@ -514,9 +453,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> getDataFromContents dim y + fImm <- fun x <$> SO.getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - putDataInContents fImm dim f + SO.putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -528,8 +467,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" - Just jacI -> do j <- jacI t <$> getDataFromContents dim y - putMatrixDataFromContents j jacS + Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y + SO.putMatrixDataFromContents j jacS -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -704,16 +643,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if res == 0 then do preD <- V.freeze diagMut - let d = SundialsDiagnostics (fromIntegral $ preD V.!0) - (fromIntegral $ preD V.!1) - (fromIntegral $ preD V.!2) - (fromIntegral $ preD V.!3) - (fromIntegral $ preD V.!4) - (fromIntegral $ preD V.!5) - (fromIntegral $ preD V.!6) - (fromIntegral $ preD V.!7) - (fromIntegral $ preD V.!8) - (fromIntegral $ preD V.!9) + let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) m <- V.freeze qMatMut return $ Right (m, d) else do @@ -783,15 +722,15 @@ getButcherTable method = unsafePerformIO $ do btB2sMut <- V.thaw btB2s let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOI x y f _ptr = do - fImm <- funI x <$> getDataFromContents dim y - putDataInContents fImm dim f + fImm <- funI x <$> SO.getDataFromContents dim y + SO.putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOE x y f _ptr = do - fImm <- funE x <$> getDataFromContents dim y - putDataInContents fImm dim f + fImm <- funE x <$> SO.getDataFromContents dim y + SO.putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index abe1bfe..d7a2b53 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -61,46 +61,6 @@ -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) -- @ -- --- KVAERNO_4_2_3 --- --- \[ --- \begin{array}{c|cccc} --- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\ --- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ --- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ --- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ --- \hline --- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ --- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ --- \end{array} --- \] --- --- SDIRK_2_1_2 --- --- \[ --- \begin{array}{c|cc} --- 1.0 & 1.0 & 0.0 \\ --- 0.0 & -1.0 & 1.0 \\ --- \hline --- & 0.5 & 0.5 \\ --- & 1.0 & 0.0 \\ --- \end{array} --- \] --- --- SDIRK_5_3_4 --- --- \[ --- \begin{array}{c|ccccc} --- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ --- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ --- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ --- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ --- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ --- \hline --- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ --- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\ --- \end{array} --- \] ----------------------------------------------------------------------------- module Numeric.Sundials.CVode.ODE ( odeSolve , odeSolveV @@ -109,7 +69,6 @@ module Numeric.Sundials.CVode.ODE ( odeSolve , ODEMethod(..) , StepControl(..) , Jacobian - , SundialsDiagnostics(..) ) where import qualified Language.C.Inline as C @@ -118,13 +77,10 @@ import qualified Language.C.Inline.Unsafe as CU import Data.Monoid ((<>)) import Data.Maybe (isJust) -import Foreign.C.Types +import Foreign.C.Types (CDouble, CInt, CLong) import Foreign.Ptr (Ptr) -import Foreign.ForeignPtr (newForeignPtr_) -import Foreign.Storable (Storable) import qualified Data.Vector.Storable as V -import qualified Data.Vector.Storable.Mutable as VM import Data.Coerce (coerce) import System.IO.Unsafe (unsafePerformIO) @@ -136,7 +92,7 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), import qualified Types as T import Arkode (cV_ADAMS, cV_BDF) -import qualified Arkode as B +import qualified Numeric.Sundials.ODEOpts as SO C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -155,65 +111,6 @@ C.include "../../../helpers.h" C.include "Arkode_hsc.h" -getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) -getDataFromContents len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorFromC len rtr - --- FIXME: Potentially an instance of Storable -_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix -_getMatrixDataFromContents ptr = do - qtr <- B.getContentMatrixPtr ptr - rs <- B.getNRows qtr - cs <- B.getNCols qtr - rtr <- B.getMatrixData qtr - vs <- vectorFromC (fromIntegral $ rs * cs) rtr - return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } - -putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () -putMatrixDataFromContents mat ptr = do - let rs = T.rows mat - cs = T.cols mat - vs = T.vals mat - qtr <- B.getContentMatrixPtr ptr - B.putNRows rs qtr - B.putNCols cs qtr - rtr <- B.getMatrixData qtr - vectorToC vs (fromIntegral $ rs * cs) rtr --- FIXME: END - -putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () -putDataInContents vec len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorToC vec len rtr - --- Utils - -vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) -vectorFromC len ptr = do - ptr' <- newForeignPtr_ ptr - V.freeze $ VM.unsafeFromForeignPtr0 ptr' len - -vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () -vectorToC vec len ptr = do - ptr' <- newForeignPtr_ ptr - V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec - -data SundialsDiagnostics = SundialsDiagnostics { - aRKodeGetNumSteps :: Int - , aRKodeGetNumStepAttempts :: Int - , aRKodeGetNumRhsEvals_fe :: Int - , aRKodeGetNumRhsEvals_fi :: Int - , aRKodeGetNumLinSolvSetups :: Int - , aRKodeGetNumErrTestFails :: Int - , aRKodeGetNumNonlinSolvIters :: Int - , aRKodeGetNumNonlinSolvConvFails :: Int - , aRKDlsGetNumJacEvals :: Int - , aRKDlsGetNumRhsEvals :: Int - } deriving Show - type Jacobian = Double -> Vector Double -> Matrix Double -- | Stepping functions @@ -243,7 +140,7 @@ odeSolveV -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolveV meth hi epsAbs epsRel f y0 ts = - case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of + case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of Left c -> error $ show c -- FIXME -- FIXME: Can we do better than using lists? Right (v, _d) -> (nR >< nC) (V.toList v) @@ -263,7 +160,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Matrix Double -- ^ solution odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL - case odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of + case odeSolveVWith' BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, _d) -> (nR >< nC) (V.toList v) where @@ -272,7 +169,7 @@ odeSolve f y0 ts = nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) -odeSolveVWith' :: +odeSolveVWith :: ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode @@ -285,15 +182,15 @@ odeSolveVWith' :: -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times -> Matrix Double -- ^ Error code or solution -odeSolveVWith' method control initStepSize f y0 tt = - case odeSolveVWith method control initStepSize f y0 tt of +odeSolveVWith method control initStepSize f y0 tt = + case odeSolveVWith' method control initStepSize f y0 tt of Left c -> error $ show c -- FIXME Right (v, _d) -> (nR >< nC) (V.toList v) where nR = V.length tt nC = V.length y0 -odeSolveVWith :: +odeSolveVWith' :: ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode @@ -305,8 +202,8 @@ odeSolveVWith :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith method control initStepSize f y0 tt = + -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution +odeSolveVWith' method control initStepSize f y0 tt = case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c @@ -335,7 +232,7 @@ solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution + -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do let isInitStepSize :: CInt @@ -366,9 +263,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> getDataFromContents dim y + fImm <- fun x <$> SO.getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - putDataInContents fImm dim f + SO.putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -380,8 +277,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" - Just jacI -> do j <- jacI t <$> getDataFromContents dim y - putMatrixDataFromContents j jacS + Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y + SO.putMatrixDataFromContents j jacS -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -541,16 +438,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if res == 0 then do preD <- V.freeze diagMut - let d = SundialsDiagnostics (fromIntegral $ preD V.!0) - (fromIntegral $ preD V.!1) - (fromIntegral $ preD V.!2) - (fromIntegral $ preD V.!3) - (fromIntegral $ preD V.!4) - (fromIntegral $ preD V.!5) - (fromIntegral $ preD V.!6) - (fromIntegral $ preD V.!7) - (fromIntegral $ preD V.!8) - (fromIntegral $ preD V.!9) + let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) m <- V.freeze qMatMut return $ Right (m, d) else do diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs new file mode 100644 index 0000000..e924292 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs @@ -0,0 +1,78 @@ +module Numeric.Sundials.ODEOpts where + +import Data.Int (Int32) +import Foreign.Ptr (Ptr) +import Foreign.Storable as FS +import Foreign.ForeignPtr as FF +import Foreign.C.Types +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VM + +import qualified Types as T +import qualified Arkode as B + +data ODEOpts = ODEOpts { + maxNumSteps :: Int32 + , minStep :: Double + , relTol :: Double + , absTols :: VS.Vector Double + , initStep :: Double + } deriving (Read, Show, Eq, Ord) + +-- FIXME: Potentially an instance of Storable +_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix +_getMatrixDataFromContents ptr = do + qtr <- B.getContentMatrixPtr ptr + rs <- B.getNRows qtr + cs <- B.getNCols qtr + rtr <- B.getMatrixData qtr + vs <- vectorFromC (fromIntegral $ rs * cs) rtr + return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } + +putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () +putMatrixDataFromContents mat ptr = do + let rs = T.rows mat + cs = T.cols mat + vs = T.vals mat + qtr <- B.getContentMatrixPtr ptr + B.putNRows rs qtr + B.putNCols cs qtr + rtr <- B.getMatrixData qtr + vectorToC vs (fromIntegral $ rs * cs) rtr +-- FIXME: END + +vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) +vectorFromC len ptr = do + ptr' <- newForeignPtr_ ptr + VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len + +vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () +vectorToC vec len ptr = do + ptr' <- newForeignPtr_ ptr + VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec + +getDataFromContents :: Int -> Ptr T.SunVector -> IO (VS.Vector CDouble) +getDataFromContents len ptr = do + qtr <- B.getContentPtr ptr + rtr <- B.getData qtr + vectorFromC len rtr + +putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () +putDataInContents vec len ptr = do + qtr <- B.getContentPtr ptr + rtr <- B.getData qtr + vectorToC vec len rtr + +data SundialsDiagnostics = SundialsDiagnostics { + aRKodeGetNumSteps :: Int + , aRKodeGetNumStepAttempts :: Int + , aRKodeGetNumRhsEvals_fe :: Int + , aRKodeGetNumRhsEvals_fi :: Int + , aRKodeGetNumLinSolvSetups :: Int + , aRKodeGetNumErrTestFails :: Int + , aRKodeGetNumNonlinSolvIters :: Int + , aRKodeGetNumNonlinSolvConvFails :: Int + , aRKDlsGetNumJacEvals :: Int + , aRKDlsGetNumRhsEvals :: Int + } deriving Show + -- cgit v1.2.3 From e8e631c7dbc7ea34b51dcce5fef5e2ec620f9458 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Thu, 26 Apr 2018 13:53:56 +0100 Subject: Refactor CVODE --- .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 75 ++++++++++------------ packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 4 ++ 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index d7a2b53..0871f9b 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -87,11 +87,12 @@ import System.IO.Unsafe (unsafePerformIO) import Numeric.LinearAlgebra.Devel (createVector) -import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), - rows, cols, toLists, size) +import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, + cols, toLists, size, reshape) import qualified Types as T import Arkode (cV_ADAMS, cV_BDF) +import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian) import qualified Numeric.Sundials.ODEOpts as SO @@ -111,8 +112,6 @@ C.include "../../../helpers.h" C.include "Arkode_hsc.h" -type Jacobian = Double -> Vector Double -> Matrix Double - -- | Stepping functions data ODEMethod = ADAMS | BDF @@ -140,14 +139,8 @@ odeSolveV -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolveV meth hi epsAbs epsRel f y0 ts = - case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of - Left c -> error $ show c -- FIXME - -- FIXME: Can we do better than using lists? - Right (v, _d) -> (nR >< nC) (V.toList v) + odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts where - us = toList ts - nR = length us - nC = size y0 g t x0 = coerce $ f t x0 -- | A version of 'odeSolveV' with reasonable default parameters and @@ -160,13 +153,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Matrix Double -- ^ solution odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL - case odeSolveVWith' BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of - Left c -> error $ show c -- FIXME - Right (v, _d) -> (nR >< nC) (V.toList v) + odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) where - us = toList ts - nR = length us - nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) odeSolveVWith :: @@ -183,15 +171,20 @@ odeSolveVWith :: -> V.Vector Double -- ^ Desired solution times -> Matrix Double -- ^ Error code or solution odeSolveVWith method control initStepSize f y0 tt = - case odeSolveVWith' method control initStepSize f y0 tt of + case odeSolveVWith' opts method control initStepSize f y0 tt of Left c -> error $ show c -- FIXME - Right (v, _d) -> (nR >< nC) (V.toList v) + Right (v, _d) -> v where - nR = V.length tt - nC = V.length y0 + opts = ODEOpts { maxNumSteps = 10000 + , minStep = 1.0e-12 + , relTol = error "relTol" + , absTols = error "absTol" + , initStep = error "initStep" + } odeSolveVWith' :: - ODEMethod + ODEOpts + -> ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode -- estimates the initial step size to be the @@ -202,19 +195,21 @@ odeSolveVWith' :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith' method control initStepSize f y0 tt = - case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) + -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution +odeSolveVWith' opts method control initStepSize f y0 tt = + case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) + (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c - Right (v, d) -> Right (coerce v, d) + Right (v, d) -> Right (reshape nC (coerce v), d) where + nC = V.length y0 l = size y0 - scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) - scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) - scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) + scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) + scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) + scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) -- FIXME; Should we check that the length of ss is correct? - scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) + scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ getJacobian method matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } @@ -225,6 +220,8 @@ odeSolveVWith' method control initStepSize f y0 tt = vs = V.fromList $ map coerce $ concat $ toLists m solveOdeC :: + CLong -> + CDouble -> CInt -> Maybe CDouble -> (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> @@ -233,7 +230,8 @@ solveOdeC :: -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution -solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do +solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = + unsafePerformIO $ do let isInitStepSize :: CInt isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize @@ -249,10 +247,6 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO nEq = fromIntegral dim nTs :: CInt nTs = fromIntegral $ V.length ts - -- FIXME: tMut is not actually mutatated? - tMut <- V.thaw ts - -- FIXME: I believe this gets taken from the ghc heap and so should - -- be subject to garbage collection. quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) qMatMut <- V.thaw quasiMatrixRes diagnostics :: V.Vector CLong <- createVector 10 -- FIXME @@ -324,18 +318,17 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; /* Specify tolerances */ for (i = 0; i < NEQ; i++) { - NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; + NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; }; - /* FIXME: A hack for initial testing */ - flag = CVodeSetMinStep(cvode_mem, 1.0e-12); + flag = CVodeSetMinStep(cvode_mem, $(double minStep_)); if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; - flag = CVodeSetMaxNumSteps(cvode_mem, 10000); + flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_)); if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; /* Call CVodeSVtolerances to specify the scalar relative tolerance * and vector absolute tolerances */ - flag = CVodeSVtolerances(cvode_mem, $(double relTol), tv); + flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv); if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); /* Initialize dense matrix data structure and solver */ @@ -371,7 +364,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO /* Stops when the final time has been reached */ for (i = 1; i < $(int nTs); i++) { - flag = CVode(cvode_mem, ($vec-ptr:(double *tMut))[i], y, &t, CV_NORMAL); /* call integrator */ + flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */ if (check_flag(&flag, "CVode", 1)) break; /* Store the results for Haskell */ diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index e924292..538b474 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs @@ -8,9 +8,13 @@ import Foreign.C.Types import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VM +import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) + import qualified Types as T import qualified Arkode as B +type Jacobian = Double -> Vector Double -> Matrix Double + data ODEOpts = ODEOpts { maxNumSteps :: Int32 , minStep :: Double -- cgit v1.2.3 From 0e12d0aa99adbf83d5a80211a2f9fd13e4880901 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 27 Apr 2018 09:06:41 +0100 Subject: Start of better naming --- packages/sundials/hmatrix-sundials.cabal | 8 +- packages/sundials/src/Arkode.hsc | 120 --------------------- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 14 ++- packages/sundials/src/Numeric/Sundials/Arkode.hsc | 120 +++++++++++++++++++++ .../src/Numeric/Sundials/CLangToHaskellTypes.hs | 37 +++++++ .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 8 +- packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 4 +- packages/sundials/src/Types.hs | 40 ------- 8 files changed, 173 insertions(+), 178 deletions(-) delete mode 100644 packages/sundials/src/Arkode.hsc create mode 100644 packages/sundials/src/Numeric/Sundials/Arkode.hsc create mode 100644 packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs delete mode 100644 packages/sundials/src/Types.hs diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index b7fa0fe..234bb9c 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal @@ -32,19 +32,19 @@ library exposed-modules: Numeric.Sundials.ODEOpts, Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE - other-modules: Types, - Arkode + other-modules: Numeric.Sundials.CLangToHaskellTypes, + Numeric.Sundials.Arkode c-sources: src/helpers.c src/helpers.h default-language: Haskell2010 test-suite hmatrix-sundials-testsuite type: exitcode-stdio-1.0 main-is: Main.hs - other-modules: Types, + other-modules: Numeric.Sundials.CLangToHaskellTypes, Numeric.Sundials.ODEOpts, Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE, - Arkode + Numeric.Sundials.Arkode build-depends: base >=4.10 && <4.11, inline-c >=0.6 && <0.7, vector >=0.12 && <0.13, diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Arkode.hsc deleted file mode 100644 index 558ce9e..0000000 --- a/packages/sundials/src/Arkode.hsc +++ /dev/null @@ -1,120 +0,0 @@ -module Arkode where - -import Foreign -import Foreign.C.Types - - -#include -#include -#include -#include -#include -#include -#include - - -#def typedef struct _generic_N_Vector SunVector; -#def typedef struct _N_VectorContent_Serial SunContent; - -#def typedef struct _generic_SUNMatrix SunMatrix; -#def typedef struct _SUNMatrixContent_Dense SunMatrixContent; - -getContentMatrixPtr :: Storable a => Ptr b -> IO a -getContentMatrixPtr ptr = (#peek SunMatrix, content) ptr - -getNRows :: Ptr b -> IO CInt -getNRows ptr = (#peek SunMatrixContent, M) ptr -putNRows :: CInt -> Ptr b -> IO () -putNRows nr ptr = (#poke SunMatrixContent, M) ptr nr - -getNCols :: Ptr b -> IO CInt -getNCols ptr = (#peek SunMatrixContent, N) ptr -putNCols :: CInt -> Ptr b -> IO () -putNCols nc ptr = (#poke SunMatrixContent, N) ptr nc - -getMatrixData :: Storable a => Ptr b -> IO a -getMatrixData ptr = (#peek SunMatrixContent, data) ptr - -getContentPtr :: Storable a => Ptr b -> IO a -getContentPtr ptr = (#peek SunVector, content) ptr - -getData :: Storable a => Ptr b -> IO a -getData ptr = (#peek SunContent, data) ptr - -cV_ADAMS :: Int -cV_ADAMS = #const CV_ADAMS -cV_BDF :: Int -cV_BDF = #const CV_BDF - -arkSMax :: Int -arkSMax = #const ARK_S_MAX - -mIN_DIRK_NUM, mAX_DIRK_NUM :: Int -mIN_DIRK_NUM = #const MIN_DIRK_NUM -mAX_DIRK_NUM = #const MAX_DIRK_NUM - --- FIXME: We could just use inline-c instead - --- Butcher table accessors -- implicit -sDIRK_2_1_2 :: Int -sDIRK_2_1_2 = #const SDIRK_2_1_2 -bILLINGTON_3_3_2 :: Int -bILLINGTON_3_3_2 = #const BILLINGTON_3_3_2 -tRBDF2_3_3_2 :: Int -tRBDF2_3_3_2 = #const TRBDF2_3_3_2 -kVAERNO_4_2_3 :: Int -kVAERNO_4_2_3 = #const KVAERNO_4_2_3 -aRK324L2SA_DIRK_4_2_3 :: Int -aRK324L2SA_DIRK_4_2_3 = #const ARK324L2SA_DIRK_4_2_3 -cASH_5_2_4 :: Int -cASH_5_2_4 = #const CASH_5_2_4 -cASH_5_3_4 :: Int -cASH_5_3_4 = #const CASH_5_3_4 -sDIRK_5_3_4 :: Int -sDIRK_5_3_4 = #const SDIRK_5_3_4 -kVAERNO_5_3_4 :: Int -kVAERNO_5_3_4 = #const KVAERNO_5_3_4 -aRK436L2SA_DIRK_6_3_4 :: Int -aRK436L2SA_DIRK_6_3_4 = #const ARK436L2SA_DIRK_6_3_4 -kVAERNO_7_4_5 :: Int -kVAERNO_7_4_5 = #const KVAERNO_7_4_5 -aRK548L2SA_DIRK_8_4_5 :: Int -aRK548L2SA_DIRK_8_4_5 = #const ARK548L2SA_DIRK_8_4_5 - --- #define DEFAULT_DIRK_2 SDIRK_2_1_2 --- #define DEFAULT_DIRK_3 ARK324L2SA_DIRK_4_2_3 --- #define DEFAULT_DIRK_4 SDIRK_5_3_4 --- #define DEFAULT_DIRK_5 ARK548L2SA_DIRK_8_4_5 - --- Butcher table accessors -- explicit -hEUN_EULER_2_1_2 :: Int -hEUN_EULER_2_1_2 = #const HEUN_EULER_2_1_2 -bOGACKI_SHAMPINE_4_2_3 :: Int -bOGACKI_SHAMPINE_4_2_3 = #const BOGACKI_SHAMPINE_4_2_3 -aRK324L2SA_ERK_4_2_3 :: Int -aRK324L2SA_ERK_4_2_3 = #const ARK324L2SA_ERK_4_2_3 -zONNEVELD_5_3_4 :: Int -zONNEVELD_5_3_4 = #const ZONNEVELD_5_3_4 -aRK436L2SA_ERK_6_3_4 :: Int -aRK436L2SA_ERK_6_3_4 = #const ARK436L2SA_ERK_6_3_4 -sAYFY_ABURUB_6_3_4 :: Int -sAYFY_ABURUB_6_3_4 = #const SAYFY_ABURUB_6_3_4 -cASH_KARP_6_4_5 :: Int -cASH_KARP_6_4_5 = #const CASH_KARP_6_4_5 -fEHLBERG_6_4_5 :: Int -fEHLBERG_6_4_5 = #const FEHLBERG_6_4_5 -dORMAND_PRINCE_7_4_5 :: Int -dORMAND_PRINCE_7_4_5 = #const DORMAND_PRINCE_7_4_5 -aRK548L2SA_ERK_8_4_5 :: Int -aRK548L2SA_ERK_8_4_5 = #const ARK548L2SA_ERK_8_4_5 -vERNER_8_5_6 :: Int -vERNER_8_5_6 = #const VERNER_8_5_6 -fEHLBERG_13_7_8 :: Int -fEHLBERG_13_7_8 = #const FEHLBERG_13_7_8 - --- #define DEFAULT_ERK_2 HEUN_EULER_2_1_2 --- #define DEFAULT_ERK_3 BOGACKI_SHAMPINE_4_2_3 --- #define DEFAULT_ERK_4 ZONNEVELD_5_3_4 --- #define DEFAULT_ERK_5 CASH_KARP_6_4_5 --- #define DEFAULT_ERK_6 VERNER_8_5_6 --- #define DEFAULT_ERK_8 FEHLBERG_13_7_8 diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 8b713c6..a8d418b 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -1,5 +1,3 @@ -{-# OPTIONS_GHC -Wall #-} - {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE MultiWayIf #-} @@ -141,9 +139,9 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix, rows, cols, toLists, size, subVector) -import qualified Types as T -import Arkode -import qualified Arkode as B +import qualified Numeric.Sundials.CLangToHaskellTypes as T +import Numeric.Sundials.Arkode +import qualified Numeric.Sundials.Arkode as B import qualified Numeric.Sundials.ODEOpts as SO @@ -160,7 +158,7 @@ C.include "" -- access to ARKDls interface C.include "" -- definition of type realtype C.include "" C.include "../../../helpers.h" -C.include "Arkode_hsc.h" +C.include "Numeric/Sundials/Arkode_hsc.h" type Jacobian = Double -> Vector Double -> Matrix Double @@ -448,7 +446,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO diagnostics :: V.Vector CLong <- createVector 10 -- FIXME diagMut <- V.thaw diagnostics -- We need the types that sundials expects. These are tied together - -- in 'Types'. FIXME: The Haskell type is currently empty! + -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then @@ -516,7 +514,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO /* problem as fully implicit and set f_E to NULL and f_I to f. */ /* Here we use the C types defined in helpers.h which tie up with */ - /* the Haskell types defined in Types */ + /* the Haskell types defined in CLangToHaskellTypes */ if ($(int method) < MIN_DIRK_NUM) { flag = ARKodeInit(arkode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), NULL, T0, y); if (check_flag(&flag, "ARKodeInit", 1)) return 1; diff --git a/packages/sundials/src/Numeric/Sundials/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc new file mode 100644 index 0000000..f5e5dc1 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc @@ -0,0 +1,120 @@ +|module Numeric.Sundials.Arkode where + +import Foreign +import Foreign.C.Types + + +#include +#include +#include +#include +#include +#include +#include + + +#def typedef struct _generic_N_Vector SunVector; +#def typedef struct _N_VectorContent_Serial SunContent; + +#def typedef struct _generic_SUNMatrix SunMatrix; +#def typedef struct _SUNMatrixContent_Dense SunMatrixContent; + +getContentMatrixPtr :: Storable a => Ptr b -> IO a +getContentMatrixPtr ptr = (#peek SunMatrix, content) ptr + +getNRows :: Ptr b -> IO CInt +getNRows ptr = (#peek SunMatrixContent, M) ptr +putNRows :: CInt -> Ptr b -> IO () +putNRows nr ptr = (#poke SunMatrixContent, M) ptr nr + +getNCols :: Ptr b -> IO CInt +getNCols ptr = (#peek SunMatrixContent, N) ptr +putNCols :: CInt -> Ptr b -> IO () +putNCols nc ptr = (#poke SunMatrixContent, N) ptr nc + +getMatrixData :: Storable a => Ptr b -> IO a +getMatrixData ptr = (#peek SunMatrixContent, data) ptr + +getContentPtr :: Storable a => Ptr b -> IO a +getContentPtr ptr = (#peek SunVector, content) ptr + +getData :: Storable a => Ptr b -> IO a +getData ptr = (#peek SunContent, data) ptr + +cV_ADAMS :: Int +cV_ADAMS = #const CV_ADAMS +cV_BDF :: Int +cV_BDF = #const CV_BDF + +arkSMax :: Int +arkSMax = #const ARK_S_MAX + +mIN_DIRK_NUM, mAX_DIRK_NUM :: Int +mIN_DIRK_NUM = #const MIN_DIRK_NUM +mAX_DIRK_NUM = #const MAX_DIRK_NUM + +-- FIXME: We could just use inline-c instead + +-- Butcher table accessors -- implicit +sDIRK_2_1_2 :: Int +sDIRK_2_1_2 = #const SDIRK_2_1_2 +bILLINGTON_3_3_2 :: Int +bILLINGTON_3_3_2 = #const BILLINGTON_3_3_2 +tRBDF2_3_3_2 :: Int +tRBDF2_3_3_2 = #const TRBDF2_3_3_2 +kVAERNO_4_2_3 :: Int +kVAERNO_4_2_3 = #const KVAERNO_4_2_3 +aRK324L2SA_DIRK_4_2_3 :: Int +aRK324L2SA_DIRK_4_2_3 = #const ARK324L2SA_DIRK_4_2_3 +cASH_5_2_4 :: Int +cASH_5_2_4 = #const CASH_5_2_4 +cASH_5_3_4 :: Int +cASH_5_3_4 = #const CASH_5_3_4 +sDIRK_5_3_4 :: Int +sDIRK_5_3_4 = #const SDIRK_5_3_4 +kVAERNO_5_3_4 :: Int +kVAERNO_5_3_4 = #const KVAERNO_5_3_4 +aRK436L2SA_DIRK_6_3_4 :: Int +aRK436L2SA_DIRK_6_3_4 = #const ARK436L2SA_DIRK_6_3_4 +kVAERNO_7_4_5 :: Int +kVAERNO_7_4_5 = #const KVAERNO_7_4_5 +aRK548L2SA_DIRK_8_4_5 :: Int +aRK548L2SA_DIRK_8_4_5 = #const ARK548L2SA_DIRK_8_4_5 + +-- #define DEFAULT_DIRK_2 SDIRK_2_1_2 +-- #define DEFAULT_DIRK_3 ARK324L2SA_DIRK_4_2_3 +-- #define DEFAULT_DIRK_4 SDIRK_5_3_4 +-- #define DEFAULT_DIRK_5 ARK548L2SA_DIRK_8_4_5 + +-- Butcher table accessors -- explicit +hEUN_EULER_2_1_2 :: Int +hEUN_EULER_2_1_2 = #const HEUN_EULER_2_1_2 +bOGACKI_SHAMPINE_4_2_3 :: Int +bOGACKI_SHAMPINE_4_2_3 = #const BOGACKI_SHAMPINE_4_2_3 +aRK324L2SA_ERK_4_2_3 :: Int +aRK324L2SA_ERK_4_2_3 = #const ARK324L2SA_ERK_4_2_3 +zONNEVELD_5_3_4 :: Int +zONNEVELD_5_3_4 = #const ZONNEVELD_5_3_4 +aRK436L2SA_ERK_6_3_4 :: Int +aRK436L2SA_ERK_6_3_4 = #const ARK436L2SA_ERK_6_3_4 +sAYFY_ABURUB_6_3_4 :: Int +sAYFY_ABURUB_6_3_4 = #const SAYFY_ABURUB_6_3_4 +cASH_KARP_6_4_5 :: Int +cASH_KARP_6_4_5 = #const CASH_KARP_6_4_5 +fEHLBERG_6_4_5 :: Int +fEHLBERG_6_4_5 = #const FEHLBERG_6_4_5 +dORMAND_PRINCE_7_4_5 :: Int +dORMAND_PRINCE_7_4_5 = #const DORMAND_PRINCE_7_4_5 +aRK548L2SA_ERK_8_4_5 :: Int +aRK548L2SA_ERK_8_4_5 = #const ARK548L2SA_ERK_8_4_5 +vERNER_8_5_6 :: Int +vERNER_8_5_6 = #const VERNER_8_5_6 +fEHLBERG_13_7_8 :: Int +fEHLBERG_13_7_8 = #const FEHLBERG_13_7_8 + +-- #define DEFAULT_ERK_2 HEUN_EULER_2_1_2 +-- #define DEFAULT_ERK_3 BOGACKI_SHAMPINE_4_2_3 +-- #define DEFAULT_ERK_4 ZONNEVELD_5_3_4 +-- #define DEFAULT_ERK_5 CASH_KARP_6_4_5 +-- #define DEFAULT_ERK_6 VERNER_8_5_6 +-- #define DEFAULT_ERK_8 FEHLBERG_13_7_8 diff --git a/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs b/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs new file mode 100644 index 0000000..0908cbe --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs @@ -0,0 +1,37 @@ +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE EmptyDataDecls #-} + +module Numeric.Sundials.CLangToHaskellTypes where + +import Foreign.C.Types + +import qualified Language.Haskell.TH as TH +import qualified Language.C.Types as CT +import qualified Data.Map as Map +import Language.C.Inline.Context + +import qualified Data.Vector.Storable as V + + +data SunVector +data SunMatrix = SunMatrix { rows :: CInt + , cols :: CInt + , vals :: V.Vector CDouble + } + +-- | This is true only if configured/ built as 64 bits +type SunIndexType = CLong + +sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ +sunTypesTable = Map.fromList + [ + (CT.TypeName "sunindextype", [t| SunIndexType |] ) + , (CT.TypeName "SunVector", [t| SunVector |] ) + , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) + ] + +sunCtx :: Context +sunCtx = mempty {ctxTypesTable = sunTypesTable} + diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index 0871f9b..1cd072f 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -90,8 +90,8 @@ import Numeric.LinearAlgebra.Devel (createVector) import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, cols, toLists, size, reshape) -import qualified Types as T -import Arkode (cV_ADAMS, cV_BDF) +import qualified Numeric.Sundials.CLangToHaskellTypes as T +import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF) import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian) import qualified Numeric.Sundials.ODEOpts as SO @@ -109,7 +109,7 @@ C.include "" -- access to CVDls interface C.include "" -- definition of type realtype C.include "" C.include "../../../helpers.h" -C.include "Arkode_hsc.h" +C.include "Numeric/Sundials/Arkode_hsc.h" -- | Stepping functions @@ -252,7 +252,7 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts diagnostics :: V.Vector CLong <- createVector 10 -- FIXME diagMut <- V.thaw diagnostics -- We need the types that sundials expects. These are tied together - -- in 'Types'. FIXME: The Haskell type is currently empty! + -- in 'CLangToHaskellTypes'. FIXME: The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index 538b474..56dc12c 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs @@ -10,8 +10,8 @@ import qualified Data.Vector.Storable.Mutable as VM import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) -import qualified Types as T -import qualified Arkode as B +import qualified Numeric.Sundials.CLangToHaskellTypes as T +import qualified Numeric.Sundials.Arkode as B type Jacobian = Double -> Vector Double -> Matrix Double diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs deleted file mode 100644 index 04e4280..0000000 --- a/packages/sundials/src/Types.hs +++ /dev/null @@ -1,40 +0,0 @@ -{-# OPTIONS_GHC -Wall #-} - -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE EmptyDataDecls #-} - -module Types where - -import Foreign.C.Types - -import qualified Language.Haskell.TH as TH -import qualified Language.C.Types as CT -import qualified Data.Map as Map -import Language.C.Inline.Context - -import qualified Data.Vector.Storable as V - - -data SunVector -data SunMatrix = SunMatrix { rows :: CInt - , cols :: CInt - , vals :: V.Vector CDouble - } - --- FIXME: Is this true? -type SunIndexType = CLong - -sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ -sunTypesTable = Map.fromList - [ - (CT.TypeName "sunindextype", [t| SunIndexType |] ) - , (CT.TypeName "SunVector", [t| SunVector |] ) - , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) - ] - -sunCtx :: Context -sunCtx = mempty {ctxTypesTable = sunTypesTable} - -- cgit v1.2.3 From d48892298d5e87ec12b29adc69af2fdd59f4c8bd Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 27 Apr 2018 09:17:16 +0100 Subject: Fix typo --- packages/sundials/src/Numeric/Sundials/Arkode.hsc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/sundials/src/Numeric/Sundials/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc index f5e5dc1..1700cdf 100644 --- a/packages/sundials/src/Numeric/Sundials/Arkode.hsc +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc @@ -1,4 +1,4 @@ -|module Numeric.Sundials.Arkode where +module Numeric.Sundials.Arkode where import Foreign import Foreign.C.Types -- cgit v1.2.3 From 149dedfc6ec8dea039a4df7ad1d31880820c52eb Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 27 Apr 2018 14:19:59 +0100 Subject: More restructuring --- packages/sundials/hmatrix-sundials.cabal | 6 +- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 56 ++++++++++---- packages/sundials/src/Numeric/Sundials/Arkode.hsc | 88 +++++++++++++++++++++- .../src/Numeric/Sundials/CLangToHaskellTypes.hs | 37 --------- .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 41 +++++----- packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 51 ------------- 6 files changed, 149 insertions(+), 130 deletions(-) delete mode 100644 packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 234bb9c..cd2be4e 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal @@ -32,16 +32,14 @@ library exposed-modules: Numeric.Sundials.ODEOpts, Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE - other-modules: Numeric.Sundials.CLangToHaskellTypes, - Numeric.Sundials.Arkode + other-modules: Numeric.Sundials.Arkode c-sources: src/helpers.c src/helpers.h default-language: Haskell2010 test-suite hmatrix-sundials-testsuite type: exitcode-stdio-1.0 main-is: Main.hs - other-modules: Numeric.Sundials.CLangToHaskellTypes, - Numeric.Sundials.ODEOpts, + other-modules: Numeric.Sundials.ODEOpts, Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE, Numeric.Sundials.Arkode diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index a8d418b..13b7eb8 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -125,6 +125,7 @@ import Data.Maybe (isJust) import Foreign.C.Types (CDouble, CInt, CLong) import Foreign.Ptr (Ptr) +import Foreign.Storable (poke) import qualified Data.Vector.Storable as V @@ -139,10 +140,33 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix, rows, cols, toLists, size, subVector) -import qualified Numeric.Sundials.CLangToHaskellTypes as T -import Numeric.Sundials.Arkode -import qualified Numeric.Sundials.Arkode as B import qualified Numeric.Sundials.ODEOpts as SO +import qualified Numeric.Sundials.Arkode as T +import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, + sDIRK_2_1_2, + bILLINGTON_3_3_2, + tRBDF2_3_3_2, + kVAERNO_4_2_3, + aRK324L2SA_DIRK_4_2_3, + cASH_5_2_4, + cASH_5_3_4, + sDIRK_5_3_4, + kVAERNO_5_3_4, + aRK436L2SA_DIRK_6_3_4, + kVAERNO_7_4_5, + aRK548L2SA_DIRK_8_4_5, + hEUN_EULER_2_1_2, + bOGACKI_SHAMPINE_4_2_3, + aRK324L2SA_ERK_4_2_3, + zONNEVELD_5_3_4, + aRK436L2SA_ERK_6_3_4, + sAYFY_ABURUB_6_3_4, + cASH_KARP_6_4_5, + fEHLBERG_6_4_5, + dORMAND_PRINCE_7_4_5, + aRK548L2SA_ERK_8_4_5, + vERNER_8_5_6, + fEHLBERG_13_7_8) C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -451,9 +475,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> SO.getDataFromContents dim y + fImm <- fun x <$> getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - SO.putDataInContents fImm dim f + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -465,8 +489,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" - Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y - SO.putMatrixDataFromContents j jacS + Just jacI -> do j <- jacI t <$> getDataFromContents dim y + poke jacS j -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -675,7 +699,7 @@ butcherTable method = case getBT method of Left c -> error $ show c -- FIXME Right (ButcherTable' v w x y, sqp) -> - ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) + ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v) , cv = subVector 0 s w , bv = subVector 0 s x , b2v = subVector 0 s y @@ -710,25 +734,25 @@ getButcherTable method = unsafePerformIO $ do btSQP :: V.Vector CInt <- createVector 3 btSQPMut <- V.thaw btSQP - btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) + btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax) btAsMut <- V.thaw btAs - btCs :: V.Vector CDouble <- createVector B.arkSMax - btBs :: V.Vector CDouble <- createVector B.arkSMax - btB2s :: V.Vector CDouble <- createVector B.arkSMax + btCs :: V.Vector CDouble <- createVector arkSMax + btBs :: V.Vector CDouble <- createVector arkSMax + btB2s :: V.Vector CDouble <- createVector arkSMax btCsMut <- V.thaw btCs btBsMut <- V.thaw btBs btB2sMut <- V.thaw btB2s let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOI x y f _ptr = do - fImm <- funI x <$> SO.getDataFromContents dim y - SO.putDataInContents fImm dim f + fImm <- funI x <$> getDataFromContents dim y + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOE x y f _ptr = do - fImm <- funE x <$> SO.getDataFromContents dim y - SO.putDataInContents fImm dim f + fImm <- funE x <$> getDataFromContents dim y + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] diff --git a/packages/sundials/src/Numeric/Sundials/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc index 1700cdf..0850258 100644 --- a/packages/sundials/src/Numeric/Sundials/Arkode.hsc +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc @@ -1,7 +1,23 @@ +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE EmptyDataDecls #-} + module Numeric.Sundials.Arkode where -import Foreign -import Foreign.C.Types +import Foreign +import Foreign.C.Types + +import Language.C.Types as CT + +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VM + +import qualified Language.Haskell.TH as TH +import qualified Data.Map as Map +import Language.C.Inline.Context + +import qualified Data.Vector.Storable as V #include @@ -13,6 +29,74 @@ import Foreign.C.Types #include +data SunVector +data SunMatrix = SunMatrix { rows :: CInt + , cols :: CInt + , vals :: V.Vector CDouble + } + +-- | This is true only if configured/ built as 64 bits +type SunIndexType = CLong + +sunTypesTable :: Map.Map TypeSpecifier TH.TypeQ +sunTypesTable = Map.fromList + [ + (TypeName "sunindextype", [t| SunIndexType |] ) + , (TypeName "SunVector", [t| SunVector |] ) + , (TypeName "SunMatrix", [t| SunMatrix |] ) + ] + +sunCtx :: Context +sunCtx = mempty {ctxTypesTable = sunTypesTable} + +getMatrixDataFromContents :: Ptr SunMatrix -> IO SunMatrix +getMatrixDataFromContents ptr = do + qtr <- getContentMatrixPtr ptr + rs <- getNRows qtr + cs <- getNCols qtr + rtr <- getMatrixData qtr + vs <- vectorFromC (fromIntegral $ rs * cs) rtr + return $ SunMatrix { rows = rs, cols = cs, vals = vs } + +putMatrixDataFromContents :: SunMatrix -> Ptr SunMatrix -> IO () +putMatrixDataFromContents mat ptr = do + let rs = rows mat + cs = cols mat + vs = vals mat + qtr <- getContentMatrixPtr ptr + putNRows rs qtr + putNCols cs qtr + rtr <- getMatrixData qtr + vectorToC vs (fromIntegral $ rs * cs) rtr + +instance Storable SunMatrix where + poke = flip putMatrixDataFromContents + peek = getMatrixDataFromContents + sizeOf _ = error "sizeOf not supported for SunMatrix" + alignment _ = error "alignment not supported for SunMatrix" + +vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) +vectorFromC len ptr = do + ptr' <- newForeignPtr_ ptr + VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len + +vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () +vectorToC vec len ptr = do + ptr' <- newForeignPtr_ ptr + VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec + +getDataFromContents :: Int -> Ptr SunVector -> IO (VS.Vector CDouble) +getDataFromContents len ptr = do + qtr <- getContentPtr ptr + rtr <- getData qtr + vectorFromC len rtr + +putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () +putDataInContents vec len ptr = do + qtr <- getContentPtr ptr + rtr <- getData qtr + vectorToC vec len rtr + #def typedef struct _generic_N_Vector SunVector; #def typedef struct _N_VectorContent_Serial SunContent; diff --git a/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs b/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs deleted file mode 100644 index 0908cbe..0000000 --- a/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs +++ /dev/null @@ -1,37 +0,0 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE EmptyDataDecls #-} - -module Numeric.Sundials.CLangToHaskellTypes where - -import Foreign.C.Types - -import qualified Language.Haskell.TH as TH -import qualified Language.C.Types as CT -import qualified Data.Map as Map -import Language.C.Inline.Context - -import qualified Data.Vector.Storable as V - - -data SunVector -data SunMatrix = SunMatrix { rows :: CInt - , cols :: CInt - , vals :: V.Vector CDouble - } - --- | This is true only if configured/ built as 64 bits -type SunIndexType = CLong - -sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ -sunTypesTable = Map.fromList - [ - (CT.TypeName "sunindextype", [t| SunIndexType |] ) - , (CT.TypeName "SunVector", [t| SunVector |] ) - , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) - ] - -sunCtx :: Context -sunCtx = mempty {ctxTypesTable = sunTypesTable} - diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index 1cd072f..159fbe2 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -79,6 +79,7 @@ import Data.Maybe (isJust) import Foreign.C.Types (CDouble, CInt, CLong) import Foreign.Ptr (Ptr) +import Foreign.Storable (poke) import qualified Data.Vector.Storable as V @@ -90,10 +91,10 @@ import Numeric.LinearAlgebra.Devel (createVector) import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, cols, toLists, size, reshape) -import qualified Numeric.Sundials.CLangToHaskellTypes as T -import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF) -import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian) -import qualified Numeric.Sundials.ODEOpts as SO +import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF, + getDataFromContents, putDataInContents) +import qualified Numeric.Sundials.Arkode as T +import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -195,7 +196,7 @@ odeSolveVWith' :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution + -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution odeSolveVWith' opts method control initStepSize f y0 tt = case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) @@ -229,7 +230,7 @@ solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution + -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do @@ -257,9 +258,9 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> SO.getDataFromContents dim y + fImm <- fun x <$> getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - SO.putDataInContents fImm dim f + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -271,8 +272,8 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" - Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y - SO.putMatrixDataFromContents j jacS + Just jacI -> do j <- jacI t <$> getDataFromContents dim y + poke jacS j -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -431,16 +432,16 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts if res == 0 then do preD <- V.freeze diagMut - let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) - (fromIntegral $ preD V.!1) - (fromIntegral $ preD V.!2) - (fromIntegral $ preD V.!3) - (fromIntegral $ preD V.!4) - (fromIntegral $ preD V.!5) - (fromIntegral $ preD V.!6) - (fromIntegral $ preD V.!7) - (fromIntegral $ preD V.!8) - (fromIntegral $ preD V.!9) + let d = SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) m <- V.freeze qMatMut return $ Right (m, d) else do diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index 56dc12c..89f2306 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs @@ -1,17 +1,10 @@ module Numeric.Sundials.ODEOpts where import Data.Int (Int32) -import Foreign.Ptr (Ptr) -import Foreign.Storable as FS -import Foreign.ForeignPtr as FF -import Foreign.C.Types import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VM import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) -import qualified Numeric.Sundials.CLangToHaskellTypes as T -import qualified Numeric.Sundials.Arkode as B type Jacobian = Double -> Vector Double -> Matrix Double @@ -23,50 +16,6 @@ data ODEOpts = ODEOpts { , initStep :: Double } deriving (Read, Show, Eq, Ord) --- FIXME: Potentially an instance of Storable -_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix -_getMatrixDataFromContents ptr = do - qtr <- B.getContentMatrixPtr ptr - rs <- B.getNRows qtr - cs <- B.getNCols qtr - rtr <- B.getMatrixData qtr - vs <- vectorFromC (fromIntegral $ rs * cs) rtr - return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } - -putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () -putMatrixDataFromContents mat ptr = do - let rs = T.rows mat - cs = T.cols mat - vs = T.vals mat - qtr <- B.getContentMatrixPtr ptr - B.putNRows rs qtr - B.putNCols cs qtr - rtr <- B.getMatrixData qtr - vectorToC vs (fromIntegral $ rs * cs) rtr --- FIXME: END - -vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) -vectorFromC len ptr = do - ptr' <- newForeignPtr_ ptr - VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len - -vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () -vectorToC vec len ptr = do - ptr' <- newForeignPtr_ ptr - VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec - -getDataFromContents :: Int -> Ptr T.SunVector -> IO (VS.Vector CDouble) -getDataFromContents len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorFromC len rtr - -putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () -putDataInContents vec len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorToC vec len rtr - data SundialsDiagnostics = SundialsDiagnostics { aRKodeGetNumSteps :: Int , aRKodeGetNumStepAttempts :: Int -- cgit v1.2.3 From 4ba859636396d211637b5507f19722b6953656a5 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 2 May 2018 14:42:43 +0100 Subject: Add more options --- packages/sundials/src/Main.hs | 48 ++++++++- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 113 ++++++++++----------- .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 23 +++-- packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 7 +- 4 files changed, 116 insertions(+), 75 deletions(-) diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 85928e2..16c21c5 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -81,6 +81,23 @@ _stiffJac _t _v = (1><1) [ lamda ] where lamda = -100.0 +predatorPrey :: Double -> [Double] -> [Double] +predatorPrey _t v = [ x * a - b * x * y + , d * x * y - c * y - e * y * z + , (-f) * z + g * y * z + ] + where + x = v!!0 + y = v!!1 + z = v!!2 + a = 1.0 + b = 1.0 + c = 1.0 + d = 1.0 + e = 1.0 + f = 1.0 + g = 1.0 + lSaxis :: [[Double]] -> P.Axis B D.V2 Double lSaxis xs = P.r2Axis &~ do let ts = xs!!0 @@ -128,11 +145,6 @@ main = do let maxDiffC = maximum $ map abs $ zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0) - hspec $ describe "Compare results" $ do - it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6 - it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6 - it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6 - let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) renderRasterific "diagrams/lorenz.png" @@ -146,3 +158,29 @@ main = do renderRasterific "diagrams/lorenz2.png" (D.dims2D 500.0 500.0) (renderAxis $ kSaxis $ zip ((toLists $ tr res3)!!1) ((toLists $ tr res3)!!2)) + + let res4 = CV.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0]) + + renderRasterific "diagrams/predatorPrey.png" + (D.dims2D 500.0 500.0) + (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!1)) + + renderRasterific "diagrams/predatorPrey1.png" + (D.dims2D 500.0 500.0) + (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!0) ((toLists $ tr res4)!!2)) + + renderRasterific "diagrams/predatorPrey2.png" + (D.dims2D 500.0 500.0) + (renderAxis $ kSaxis $ zip ((toLists $ tr res4)!!1) ((toLists $ tr res4)!!2)) + + let res4a = ARK.odeSolve predatorPrey [0.5, 1.0, 2.0] (fromList [0.0, 0.01 .. 10.0]) + + let maxDiffPpA = maximum $ map abs $ + zipWith (-) ((toLists $ tr res4)!!0) ((toLists $ tr res4a)!!0) + + hspec $ describe "Compare results" $ do + it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6 + it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6 + it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6 + it "for CV and ARK for the Predator Prey model" $ maxDiffPpA < 1.0e-3 + diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 13b7eb8..ce46968 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -114,7 +114,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve , butcherTable , ODEMethod(..) , StepControl(..) - , Jacobian ) where import qualified Language.C.Inline as C @@ -136,11 +135,11 @@ import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1( import Numeric.LinearAlgebra.Devel (createVector) -import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), - subMatrix, rows, cols, toLists, - size, subVector) +import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, + cols, toLists, size, reshape, + subVector, subMatrix, (><)) -import qualified Numeric.Sundials.ODEOpts as SO +import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) import qualified Numeric.Sundials.Arkode as T import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, sDIRK_2_1_2, @@ -185,8 +184,6 @@ C.include "../../../helpers.h" C.include "Numeric/Sundials/Arkode_hsc.h" -type Jacobian = Double -> Vector Double -> Matrix Double - -- | Stepping functions data ODEMethod = SDIRK_2_1_2 Jacobian | SDIRK_2_1_2' @@ -351,15 +348,9 @@ odeSolveV -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolveV meth hi epsAbs epsRel f y0 ts = - case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of - Left c -> error $ show c -- FIXME - -- FIXME: Can we do better than using lists? - Right (v, _d) -> (nR >< nC) (V.toList v) - where - us = toList ts - nR = length us - nC = size y0 - g t x0 = coerce $ f t x0 + odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts + where + g t x0 = coerce $ f t x0 -- | A version of 'odeSolveV' with reasonable default parameters and -- system of equations defined using lists. FIXME: we should say @@ -371,13 +362,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Matrix Double -- ^ solution odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL - case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of - Left c -> error $ show c -- FIXME - Right (v, _d) -> (nR >< nC) (V.toList v) + odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) where - us = toList ts - nR = length us - nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) odeSolveVWith :: @@ -394,15 +380,21 @@ odeSolveVWith :: -> V.Vector Double -- ^ Desired solution times -> Matrix Double -- ^ Error code or solution odeSolveVWith method control initStepSize f y0 tt = - case odeSolveVWith' method control initStepSize f y0 tt of + case odeSolveVWith' opts method control initStepSize f y0 tt of Left c -> error $ show c -- FIXME - Right (v, _d) -> (nR >< nC) (V.toList v) + Right (v, _d) -> v where - nR = V.length tt - nC = V.length y0 + opts = ODEOpts { maxNumSteps = 10000 + , minStep = 1.0e-12 + , relTol = error "relTol" + , absTols = error "absTol" + , initStep = error "initStep" + , maxFail = 10 + } odeSolveVWith' :: - ODEMethod + ODEOpts + -> ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode -- estimates the initial step size to be the @@ -413,19 +405,21 @@ odeSolveVWith' :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith' method control initStepSize f y0 tt = - case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) + -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution +odeSolveVWith' opts method control initStepSize f y0 tt = + case solveOdeC (fromIntegral $ maxFail opts) + (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) + (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c - Right (v, d) -> Right (coerce v, d) + Right (v, d) -> Right (reshape l (coerce v), d) where l = size y0 - scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) - scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) - scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) + scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) + scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) + scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) -- FIXME; Should we check that the length of ss is correct? - scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) + scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ getJacobian method matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } @@ -436,6 +430,9 @@ odeSolveVWith' method control initStepSize f y0 tt = vs = V.fromList $ map coerce $ concat $ toLists m solveOdeC :: + CInt -> + CLong -> + CDouble -> CInt -> Maybe CDouble -> (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> @@ -443,8 +440,9 @@ solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution -solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do + -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution +solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize + jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do let isInitStepSize :: CInt isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize @@ -455,14 +453,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO -- used :( Nothing -> 0.0 Just x -> x + let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim nTs :: CInt nTs = fromIntegral $ V.length ts - -- FIXME: fMut is not actually mutatated - fMut <- V.thaw f0 - tMut <- V.thaw ts -- FIXME: I believe this gets taken from the ghc heap and so should -- be subject to garbage collection. quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) @@ -510,7 +506,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO /* general problem parameters */ - realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ + realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ /* Initialize data structures */ @@ -519,14 +515,14 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; /* Specify initial condition */ for (i = 0; i < NEQ; i++) { - NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; + NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; }; tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */ if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; /* Specify tolerances */ for (i = 0; i < NEQ; i++) { - NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; + NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; }; arkode_mem = ARKodeCreate(); /* Create the solver memory */ @@ -547,14 +543,15 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if (check_flag(&flag, "ARKodeInit", 1)) return 1; } - /* FIXME: A hack for initial testing */ - flag = ARKodeSetMinStep(arkode_mem, 1.0e-12); + flag = ARKodeSetMinStep(arkode_mem, $(double minStep_)); if (check_flag(&flag, "ARKodeSetMinStep", 1)) return 1; - flag = ARKodeSetMaxNumSteps(arkode_mem, 10000); + flag = ARKodeSetMaxNumSteps(arkode_mem, $(long int maxNumSteps_)); if (check_flag(&flag, "ARKodeSetMaxNumSteps", 1)) return 1; + flag = ARKodeSetMaxErrTestFails(arkode_mem, $(int maxErrTestFails)); + if (check_flag(&flag, "ARKodeSetMaxErrTestFails", 1)) return 1; /* Set routines */ - flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv); + flag = ARKodeSVtolerances(arkode_mem, $(double rTol), tv); if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1; /* Initialize dense matrix data structure and solver */ @@ -599,7 +596,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO /* Stops when the final time has been reached */ for (i = 1; i < $(int nTs); i++) { - flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ + flag = ARKode(arkode_mem, ($vec-ptr:(double *ts))[i], y, &t, ARK_NORMAL); /* call integrator */ if (check_flag(&flag, "ARKode", 1)) break; /* Store the results for Haskell */ @@ -665,16 +662,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if res == 0 then do preD <- V.freeze diagMut - let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) - (fromIntegral $ preD V.!1) - (fromIntegral $ preD V.!2) - (fromIntegral $ preD V.!3) - (fromIntegral $ preD V.!4) - (fromIntegral $ preD V.!5) - (fromIntegral $ preD V.!6) - (fromIntegral $ preD V.!7) - (fromIntegral $ preD V.!8) - (fromIntegral $ preD V.!9) + let d = SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) m <- V.freeze qMatMut return $ Right (m, d) else do diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index 159fbe2..a6f185e 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -68,7 +68,6 @@ module Numeric.Sundials.CVode.ODE ( odeSolve , odeSolveVWith' , ODEMethod(..) , StepControl(..) - , Jacobian ) where import qualified Language.C.Inline as C @@ -127,7 +126,7 @@ getJacobian _ = Nothing -- | A version of 'odeSolveVWith' with reasonable default step control. odeSolveV :: ODEMethod - -> Maybe Double -- ^ initial step size - by default, ARKode + -> Maybe Double -- ^ initial step size - by default, CVode -- estimates the initial step size to be the -- solution \(h\) of the equation -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where @@ -161,7 +160,7 @@ odeSolve f y0 ts = odeSolveVWith :: ODEMethod -> StepControl - -> Maybe Double -- ^ initial step size - by default, ARKode + -> Maybe Double -- ^ initial step size - by default, CVode -- estimates the initial step size to be the -- solution \(h\) of the equation -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where @@ -181,13 +180,14 @@ odeSolveVWith method control initStepSize f y0 tt = , relTol = error "relTol" , absTols = error "absTol" , initStep = error "initStep" + , maxFail = 10 } odeSolveVWith' :: ODEOpts -> ODEMethod -> StepControl - -> Maybe Double -- ^ initial step size - by default, ARKode + -> Maybe Double -- ^ initial step size - by default, CVode -- estimates the initial step size to be the -- solution \(h\) of the equation -- \(\|\frac{h^2\ddot{y}}{2}\| = 1\), where @@ -198,13 +198,13 @@ odeSolveVWith' :: -> V.Vector Double -- ^ Desired solution times -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution odeSolveVWith' opts method control initStepSize f y0 tt = - case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) + case solveOdeC (fromIntegral $ maxFail opts) + (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c - Right (v, d) -> Right (reshape nC (coerce v), d) + Right (v, d) -> Right (reshape l (coerce v), d) where - nC = V.length y0 l = size y0 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) @@ -221,6 +221,7 @@ odeSolveVWith' opts method control initStepSize f y0 tt = vs = V.fromList $ map coerce $ concat $ toLists m solveOdeC :: + CInt -> CLong -> CDouble -> CInt -> @@ -231,7 +232,8 @@ solveOdeC :: -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution -solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = +solveOdeC maxErrTestFails maxNumSteps_ minStep_ method initStepSize + jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do let isInitStepSize :: CInt @@ -243,6 +245,7 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts -- used :( Nothing -> 0.0 Just x -> x + let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim @@ -271,7 +274,7 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts IO CInt jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of - Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" + Nothing -> error "Numeric.Sundials.CVode.ODE: Jacobian not defined" Just jacI -> do j <- jacI t <$> getDataFromContents dim y poke jacS j -- FIXME: I don't understand what this comment means @@ -326,6 +329,8 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_)); if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; + flag = CVodeSetMaxErrTestFails(cvode_mem, $(int maxErrTestFails)); + if (check_flag(&flag, "CVodeSetMaxErrTestFails", 1)) return 1; /* Call CVodeSVtolerances to specify the scalar relative tolerance * and vector absolute tolerances */ diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index 89f2306..027d99a 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs @@ -1,6 +1,6 @@ module Numeric.Sundials.ODEOpts where -import Data.Int (Int32) +import Data.Word (Word32) import qualified Data.Vector.Storable as VS import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) @@ -9,11 +9,12 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) type Jacobian = Double -> Vector Double -> Matrix Double data ODEOpts = ODEOpts { - maxNumSteps :: Int32 + maxNumSteps :: Word32 , minStep :: Double , relTol :: Double , absTols :: VS.Vector Double - , initStep :: Double + , initStep :: Maybe Double + , maxFail :: Word32 } deriving (Read, Show, Eq, Ord) data SundialsDiagnostics = SundialsDiagnostics { -- cgit v1.2.3 From 686bd51792648dee967c611225cb1a59efa6b1c2 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Thu, 3 May 2018 08:14:30 +0100 Subject: Improve documentation --- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 51 +++++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index ce46968..fafc237 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -20,8 +20,7 @@ -- Stability : provisional -- -- Solution of ordinary differential equation (ODE) initial value problems. --- --- +-- See for more detail. -- -- A simple example: -- @@ -65,6 +64,54 @@ -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) -- @ -- +-- With Sundials ARKode, it is possible to retrieve the Butcher tableau for the solver. +-- +-- @ +-- import Numeric.Sundials.ARKode.ODE +-- import Numeric.LinearAlgebra +-- +-- import Data.List (intercalate) +-- +-- import Text.PrettyPrint.HughesPJClass +-- +-- +-- butcherTableauTex :: ButcherTable -> String +-- butcherTableauTex (ButcherTable m c b b2) = +-- render $ +-- vcat [ text ("\n\\begin{array}{c|" ++ (concat $ replicate n "c") ++ "}") +-- , us +-- , text "\\hline" +-- , text bs <+> text "\\\\" +-- , text b2s <+> text "\\\\" +-- , text "\\end{array}" +-- ] +-- where +-- n = rows m +-- rs = toLists m +-- ss = map (\r -> intercalate " & " $ map show r) rs +-- ts = zipWith (\i r -> show i ++ " & " ++ r) (toList c) ss +-- us = vcat $ map (\r -> text r <+> text "\\\\") ts +-- bs = " & " ++ (intercalate " & " $ map show $ toList b) +-- b2s = " & " ++ (intercalate " & " $ map show $ toList b2) +-- +-- main :: IO () +-- main = do +-- +-- let res = butcherTable (SDIRK_2_1_2 undefined) +-- putStrLn $ show res +-- putStrLn $ butcherTableauTex res +-- +-- let resA = butcherTable (KVAERNO_4_2_3 undefined) +-- putStrLn $ show resA +-- putStrLn $ butcherTableauTex resA +-- +-- let resB = butcherTable (SDIRK_5_3_4 undefined) +-- putStrLn $ show resB +-- putStrLn $ butcherTableauTex resB +-- @ +-- +-- Using the code above from the examples gives +-- -- KVAERNO_4_2_3 -- -- \[ -- cgit v1.2.3