diff options
-rw-r--r-- | packages/sundials/src/Main.hs | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 328af08..473daf7 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -2,6 +2,7 @@ | |||
2 | {-# LANGUAGE TemplateHaskell #-} | 2 | {-# LANGUAGE TemplateHaskell #-} |
3 | {-# LANGUAGE MultiWayIf #-} | 3 | {-# LANGUAGE MultiWayIf #-} |
4 | {-# LANGUAGE OverloadedStrings #-} | 4 | {-# LANGUAGE OverloadedStrings #-} |
5 | {-# LANGUAGE ScopedTypeVariables #-} | ||
5 | 6 | ||
6 | import qualified Language.C.Inline as C | 7 | import qualified Language.C.Inline as C |
7 | import qualified Language.C.Inline.Unsafe as CU | 8 | import qualified Language.C.Inline.Unsafe as CU |
@@ -95,20 +96,23 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | |||
95 | u = v V.! 0 | 96 | u = v V.! 0 |
96 | lamda = -100.0 | 97 | lamda = -100.0 |
97 | 98 | ||
98 | solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | 99 | solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> |
99 | V.Vector Double -> | 100 | V.Vector Double -> |
100 | CDouble -> | 101 | CDouble -> |
101 | CInt | 102 | CInt |
102 | solve fun f0 lambda = unsafePerformIO $ do | 103 | solveOdeC fun f0 lambda = unsafePerformIO $ do |
103 | let dim = V.length f0 | 104 | let dim = V.length f0 |
104 | -- We need the types that sundials expects. These are tied together | 105 | -- We need the types that sundials expects. These are tied together |
105 | -- in 'Types'. The Haskell type is currently empty! | 106 | -- in 'Types'. The Haskell type is currently empty! |
106 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 107 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt |
107 | funIO x y f _ptr = do | 108 | funIO x y f _ptr = do |
108 | error $ show x | 109 | z :: (Ptr (Ptr (CDouble))) <- getContentPtr y |
110 | u :: (Ptr (CDouble)) <- getDataFromContents y | ||
111 | v <- vectorFromC 1 u | ||
112 | error $ show y ++ " " ++ show z ++ " " ++ show u ++ " " ++ show v ++ " " ++ show dim | ||
109 | -- Convert the pointer we get from C (y) to a vector, and then | 113 | -- Convert the pointer we get from C (y) to a vector, and then |
110 | -- apply the user-supplied function. | 114 | -- apply the user-supplied function. |
111 | -- fImm <- fun x <$> vectorFromC dim y | 115 | fImm <- fun x <$> vectorFromC dim u |
112 | -- Fill in the provided pointer with the resulting vector. | 116 | -- Fill in the provided pointer with the resulting vector. |
113 | -- vectorToC fImm dim f | 117 | -- vectorToC fImm dim f |
114 | -- Unsafe since the function will be called many times. | 118 | -- Unsafe since the function will be called many times. |
@@ -246,5 +250,5 @@ solve fun f0 lambda = unsafePerformIO $ do | |||
246 | return res | 250 | return res |
247 | 251 | ||
248 | main = do | 252 | main = do |
249 | let res = solve undefined undefined (coerce (100.0 :: Double)) | 253 | let res = solveOdeC undefined (V.fromList [17.0]) (coerce (100.0 :: Double)) |
250 | putStrLn $ show res | 254 | putStrLn $ show res |