From 5a51a987e014066d019473a68c1ceca9e30a348f Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Mon, 19 Mar 2018 15:46:25 +0000 Subject: A working example --- packages/sundials/src/Main.hs | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) (limited to 'packages') diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 473daf7..89d6668 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -30,7 +30,7 @@ import qualified Data.Map as Map import Language.C.Inline.Context import Foreign.C.String -import Foreign.Storable (peek, poke, peekByteOff) +import Foreign.Storable (peek, poke, peekByteOff, pokeByteOff) import Data.Int import qualified Types as T @@ -60,10 +60,17 @@ getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr getData :: Storable a => Ptr b -> IO a getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr -getDataFromContents :: Storable a => Ptr b -> IO a -getDataFromContents ptr = do +getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) +getDataFromContents len ptr = do qtr <- getContentPtr ptr - getData qtr + rtr <- getData qtr + vectorFromC len rtr + +putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () +putDataInContents vec len ptr = do + qtr <- getContentPtr ptr + rtr <- getData qtr + vectorToC vec len rtr -- Utils @@ -98,23 +105,19 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> V.Vector Double -> - CDouble -> CInt -solveOdeC fun f0 lambda = unsafePerformIO $ do +solveOdeC fun f0 = unsafePerformIO $ do let dim = V.length f0 -- We need the types that sundials expects. These are tied together -- in 'Types'. The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt funIO x y f _ptr = do - z :: (Ptr (Ptr (CDouble))) <- getContentPtr y - u :: (Ptr (CDouble)) <- getDataFromContents y - v <- vectorFromC 1 u - error $ show y ++ " " ++ show z ++ " " ++ show u ++ " " ++ show v ++ " " ++ show dim -- Convert the pointer we get from C (y) to a vector, and then -- apply the user-supplied function. - fImm <- fun x <$> vectorFromC dim u + fImm <- fun x <$> getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. - -- vectorToC fImm dim f + putDataInContents fImm dim f + -- I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] res <- [C.block| int { @@ -140,7 +143,6 @@ solveOdeC fun f0 lambda = unsafePerformIO $ do /* Initial diagnostics output */ printf("\nAnalytical ODE test problem:\n"); printf(" lamda = %"GSYM"\n", lamda); - printf(" lambda = %"GSYM"\n", $(double lambda)); printf(" reltol = %.1"ESYM"\n", reltol); printf(" abstol = %.1"ESYM"\n\n",abstol); @@ -250,5 +252,5 @@ solveOdeC fun f0 lambda = unsafePerformIO $ do return res main = do - let res = solveOdeC undefined (V.fromList [17.0]) (coerce (100.0 :: Double)) + let res = solveOdeC (coerce stiffish) (V.fromList [1.0]) putStrLn $ show res -- cgit v1.2.3