From 729eb192cf77d4cddf33d2724b4409ab7d828921 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 25 Apr 2018 16:25:20 +0100 Subject: Pull out common code and start to follow gsl naming convention --- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 123 ++++++--------------- 1 file changed, 31 insertions(+), 92 deletions(-) (limited to 'packages/sundials/src/Numeric/Sundials/ARKode') diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index e5a2e4d..8b713c6 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -117,7 +117,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve , ODEMethod(..) , StepControl(..) , Jacobian - , SundialsDiagnostics(..) ) where import qualified Language.C.Inline as C @@ -126,17 +125,15 @@ import qualified Language.C.Inline.Unsafe as CU import Data.Monoid ((<>)) import Data.Maybe (isJust) -import Foreign.C.Types +import Foreign.C.Types (CDouble, CInt, CLong) import Foreign.Ptr (Ptr) -import Foreign.ForeignPtr (newForeignPtr_) -import Foreign.Storable (Storable) import qualified Data.Vector.Storable as V -import qualified Data.Vector.Storable.Mutable as VM import Data.Coerce (coerce) import System.IO.Unsafe (unsafePerformIO) -import GHC.Generics +import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..), + from, conName) import Numeric.LinearAlgebra.Devel (createVector) @@ -147,6 +144,7 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), import qualified Types as T import Arkode import qualified Arkode as B +import qualified Numeric.Sundials.ODEOpts as SO C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -165,65 +163,6 @@ C.include "../../../helpers.h" C.include "Arkode_hsc.h" -getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) -getDataFromContents len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorFromC len rtr - --- FIXME: Potentially an instance of Storable -_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix -_getMatrixDataFromContents ptr = do - qtr <- B.getContentMatrixPtr ptr - rs <- B.getNRows qtr - cs <- B.getNCols qtr - rtr <- B.getMatrixData qtr - vs <- vectorFromC (fromIntegral $ rs * cs) rtr - return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } - -putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () -putMatrixDataFromContents mat ptr = do - let rs = T.rows mat - cs = T.cols mat - vs = T.vals mat - qtr <- B.getContentMatrixPtr ptr - B.putNRows rs qtr - B.putNCols cs qtr - rtr <- B.getMatrixData qtr - vectorToC vs (fromIntegral $ rs * cs) rtr --- FIXME: END - -putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () -putDataInContents vec len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorToC vec len rtr - --- Utils - -vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) -vectorFromC len ptr = do - ptr' <- newForeignPtr_ ptr - V.freeze $ VM.unsafeFromForeignPtr0 ptr' len - -vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () -vectorToC vec len ptr = do - ptr' <- newForeignPtr_ ptr - V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec - -data SundialsDiagnostics = SundialsDiagnostics { - aRKodeGetNumSteps :: Int - , aRKodeGetNumStepAttempts :: Int - , aRKodeGetNumRhsEvals_fe :: Int - , aRKodeGetNumRhsEvals_fi :: Int - , aRKodeGetNumLinSolvSetups :: Int - , aRKodeGetNumErrTestFails :: Int - , aRKodeGetNumNonlinSolvIters :: Int - , aRKodeGetNumNonlinSolvConvFails :: Int - , aRKDlsGetNumJacEvals :: Int - , aRKDlsGetNumRhsEvals :: Int - } deriving Show - type Jacobian = Double -> Vector Double -> Matrix Double -- | Stepping functions @@ -390,7 +329,7 @@ odeSolveV -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolveV meth hi epsAbs epsRel f y0 ts = - case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of + case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of Left c -> error $ show c -- FIXME -- FIXME: Can we do better than using lists? Right (v, _d) -> (nR >< nC) (V.toList v) @@ -410,7 +349,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y -> Matrix Double -- ^ solution odeSolve f y0 ts = -- FIXME: These tolerances are different from the ones in GSL - case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of + case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, _d) -> (nR >< nC) (V.toList v) where @@ -419,7 +358,7 @@ odeSolve f y0 ts = nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) -odeSolveVWith' :: +odeSolveVWith :: ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode @@ -432,15 +371,15 @@ odeSolveVWith' :: -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times -> Matrix Double -- ^ Error code or solution -odeSolveVWith' method control initStepSize f y0 tt = - case odeSolveVWith method control initStepSize f y0 tt of +odeSolveVWith method control initStepSize f y0 tt = + case odeSolveVWith' method control initStepSize f y0 tt of Left c -> error $ show c -- FIXME Right (v, _d) -> (nR >< nC) (V.toList v) where nR = V.length tt nC = V.length y0 -odeSolveVWith :: +odeSolveVWith' :: ODEMethod -> StepControl -> Maybe Double -- ^ initial step size - by default, ARKode @@ -452,8 +391,8 @@ odeSolveVWith :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -odeSolveVWith method control initStepSize f y0 tt = + -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution +odeSolveVWith' method control initStepSize f y0 tt = case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c @@ -482,7 +421,7 @@ solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution + -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do let isInitStepSize :: CInt @@ -514,9 +453,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> getDataFromContents dim y + fImm <- fun x <$> SO.getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - putDataInContents fImm dim f + SO.putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -528,8 +467,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" - Just jacI -> do j <- jacI t <$> getDataFromContents dim y - putMatrixDataFromContents j jacS + Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y + SO.putMatrixDataFromContents j jacS -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -704,16 +643,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO if res == 0 then do preD <- V.freeze diagMut - let d = SundialsDiagnostics (fromIntegral $ preD V.!0) - (fromIntegral $ preD V.!1) - (fromIntegral $ preD V.!2) - (fromIntegral $ preD V.!3) - (fromIntegral $ preD V.!4) - (fromIntegral $ preD V.!5) - (fromIntegral $ preD V.!6) - (fromIntegral $ preD V.!7) - (fromIntegral $ preD V.!8) - (fromIntegral $ preD V.!9) + let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) m <- V.freeze qMatMut return $ Right (m, d) else do @@ -783,15 +722,15 @@ getButcherTable method = unsafePerformIO $ do btB2sMut <- V.thaw btB2s let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOI x y f _ptr = do - fImm <- funI x <$> getDataFromContents dim y - putDataInContents fImm dim f + fImm <- funI x <$> SO.getDataFromContents dim y + SO.putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOE x y f _ptr = do - fImm <- funE x <$> getDataFromContents dim y - putDataInContents fImm dim f + fImm <- funE x <$> SO.getDataFromContents dim y + SO.putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] -- cgit v1.2.3