From 1635f317b5fe8bfcea33c5e7428598fffb0446d0 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Sat, 17 Mar 2018 12:37:41 +0000 Subject: Start of the correct (non_Fortran) way --- packages/sundials/src/Bar.hsc | 24 +++++++++++++++++ packages/sundials/src/Main.hs | 57 +++++++++++++++++++++++++++++++++++------ packages/sundials/src/Types.hs | 12 +++++++-- packages/sundials/src/helpers.c | 5 +--- packages/sundials/src/helpers.h | 3 +++ 5 files changed, 87 insertions(+), 14 deletions(-) create mode 100644 packages/sundials/src/Bar.hsc (limited to 'packages/sundials') diff --git a/packages/sundials/src/Bar.hsc b/packages/sundials/src/Bar.hsc new file mode 100644 index 0000000..b1159b6 --- /dev/null +++ b/packages/sundials/src/Bar.hsc @@ -0,0 +1,24 @@ +{-# LANGUAGE RecordWildCards #-} + +module Example where + +import Foreign +import Foreign.C.Types +import Foreign.C.String + +#include "/Users/dom/sundials/include/sundials/sundials_nvector.h" +#include "/Users/dom/sundials/include/nvector/nvector_serial.h" + +#def typedef struct _generic_N_Vector BarType; +#def typedef struct _N_VectorContent_Serial BazType; + + +getContentPtr :: Storable a => Ptr b -> IO a +getContPtr ptr = (#peek BarType, content) ptr + +getData ptr = (#peek BazType, data) ptr + +foo ptr = do + qtr <- getContPtr ptr + getData qtr + diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index bab5710..328af08 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -28,7 +28,13 @@ import qualified Language.C.Types as CT import qualified Data.Map as Map import Language.C.Inline.Context -C.context (C.baseCtx <> C.vecCtx <> C.funCtx) +import Foreign.C.String +import Foreign.Storable (peek, poke, peekByteOff) +import Data.Int + +import qualified Types as T + +C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) -- C includes C.include "" @@ -43,6 +49,21 @@ C.include "" C.include "helpers.h" +-- These were semi-generated using hsc2hs with Bar.hsc as the +-- template. They are probably very fragile and could easily break on +-- different architectures and / or changes in the sundials package. + +getContentPtr :: Storable a => Ptr b -> IO a +getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr + +getData :: Storable a => Ptr b -> IO a +getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr + +getDataFromContents :: Storable a => Ptr b -> IO a +getDataFromContents ptr = do + qtr <- getContentPtr ptr + getData qtr + -- Utils vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) @@ -55,11 +76,23 @@ vectorToC vec len ptr = do ptr' <- newForeignPtr_ ptr V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec -foreign export ccall singleEq :: Double -> Double -> IO Double +-- Provided you always call your function 'multiEq' then we can +-- probably solve any set of ODEs! But of course we don't want to +-- follow the Fortran way of interacting with sundials. + +-- foreign export ccall multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO () -singleEq :: Double -> Double -> IO Double -singleEq t u = return $ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t +multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO () +multiEq tPtr yPtr yDotPtr iParPtr rParPtr ierPtr = do + t <- peek tPtr + y <- vectorFromC 1 yPtr + vectorToC (V.map realToFrac $ stiffish (realToFrac t) (V.map realToFrac y)) 1 yDotPtr + poke ierPtr 0 + +stiffish :: Double -> V.Vector Double -> V.Vector Double +stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] where + u = v V.! 0 lamda = -100.0 solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> @@ -68,12 +101,16 @@ solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> CInt solve fun f0 lambda = unsafePerformIO $ do let dim = V.length f0 - let funIO x y f _ptr = do + -- We need the types that sundials expects. These are tied together + -- in 'Types'. The Haskell type is currently empty! + let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt + funIO x y f _ptr = do + error $ show x -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> vectorFromC dim y + -- fImm <- fun x <$> vectorFromC dim y -- Fill in the provided pointer with the resulting vector. - vectorToC fImm dim f + -- vectorToC fImm dim f -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] res <- [C.block| int { @@ -114,7 +151,11 @@ solve fun f0 lambda = unsafePerformIO $ do /* right-hand side function in y'=f(t,y), the inital time T0, and */ /* the initial dependent variable vector y. Note: since this */ /* problem is fully implicit, we set f_E to NULL and f_I to f. */ - flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); + + /* Here we use the C types defined in helpers.h which tie up with */ + /* the Haskell types defined in Types */ + flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); + /* flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); */ if (check_flag(&flag, "ARKodeInit", 1)) return 1; /* Set routines */ diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs index 355850d..325072c 100644 --- a/packages/sundials/src/Types.hs +++ b/packages/sundials/src/Types.hs @@ -20,7 +20,7 @@ import qualified Data.Vector.Storable.Mutable as VM import Foreign.C.Types import Foreign.ForeignPtr (newForeignPtr_) import Foreign.Ptr (Ptr) -import Foreign.Storable (Storable) +import Foreign.Storable (Storable(..)) import qualified Language.C.Inline as C import qualified Language.C.Inline.Unsafe as CU import System.IO.Unsafe (unsafePerformIO) @@ -30,6 +30,13 @@ import qualified Language.C.Types as CT import qualified Data.Map as Map import Language.C.Inline.Context +data BarType + +instance Storable BarType where + sizeOf _ = sizeOf (undefined :: BarType) + alignment _ = alignment (undefined :: Ptr ()) + peek _ = error "peek not implemented for BarType" + poke _ _ = error "poke not implemented for BarType" -- This is a lie!!! type SunIndexType = CLong @@ -38,7 +45,8 @@ sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ sunTypesTable = Map.fromList [ (CT.TypeName "sunindextype", [t| SunIndexType |] ) + , (CT.TypeName "BarType", [t| BarType |] ) ] -sunctx = mempty {ctxTypesTable = sunTypesTable} +sunCtx = mempty {ctxTypesTable = sunTypesTable} diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c index eab5ac9..6162b71 100644 --- a/packages/sundials/src/helpers.c +++ b/packages/sundials/src/helpers.c @@ -64,10 +64,7 @@ int f(realtype t, N_Vector y, N_Vector ydot, void *user_data) int FARK_IMP_FUN(realtype *T, realtype *Y, realtype *YDOT, long int *IPAR, realtype *RPAR, int *IER) { - realtype t = *T; - realtype u = Y[0]; - realtype lamda = -100.0; - YDOT[0] = singleEq(t, u); + multiEq(T, Y, YDOT, IPAR, RPAR, IER); return 0; } diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h index 3b50163..7f4ba02 100644 --- a/packages/sundials/src/helpers.h +++ b/packages/sundials/src/helpers.h @@ -8,6 +8,9 @@ #define FSYM "f" #endif +typedef struct _generic_N_Vector BarType; +typedef struct _N_VectorContent_Serial BazType; + /* Check function return value... opt == 0 means SUNDIALS function allocates memory so check if returned NULL pointer -- cgit v1.2.3