diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2018-03-17 12:37:41 +0000 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2018-03-17 12:37:41 +0000 |
commit | 1635f317b5fe8bfcea33c5e7428598fffb0446d0 (patch) | |
tree | a068683ee467b8b6749952b61b2508cf161c1e3e /packages/sundials/src/Main.hs | |
parent | 7b98f43f6ddca22ec0581e1a48e0626e2a8bc3de (diff) |
Start of the correct (non_Fortran) way
Diffstat (limited to 'packages/sundials/src/Main.hs')
-rw-r--r-- | packages/sundials/src/Main.hs | 57 |
1 files changed, 49 insertions, 8 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index bab5710..328af08 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -28,7 +28,13 @@ import qualified Language.C.Types as CT | |||
28 | import qualified Data.Map as Map | 28 | import qualified Data.Map as Map |
29 | import Language.C.Inline.Context | 29 | import Language.C.Inline.Context |
30 | 30 | ||
31 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx) | 31 | import Foreign.C.String |
32 | import Foreign.Storable (peek, poke, peekByteOff) | ||
33 | import Data.Int | ||
34 | |||
35 | import qualified Types as T | ||
36 | |||
37 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
32 | 38 | ||
33 | -- C includes | 39 | -- C includes |
34 | C.include "<stdio.h>" | 40 | C.include "<stdio.h>" |
@@ -43,6 +49,21 @@ C.include "<sundials/sundials_math.h>" | |||
43 | C.include "helpers.h" | 49 | C.include "helpers.h" |
44 | 50 | ||
45 | 51 | ||
52 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
53 | -- template. They are probably very fragile and could easily break on | ||
54 | -- different architectures and / or changes in the sundials package. | ||
55 | |||
56 | getContentPtr :: Storable a => Ptr b -> IO a | ||
57 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
58 | |||
59 | getData :: Storable a => Ptr b -> IO a | ||
60 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
61 | |||
62 | getDataFromContents :: Storable a => Ptr b -> IO a | ||
63 | getDataFromContents ptr = do | ||
64 | qtr <- getContentPtr ptr | ||
65 | getData qtr | ||
66 | |||
46 | -- Utils | 67 | -- Utils |
47 | 68 | ||
48 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | 69 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) |
@@ -55,11 +76,23 @@ vectorToC vec len ptr = do | |||
55 | ptr' <- newForeignPtr_ ptr | 76 | ptr' <- newForeignPtr_ ptr |
56 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 77 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
57 | 78 | ||
58 | foreign export ccall singleEq :: Double -> Double -> IO Double | 79 | -- Provided you always call your function 'multiEq' then we can |
80 | -- probably solve any set of ODEs! But of course we don't want to | ||
81 | -- follow the Fortran way of interacting with sundials. | ||
82 | |||
83 | -- foreign export ccall multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO () | ||
59 | 84 | ||
60 | singleEq :: Double -> Double -> IO Double | 85 | multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO () |
61 | singleEq t u = return $ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t | 86 | multiEq tPtr yPtr yDotPtr iParPtr rParPtr ierPtr = do |
87 | t <- peek tPtr | ||
88 | y <- vectorFromC 1 yPtr | ||
89 | vectorToC (V.map realToFrac $ stiffish (realToFrac t) (V.map realToFrac y)) 1 yDotPtr | ||
90 | poke ierPtr 0 | ||
91 | |||
92 | stiffish :: Double -> V.Vector Double -> V.Vector Double | ||
93 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
62 | where | 94 | where |
95 | u = v V.! 0 | ||
63 | lamda = -100.0 | 96 | lamda = -100.0 |
64 | 97 | ||
65 | solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | 98 | solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> |
@@ -68,12 +101,16 @@ solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | |||
68 | CInt | 101 | CInt |
69 | solve fun f0 lambda = unsafePerformIO $ do | 102 | solve fun f0 lambda = unsafePerformIO $ do |
70 | let dim = V.length f0 | 103 | let dim = V.length f0 |
71 | let funIO x y f _ptr = do | 104 | -- We need the types that sundials expects. These are tied together |
105 | -- in 'Types'. The Haskell type is currently empty! | ||
106 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
107 | funIO x y f _ptr = do | ||
108 | error $ show x | ||
72 | -- Convert the pointer we get from C (y) to a vector, and then | 109 | -- Convert the pointer we get from C (y) to a vector, and then |
73 | -- apply the user-supplied function. | 110 | -- apply the user-supplied function. |
74 | fImm <- fun x <$> vectorFromC dim y | 111 | -- fImm <- fun x <$> vectorFromC dim y |
75 | -- Fill in the provided pointer with the resulting vector. | 112 | -- Fill in the provided pointer with the resulting vector. |
76 | vectorToC fImm dim f | 113 | -- vectorToC fImm dim f |
77 | -- Unsafe since the function will be called many times. | 114 | -- Unsafe since the function will be called many times. |
78 | [CU.exp| int{ 0 } |] | 115 | [CU.exp| int{ 0 } |] |
79 | res <- [C.block| int { | 116 | res <- [C.block| int { |
@@ -114,7 +151,11 @@ solve fun f0 lambda = unsafePerformIO $ do | |||
114 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | 151 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ |
115 | /* the initial dependent variable vector y. Note: since this */ | 152 | /* the initial dependent variable vector y. Note: since this */ |
116 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | 153 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ |
117 | flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); | 154 | |
155 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
156 | /* the Haskell types defined in Types */ | ||
157 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
158 | /* flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); */ | ||
118 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 159 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
119 | 160 | ||
120 | /* Set routines */ | 161 | /* Set routines */ |