From 149dedfc6ec8dea039a4df7ad1d31880820c52eb Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 27 Apr 2018 14:19:59 +0100 Subject: More restructuring --- packages/sundials/hmatrix-sundials.cabal | 6 +- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 56 ++++++++++---- packages/sundials/src/Numeric/Sundials/Arkode.hsc | 88 +++++++++++++++++++++- .../src/Numeric/Sundials/CLangToHaskellTypes.hs | 37 --------- .../sundials/src/Numeric/Sundials/CVode/ODE.hs | 41 +++++----- packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 51 ------------- 6 files changed, 149 insertions(+), 130 deletions(-) delete mode 100644 packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 234bb9c..cd2be4e 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal @@ -32,16 +32,14 @@ library exposed-modules: Numeric.Sundials.ODEOpts, Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE - other-modules: Numeric.Sundials.CLangToHaskellTypes, - Numeric.Sundials.Arkode + other-modules: Numeric.Sundials.Arkode c-sources: src/helpers.c src/helpers.h default-language: Haskell2010 test-suite hmatrix-sundials-testsuite type: exitcode-stdio-1.0 main-is: Main.hs - other-modules: Numeric.Sundials.CLangToHaskellTypes, - Numeric.Sundials.ODEOpts, + other-modules: Numeric.Sundials.ODEOpts, Numeric.Sundials.ARKode.ODE, Numeric.Sundials.CVode.ODE, Numeric.Sundials.Arkode diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index a8d418b..13b7eb8 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -125,6 +125,7 @@ import Data.Maybe (isJust) import Foreign.C.Types (CDouble, CInt, CLong) import Foreign.Ptr (Ptr) +import Foreign.Storable (poke) import qualified Data.Vector.Storable as V @@ -139,10 +140,33 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix, rows, cols, toLists, size, subVector) -import qualified Numeric.Sundials.CLangToHaskellTypes as T -import Numeric.Sundials.Arkode -import qualified Numeric.Sundials.Arkode as B import qualified Numeric.Sundials.ODEOpts as SO +import qualified Numeric.Sundials.Arkode as T +import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, + sDIRK_2_1_2, + bILLINGTON_3_3_2, + tRBDF2_3_3_2, + kVAERNO_4_2_3, + aRK324L2SA_DIRK_4_2_3, + cASH_5_2_4, + cASH_5_3_4, + sDIRK_5_3_4, + kVAERNO_5_3_4, + aRK436L2SA_DIRK_6_3_4, + kVAERNO_7_4_5, + aRK548L2SA_DIRK_8_4_5, + hEUN_EULER_2_1_2, + bOGACKI_SHAMPINE_4_2_3, + aRK324L2SA_ERK_4_2_3, + zONNEVELD_5_3_4, + aRK436L2SA_ERK_6_3_4, + sAYFY_ABURUB_6_3_4, + cASH_KARP_6_4_5, + fEHLBERG_6_4_5, + dORMAND_PRINCE_7_4_5, + aRK548L2SA_ERK_8_4_5, + vERNER_8_5_6, + fEHLBERG_13_7_8) C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -451,9 +475,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> SO.getDataFromContents dim y + fImm <- fun x <$> getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - SO.putDataInContents fImm dim f + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -465,8 +489,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" - Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y - SO.putMatrixDataFromContents j jacS + Just jacI -> do j <- jacI t <$> getDataFromContents dim y + poke jacS j -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -675,7 +699,7 @@ butcherTable method = case getBT method of Left c -> error $ show c -- FIXME Right (ButcherTable' v w x y, sqp) -> - ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) + ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v) , cv = subVector 0 s w , bv = subVector 0 s x , b2v = subVector 0 s y @@ -710,25 +734,25 @@ getButcherTable method = unsafePerformIO $ do btSQP :: V.Vector CInt <- createVector 3 btSQPMut <- V.thaw btSQP - btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) + btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax) btAsMut <- V.thaw btAs - btCs :: V.Vector CDouble <- createVector B.arkSMax - btBs :: V.Vector CDouble <- createVector B.arkSMax - btB2s :: V.Vector CDouble <- createVector B.arkSMax + btCs :: V.Vector CDouble <- createVector arkSMax + btBs :: V.Vector CDouble <- createVector arkSMax + btB2s :: V.Vector CDouble <- createVector arkSMax btCsMut <- V.thaw btCs btBsMut <- V.thaw btBs btB2sMut <- V.thaw btB2s let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOI x y f _ptr = do - fImm <- funI x <$> SO.getDataFromContents dim y - SO.putDataInContents fImm dim f + fImm <- funI x <$> getDataFromContents dim y + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIOE x y f _ptr = do - fImm <- funE x <$> SO.getDataFromContents dim y - SO.putDataInContents fImm dim f + fImm <- funE x <$> getDataFromContents dim y + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] diff --git a/packages/sundials/src/Numeric/Sundials/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc index 1700cdf..0850258 100644 --- a/packages/sundials/src/Numeric/Sundials/Arkode.hsc +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc @@ -1,7 +1,23 @@ +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE EmptyDataDecls #-} + module Numeric.Sundials.Arkode where -import Foreign -import Foreign.C.Types +import Foreign +import Foreign.C.Types + +import Language.C.Types as CT + +import qualified Data.Vector.Storable as VS +import qualified Data.Vector.Storable.Mutable as VM + +import qualified Language.Haskell.TH as TH +import qualified Data.Map as Map +import Language.C.Inline.Context + +import qualified Data.Vector.Storable as V #include @@ -13,6 +29,74 @@ import Foreign.C.Types #include +data SunVector +data SunMatrix = SunMatrix { rows :: CInt + , cols :: CInt + , vals :: V.Vector CDouble + } + +-- | This is true only if configured/ built as 64 bits +type SunIndexType = CLong + +sunTypesTable :: Map.Map TypeSpecifier TH.TypeQ +sunTypesTable = Map.fromList + [ + (TypeName "sunindextype", [t| SunIndexType |] ) + , (TypeName "SunVector", [t| SunVector |] ) + , (TypeName "SunMatrix", [t| SunMatrix |] ) + ] + +sunCtx :: Context +sunCtx = mempty {ctxTypesTable = sunTypesTable} + +getMatrixDataFromContents :: Ptr SunMatrix -> IO SunMatrix +getMatrixDataFromContents ptr = do + qtr <- getContentMatrixPtr ptr + rs <- getNRows qtr + cs <- getNCols qtr + rtr <- getMatrixData qtr + vs <- vectorFromC (fromIntegral $ rs * cs) rtr + return $ SunMatrix { rows = rs, cols = cs, vals = vs } + +putMatrixDataFromContents :: SunMatrix -> Ptr SunMatrix -> IO () +putMatrixDataFromContents mat ptr = do + let rs = rows mat + cs = cols mat + vs = vals mat + qtr <- getContentMatrixPtr ptr + putNRows rs qtr + putNCols cs qtr + rtr <- getMatrixData qtr + vectorToC vs (fromIntegral $ rs * cs) rtr + +instance Storable SunMatrix where + poke = flip putMatrixDataFromContents + peek = getMatrixDataFromContents + sizeOf _ = error "sizeOf not supported for SunMatrix" + alignment _ = error "alignment not supported for SunMatrix" + +vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) +vectorFromC len ptr = do + ptr' <- newForeignPtr_ ptr + VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len + +vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () +vectorToC vec len ptr = do + ptr' <- newForeignPtr_ ptr + VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec + +getDataFromContents :: Int -> Ptr SunVector -> IO (VS.Vector CDouble) +getDataFromContents len ptr = do + qtr <- getContentPtr ptr + rtr <- getData qtr + vectorFromC len rtr + +putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () +putDataInContents vec len ptr = do + qtr <- getContentPtr ptr + rtr <- getData qtr + vectorToC vec len rtr + #def typedef struct _generic_N_Vector SunVector; #def typedef struct _N_VectorContent_Serial SunContent; diff --git a/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs b/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs deleted file mode 100644 index 0908cbe..0000000 --- a/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs +++ /dev/null @@ -1,37 +0,0 @@ -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE EmptyDataDecls #-} - -module Numeric.Sundials.CLangToHaskellTypes where - -import Foreign.C.Types - -import qualified Language.Haskell.TH as TH -import qualified Language.C.Types as CT -import qualified Data.Map as Map -import Language.C.Inline.Context - -import qualified Data.Vector.Storable as V - - -data SunVector -data SunMatrix = SunMatrix { rows :: CInt - , cols :: CInt - , vals :: V.Vector CDouble - } - --- | This is true only if configured/ built as 64 bits -type SunIndexType = CLong - -sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ -sunTypesTable = Map.fromList - [ - (CT.TypeName "sunindextype", [t| SunIndexType |] ) - , (CT.TypeName "SunVector", [t| SunVector |] ) - , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) - ] - -sunCtx :: Context -sunCtx = mempty {ctxTypesTable = sunTypesTable} - diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index 1cd072f..159fbe2 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs @@ -79,6 +79,7 @@ import Data.Maybe (isJust) import Foreign.C.Types (CDouble, CInt, CLong) import Foreign.Ptr (Ptr) +import Foreign.Storable (poke) import qualified Data.Vector.Storable as V @@ -90,10 +91,10 @@ import Numeric.LinearAlgebra.Devel (createVector) import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, cols, toLists, size, reshape) -import qualified Numeric.Sundials.CLangToHaskellTypes as T -import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF) -import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian) -import qualified Numeric.Sundials.ODEOpts as SO +import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF, + getDataFromContents, putDataInContents) +import qualified Numeric.Sundials.Arkode as T +import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) @@ -195,7 +196,7 @@ odeSolveVWith' :: -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution + -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution odeSolveVWith' opts method control initStepSize f y0 tt = case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) @@ -229,7 +230,7 @@ solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution + -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = unsafePerformIO $ do @@ -257,9 +258,9 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> SO.getDataFromContents dim y + fImm <- fun x <$> getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - SO.putDataInContents fImm dim f + putDataInContents fImm dim f -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -271,8 +272,8 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do case jacH of Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" - Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y - SO.putMatrixDataFromContents j jacS + Just jacI -> do j <- jacI t <$> getDataFromContents dim y + poke jacS j -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] @@ -431,16 +432,16 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts if res == 0 then do preD <- V.freeze diagMut - let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) - (fromIntegral $ preD V.!1) - (fromIntegral $ preD V.!2) - (fromIntegral $ preD V.!3) - (fromIntegral $ preD V.!4) - (fromIntegral $ preD V.!5) - (fromIntegral $ preD V.!6) - (fromIntegral $ preD V.!7) - (fromIntegral $ preD V.!8) - (fromIntegral $ preD V.!9) + let d = SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) m <- V.freeze qMatMut return $ Right (m, d) else do diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index 56dc12c..89f2306 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs @@ -1,17 +1,10 @@ module Numeric.Sundials.ODEOpts where import Data.Int (Int32) -import Foreign.Ptr (Ptr) -import Foreign.Storable as FS -import Foreign.ForeignPtr as FF -import Foreign.C.Types import qualified Data.Vector.Storable as VS -import qualified Data.Vector.Storable.Mutable as VM import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) -import qualified Numeric.Sundials.CLangToHaskellTypes as T -import qualified Numeric.Sundials.Arkode as B type Jacobian = Double -> Vector Double -> Matrix Double @@ -23,50 +16,6 @@ data ODEOpts = ODEOpts { , initStep :: Double } deriving (Read, Show, Eq, Ord) --- FIXME: Potentially an instance of Storable -_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix -_getMatrixDataFromContents ptr = do - qtr <- B.getContentMatrixPtr ptr - rs <- B.getNRows qtr - cs <- B.getNCols qtr - rtr <- B.getMatrixData qtr - vs <- vectorFromC (fromIntegral $ rs * cs) rtr - return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } - -putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () -putMatrixDataFromContents mat ptr = do - let rs = T.rows mat - cs = T.cols mat - vs = T.vals mat - qtr <- B.getContentMatrixPtr ptr - B.putNRows rs qtr - B.putNCols cs qtr - rtr <- B.getMatrixData qtr - vectorToC vs (fromIntegral $ rs * cs) rtr --- FIXME: END - -vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) -vectorFromC len ptr = do - ptr' <- newForeignPtr_ ptr - VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len - -vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () -vectorToC vec len ptr = do - ptr' <- newForeignPtr_ ptr - VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec - -getDataFromContents :: Int -> Ptr T.SunVector -> IO (VS.Vector CDouble) -getDataFromContents len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorFromC len rtr - -putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () -putDataInContents vec len ptr = do - qtr <- B.getContentPtr ptr - rtr <- B.getData qtr - vectorToC vec len rtr - data SundialsDiagnostics = SundialsDiagnostics { aRKodeGetNumSteps :: Int , aRKodeGetNumStepAttempts :: Int -- cgit v1.2.3