summaryrefslogtreecommitdiff
path: root/packages/sundials/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src')
-rw-r--r--packages/sundials/src/Main.hs30
1 files changed, 16 insertions, 14 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 473daf7..89d6668 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -30,7 +30,7 @@ import qualified Data.Map as Map
30import Language.C.Inline.Context 30import Language.C.Inline.Context
31 31
32import Foreign.C.String 32import Foreign.C.String
33import Foreign.Storable (peek, poke, peekByteOff) 33import Foreign.Storable (peek, poke, peekByteOff, pokeByteOff)
34import Data.Int 34import Data.Int
35 35
36import qualified Types as T 36import qualified Types as T
@@ -60,10 +60,17 @@ getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr
60getData :: Storable a => Ptr b -> IO a 60getData :: Storable a => Ptr b -> IO a
61getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr 61getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr
62 62
63getDataFromContents :: Storable a => Ptr b -> IO a 63getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b)
64getDataFromContents ptr = do 64getDataFromContents len ptr = do
65 qtr <- getContentPtr ptr 65 qtr <- getContentPtr ptr
66 getData qtr 66 rtr <- getData qtr
67 vectorFromC len rtr
68
69putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
70putDataInContents vec len ptr = do
71 qtr <- getContentPtr ptr
72 rtr <- getData qtr
73 vectorToC vec len rtr
67 74
68-- Utils 75-- Utils
69 76
@@ -98,23 +105,19 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
98 105
99solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> 106solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
100 V.Vector Double -> 107 V.Vector Double ->
101 CDouble ->
102 CInt 108 CInt
103solveOdeC fun f0 lambda = unsafePerformIO $ do 109solveOdeC fun f0 = unsafePerformIO $ do
104 let dim = V.length f0 110 let dim = V.length f0
105 -- We need the types that sundials expects. These are tied together 111 -- We need the types that sundials expects. These are tied together
106 -- in 'Types'. The Haskell type is currently empty! 112 -- in 'Types'. The Haskell type is currently empty!
107 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt 113 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
108 funIO x y f _ptr = do 114 funIO x y f _ptr = do
109 z :: (Ptr (Ptr (CDouble))) <- getContentPtr y
110 u :: (Ptr (CDouble)) <- getDataFromContents y
111 v <- vectorFromC 1 u
112 error $ show y ++ " " ++ show z ++ " " ++ show u ++ " " ++ show v ++ " " ++ show dim
113 -- Convert the pointer we get from C (y) to a vector, and then 115 -- Convert the pointer we get from C (y) to a vector, and then
114 -- apply the user-supplied function. 116 -- apply the user-supplied function.
115 fImm <- fun x <$> vectorFromC dim u 117 fImm <- fun x <$> getDataFromContents dim y
116 -- Fill in the provided pointer with the resulting vector. 118 -- Fill in the provided pointer with the resulting vector.
117 -- vectorToC fImm dim f 119 putDataInContents fImm dim f
120 -- I don't understand what this comment means
118 -- Unsafe since the function will be called many times. 121 -- Unsafe since the function will be called many times.
119 [CU.exp| int{ 0 } |] 122 [CU.exp| int{ 0 } |]
120 res <- [C.block| int { 123 res <- [C.block| int {
@@ -140,7 +143,6 @@ solveOdeC fun f0 lambda = unsafePerformIO $ do
140 /* Initial diagnostics output */ 143 /* Initial diagnostics output */
141 printf("\nAnalytical ODE test problem:\n"); 144 printf("\nAnalytical ODE test problem:\n");
142 printf(" lamda = %"GSYM"\n", lamda); 145 printf(" lamda = %"GSYM"\n", lamda);
143 printf(" lambda = %"GSYM"\n", $(double lambda));
144 printf(" reltol = %.1"ESYM"\n", reltol); 146 printf(" reltol = %.1"ESYM"\n", reltol);
145 printf(" abstol = %.1"ESYM"\n\n",abstol); 147 printf(" abstol = %.1"ESYM"\n\n",abstol);
146 148
@@ -250,5 +252,5 @@ solveOdeC fun f0 lambda = unsafePerformIO $ do
250 return res 252 return res
251 253
252main = do 254main = do
253 let res = solveOdeC undefined (V.fromList [17.0]) (coerce (100.0 :: Double)) 255 let res = solveOdeC (coerce stiffish) (V.fromList [1.0])
254 putStrLn $ show res 256 putStrLn $ show res