diff options
Diffstat (limited to 'packages/sundials/src')
-rw-r--r-- | packages/sundials/src/Main.hs | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 473daf7..89d6668 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -30,7 +30,7 @@ import qualified Data.Map as Map | |||
30 | import Language.C.Inline.Context | 30 | import Language.C.Inline.Context |
31 | 31 | ||
32 | import Foreign.C.String | 32 | import Foreign.C.String |
33 | import Foreign.Storable (peek, poke, peekByteOff) | 33 | import Foreign.Storable (peek, poke, peekByteOff, pokeByteOff) |
34 | import Data.Int | 34 | import Data.Int |
35 | 35 | ||
36 | import qualified Types as T | 36 | import qualified Types as T |
@@ -60,10 +60,17 @@ getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | |||
60 | getData :: Storable a => Ptr b -> IO a | 60 | getData :: Storable a => Ptr b -> IO a |
61 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | 61 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr |
62 | 62 | ||
63 | getDataFromContents :: Storable a => Ptr b -> IO a | 63 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) |
64 | getDataFromContents ptr = do | 64 | getDataFromContents len ptr = do |
65 | qtr <- getContentPtr ptr | 65 | qtr <- getContentPtr ptr |
66 | getData qtr | 66 | rtr <- getData qtr |
67 | vectorFromC len rtr | ||
68 | |||
69 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
70 | putDataInContents vec len ptr = do | ||
71 | qtr <- getContentPtr ptr | ||
72 | rtr <- getData qtr | ||
73 | vectorToC vec len rtr | ||
67 | 74 | ||
68 | -- Utils | 75 | -- Utils |
69 | 76 | ||
@@ -98,23 +105,19 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | |||
98 | 105 | ||
99 | solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | 106 | solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> |
100 | V.Vector Double -> | 107 | V.Vector Double -> |
101 | CDouble -> | ||
102 | CInt | 108 | CInt |
103 | solveOdeC fun f0 lambda = unsafePerformIO $ do | 109 | solveOdeC fun f0 = unsafePerformIO $ do |
104 | let dim = V.length f0 | 110 | let dim = V.length f0 |
105 | -- We need the types that sundials expects. These are tied together | 111 | -- We need the types that sundials expects. These are tied together |
106 | -- in 'Types'. The Haskell type is currently empty! | 112 | -- in 'Types'. The Haskell type is currently empty! |
107 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 113 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt |
108 | funIO x y f _ptr = do | 114 | funIO x y f _ptr = do |
109 | z :: (Ptr (Ptr (CDouble))) <- getContentPtr y | ||
110 | u :: (Ptr (CDouble)) <- getDataFromContents y | ||
111 | v <- vectorFromC 1 u | ||
112 | error $ show y ++ " " ++ show z ++ " " ++ show u ++ " " ++ show v ++ " " ++ show dim | ||
113 | -- Convert the pointer we get from C (y) to a vector, and then | 115 | -- Convert the pointer we get from C (y) to a vector, and then |
114 | -- apply the user-supplied function. | 116 | -- apply the user-supplied function. |
115 | fImm <- fun x <$> vectorFromC dim u | 117 | fImm <- fun x <$> getDataFromContents dim y |
116 | -- Fill in the provided pointer with the resulting vector. | 118 | -- Fill in the provided pointer with the resulting vector. |
117 | -- vectorToC fImm dim f | 119 | putDataInContents fImm dim f |
120 | -- I don't understand what this comment means | ||
118 | -- Unsafe since the function will be called many times. | 121 | -- Unsafe since the function will be called many times. |
119 | [CU.exp| int{ 0 } |] | 122 | [CU.exp| int{ 0 } |] |
120 | res <- [C.block| int { | 123 | res <- [C.block| int { |
@@ -140,7 +143,6 @@ solveOdeC fun f0 lambda = unsafePerformIO $ do | |||
140 | /* Initial diagnostics output */ | 143 | /* Initial diagnostics output */ |
141 | printf("\nAnalytical ODE test problem:\n"); | 144 | printf("\nAnalytical ODE test problem:\n"); |
142 | printf(" lamda = %"GSYM"\n", lamda); | 145 | printf(" lamda = %"GSYM"\n", lamda); |
143 | printf(" lambda = %"GSYM"\n", $(double lambda)); | ||
144 | printf(" reltol = %.1"ESYM"\n", reltol); | 146 | printf(" reltol = %.1"ESYM"\n", reltol); |
145 | printf(" abstol = %.1"ESYM"\n\n",abstol); | 147 | printf(" abstol = %.1"ESYM"\n\n",abstol); |
146 | 148 | ||
@@ -250,5 +252,5 @@ solveOdeC fun f0 lambda = unsafePerformIO $ do | |||
250 | return res | 252 | return res |
251 | 253 | ||
252 | main = do | 254 | main = do |
253 | let res = solveOdeC undefined (V.fromList [17.0]) (coerce (100.0 :: Double)) | 255 | let res = solveOdeC (coerce stiffish) (V.fromList [1.0]) |
254 | putStrLn $ show res | 256 | putStrLn $ show res |