summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Main.hs
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-17 12:37:41 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-17 12:37:41 +0000
commit1635f317b5fe8bfcea33c5e7428598fffb0446d0 (patch)
treea068683ee467b8b6749952b61b2508cf161c1e3e /packages/sundials/src/Main.hs
parent7b98f43f6ddca22ec0581e1a48e0626e2a8bc3de (diff)
Start of the correct (non_Fortran) way
Diffstat (limited to 'packages/sundials/src/Main.hs')
-rw-r--r--packages/sundials/src/Main.hs57
1 files changed, 49 insertions, 8 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index bab5710..328af08 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -28,7 +28,13 @@ import qualified Language.C.Types as CT
28import qualified Data.Map as Map 28import qualified Data.Map as Map
29import Language.C.Inline.Context 29import Language.C.Inline.Context
30 30
31C.context (C.baseCtx <> C.vecCtx <> C.funCtx) 31import Foreign.C.String
32import Foreign.Storable (peek, poke, peekByteOff)
33import Data.Int
34
35import qualified Types as T
36
37C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
32 38
33-- C includes 39-- C includes
34C.include "<stdio.h>" 40C.include "<stdio.h>"
@@ -43,6 +49,21 @@ C.include "<sundials/sundials_math.h>"
43C.include "helpers.h" 49C.include "helpers.h"
44 50
45 51
52-- These were semi-generated using hsc2hs with Bar.hsc as the
53-- template. They are probably very fragile and could easily break on
54-- different architectures and / or changes in the sundials package.
55
56getContentPtr :: Storable a => Ptr b -> IO a
57getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr
58
59getData :: Storable a => Ptr b -> IO a
60getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr
61
62getDataFromContents :: Storable a => Ptr b -> IO a
63getDataFromContents ptr = do
64 qtr <- getContentPtr ptr
65 getData qtr
66
46-- Utils 67-- Utils
47 68
48vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) 69vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
@@ -55,11 +76,23 @@ vectorToC vec len ptr = do
55 ptr' <- newForeignPtr_ ptr 76 ptr' <- newForeignPtr_ ptr
56 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 77 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
57 78
58foreign export ccall singleEq :: Double -> Double -> IO Double 79-- Provided you always call your function 'multiEq' then we can
80-- probably solve any set of ODEs! But of course we don't want to
81-- follow the Fortran way of interacting with sundials.
82
83-- foreign export ccall multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO ()
59 84
60singleEq :: Double -> Double -> IO Double 85multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO ()
61singleEq t u = return $ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t 86multiEq tPtr yPtr yDotPtr iParPtr rParPtr ierPtr = do
87 t <- peek tPtr
88 y <- vectorFromC 1 yPtr
89 vectorToC (V.map realToFrac $ stiffish (realToFrac t) (V.map realToFrac y)) 1 yDotPtr
90 poke ierPtr 0
91
92stiffish :: Double -> V.Vector Double -> V.Vector Double
93stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
62 where 94 where
95 u = v V.! 0
63 lamda = -100.0 96 lamda = -100.0
64 97
65solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> 98solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
@@ -68,12 +101,16 @@ solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
68 CInt 101 CInt
69solve fun f0 lambda = unsafePerformIO $ do 102solve fun f0 lambda = unsafePerformIO $ do
70 let dim = V.length f0 103 let dim = V.length f0
71 let funIO x y f _ptr = do 104 -- We need the types that sundials expects. These are tied together
105 -- in 'Types'. The Haskell type is currently empty!
106 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
107 funIO x y f _ptr = do
108 error $ show x
72 -- Convert the pointer we get from C (y) to a vector, and then 109 -- Convert the pointer we get from C (y) to a vector, and then
73 -- apply the user-supplied function. 110 -- apply the user-supplied function.
74 fImm <- fun x <$> vectorFromC dim y 111 -- fImm <- fun x <$> vectorFromC dim y
75 -- Fill in the provided pointer with the resulting vector. 112 -- Fill in the provided pointer with the resulting vector.
76 vectorToC fImm dim f 113 -- vectorToC fImm dim f
77 -- Unsafe since the function will be called many times. 114 -- Unsafe since the function will be called many times.
78 [CU.exp| int{ 0 } |] 115 [CU.exp| int{ 0 } |]
79 res <- [C.block| int { 116 res <- [C.block| int {
@@ -114,7 +151,11 @@ solve fun f0 lambda = unsafePerformIO $ do
114 /* right-hand side function in y'=f(t,y), the inital time T0, and */ 151 /* right-hand side function in y'=f(t,y), the inital time T0, and */
115 /* the initial dependent variable vector y. Note: since this */ 152 /* the initial dependent variable vector y. Note: since this */
116 /* problem is fully implicit, we set f_E to NULL and f_I to f. */ 153 /* problem is fully implicit, we set f_E to NULL and f_I to f. */
117 flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); 154
155 /* Here we use the C types defined in helpers.h which tie up with */
156 /* the Haskell types defined in Types */
157 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y);
158 /* flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); */
118 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 159 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
119 160
120 /* Set routines */ 161 /* Set routines */