summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Main.hs
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-18 13:58:12 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-18 13:58:12 +0000
commite6a217ec2615f6fc12c777aeb878a5a207f3b17c (patch)
tree79afcbf392066a7d3ec337ee9f48d1010b93f05c /packages/sundials/src/Main.hs
parent1635f317b5fe8bfcea33c5e7428598fffb0446d0 (diff)
Check we are passed the correct values
Diffstat (limited to 'packages/sundials/src/Main.hs')
-rw-r--r--packages/sundials/src/Main.hs20
1 files changed, 12 insertions, 8 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 328af08..473daf7 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -2,6 +2,7 @@
2{-# LANGUAGE TemplateHaskell #-} 2{-# LANGUAGE TemplateHaskell #-}
3{-# LANGUAGE MultiWayIf #-} 3{-# LANGUAGE MultiWayIf #-}
4{-# LANGUAGE OverloadedStrings #-} 4{-# LANGUAGE OverloadedStrings #-}
5{-# LANGUAGE ScopedTypeVariables #-}
5 6
6import qualified Language.C.Inline as C 7import qualified Language.C.Inline as C
7import qualified Language.C.Inline.Unsafe as CU 8import qualified Language.C.Inline.Unsafe as CU
@@ -95,20 +96,23 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
95 u = v V.! 0 96 u = v V.! 0
96 lamda = -100.0 97 lamda = -100.0
97 98
98solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> 99solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
99 V.Vector Double -> 100 V.Vector Double ->
100 CDouble -> 101 CDouble ->
101 CInt 102 CInt
102solve fun f0 lambda = unsafePerformIO $ do 103solveOdeC fun f0 lambda = unsafePerformIO $ do
103 let dim = V.length f0 104 let dim = V.length f0
104 -- We need the types that sundials expects. These are tied together 105 -- We need the types that sundials expects. These are tied together
105 -- in 'Types'. The Haskell type is currently empty! 106 -- in 'Types'. The Haskell type is currently empty!
106 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt 107 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
107 funIO x y f _ptr = do 108 funIO x y f _ptr = do
108 error $ show x 109 z :: (Ptr (Ptr (CDouble))) <- getContentPtr y
110 u :: (Ptr (CDouble)) <- getDataFromContents y
111 v <- vectorFromC 1 u
112 error $ show y ++ " " ++ show z ++ " " ++ show u ++ " " ++ show v ++ " " ++ show dim
109 -- Convert the pointer we get from C (y) to a vector, and then 113 -- Convert the pointer we get from C (y) to a vector, and then
110 -- apply the user-supplied function. 114 -- apply the user-supplied function.
111 -- fImm <- fun x <$> vectorFromC dim y 115 fImm <- fun x <$> vectorFromC dim u
112 -- Fill in the provided pointer with the resulting vector. 116 -- Fill in the provided pointer with the resulting vector.
113 -- vectorToC fImm dim f 117 -- vectorToC fImm dim f
114 -- Unsafe since the function will be called many times. 118 -- Unsafe since the function will be called many times.
@@ -246,5 +250,5 @@ solve fun f0 lambda = unsafePerformIO $ do
246 return res 250 return res
247 251
248main = do 252main = do
249 let res = solve undefined undefined (coerce (100.0 :: Double)) 253 let res = solveOdeC undefined (V.fromList [17.0]) (coerce (100.0 :: Double))
250 putStrLn $ show res 254 putStrLn $ show res