summaryrefslogtreecommitdiff
path: root/packages/sundials/src
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-17 12:37:41 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-17 12:37:41 +0000
commit1635f317b5fe8bfcea33c5e7428598fffb0446d0 (patch)
treea068683ee467b8b6749952b61b2508cf161c1e3e /packages/sundials/src
parent7b98f43f6ddca22ec0581e1a48e0626e2a8bc3de (diff)
Start of the correct (non_Fortran) way
Diffstat (limited to 'packages/sundials/src')
-rw-r--r--packages/sundials/src/Bar.hsc24
-rw-r--r--packages/sundials/src/Main.hs57
-rw-r--r--packages/sundials/src/Types.hs12
-rw-r--r--packages/sundials/src/helpers.c5
-rw-r--r--packages/sundials/src/helpers.h3
5 files changed, 87 insertions, 14 deletions
diff --git a/packages/sundials/src/Bar.hsc b/packages/sundials/src/Bar.hsc
new file mode 100644
index 0000000..b1159b6
--- /dev/null
+++ b/packages/sundials/src/Bar.hsc
@@ -0,0 +1,24 @@
1{-# LANGUAGE RecordWildCards #-}
2
3module Example where
4
5import Foreign
6import Foreign.C.Types
7import Foreign.C.String
8
9#include "/Users/dom/sundials/include/sundials/sundials_nvector.h"
10#include "/Users/dom/sundials/include/nvector/nvector_serial.h"
11
12#def typedef struct _generic_N_Vector BarType;
13#def typedef struct _N_VectorContent_Serial BazType;
14
15
16getContentPtr :: Storable a => Ptr b -> IO a
17getContPtr ptr = (#peek BarType, content) ptr
18
19getData ptr = (#peek BazType, data) ptr
20
21foo ptr = do
22 qtr <- getContPtr ptr
23 getData qtr
24
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index bab5710..328af08 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -28,7 +28,13 @@ import qualified Language.C.Types as CT
28import qualified Data.Map as Map 28import qualified Data.Map as Map
29import Language.C.Inline.Context 29import Language.C.Inline.Context
30 30
31C.context (C.baseCtx <> C.vecCtx <> C.funCtx) 31import Foreign.C.String
32import Foreign.Storable (peek, poke, peekByteOff)
33import Data.Int
34
35import qualified Types as T
36
37C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
32 38
33-- C includes 39-- C includes
34C.include "<stdio.h>" 40C.include "<stdio.h>"
@@ -43,6 +49,21 @@ C.include "<sundials/sundials_math.h>"
43C.include "helpers.h" 49C.include "helpers.h"
44 50
45 51
52-- These were semi-generated using hsc2hs with Bar.hsc as the
53-- template. They are probably very fragile and could easily break on
54-- different architectures and / or changes in the sundials package.
55
56getContentPtr :: Storable a => Ptr b -> IO a
57getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr
58
59getData :: Storable a => Ptr b -> IO a
60getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr
61
62getDataFromContents :: Storable a => Ptr b -> IO a
63getDataFromContents ptr = do
64 qtr <- getContentPtr ptr
65 getData qtr
66
46-- Utils 67-- Utils
47 68
48vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) 69vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
@@ -55,11 +76,23 @@ vectorToC vec len ptr = do
55 ptr' <- newForeignPtr_ ptr 76 ptr' <- newForeignPtr_ ptr
56 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 77 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
57 78
58foreign export ccall singleEq :: Double -> Double -> IO Double 79-- Provided you always call your function 'multiEq' then we can
80-- probably solve any set of ODEs! But of course we don't want to
81-- follow the Fortran way of interacting with sundials.
82
83-- foreign export ccall multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO ()
59 84
60singleEq :: Double -> Double -> IO Double 85multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO ()
61singleEq t u = return $ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t 86multiEq tPtr yPtr yDotPtr iParPtr rParPtr ierPtr = do
87 t <- peek tPtr
88 y <- vectorFromC 1 yPtr
89 vectorToC (V.map realToFrac $ stiffish (realToFrac t) (V.map realToFrac y)) 1 yDotPtr
90 poke ierPtr 0
91
92stiffish :: Double -> V.Vector Double -> V.Vector Double
93stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
62 where 94 where
95 u = v V.! 0
63 lamda = -100.0 96 lamda = -100.0
64 97
65solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> 98solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
@@ -68,12 +101,16 @@ solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
68 CInt 101 CInt
69solve fun f0 lambda = unsafePerformIO $ do 102solve fun f0 lambda = unsafePerformIO $ do
70 let dim = V.length f0 103 let dim = V.length f0
71 let funIO x y f _ptr = do 104 -- We need the types that sundials expects. These are tied together
105 -- in 'Types'. The Haskell type is currently empty!
106 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
107 funIO x y f _ptr = do
108 error $ show x
72 -- Convert the pointer we get from C (y) to a vector, and then 109 -- Convert the pointer we get from C (y) to a vector, and then
73 -- apply the user-supplied function. 110 -- apply the user-supplied function.
74 fImm <- fun x <$> vectorFromC dim y 111 -- fImm <- fun x <$> vectorFromC dim y
75 -- Fill in the provided pointer with the resulting vector. 112 -- Fill in the provided pointer with the resulting vector.
76 vectorToC fImm dim f 113 -- vectorToC fImm dim f
77 -- Unsafe since the function will be called many times. 114 -- Unsafe since the function will be called many times.
78 [CU.exp| int{ 0 } |] 115 [CU.exp| int{ 0 } |]
79 res <- [C.block| int { 116 res <- [C.block| int {
@@ -114,7 +151,11 @@ solve fun f0 lambda = unsafePerformIO $ do
114 /* right-hand side function in y'=f(t,y), the inital time T0, and */ 151 /* right-hand side function in y'=f(t,y), the inital time T0, and */
115 /* the initial dependent variable vector y. Note: since this */ 152 /* the initial dependent variable vector y. Note: since this */
116 /* problem is fully implicit, we set f_E to NULL and f_I to f. */ 153 /* problem is fully implicit, we set f_E to NULL and f_I to f. */
117 flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); 154
155 /* Here we use the C types defined in helpers.h which tie up with */
156 /* the Haskell types defined in Types */
157 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y);
158 /* flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); */
118 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 159 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
119 160
120 /* Set routines */ 161 /* Set routines */
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs
index 355850d..325072c 100644
--- a/packages/sundials/src/Types.hs
+++ b/packages/sundials/src/Types.hs
@@ -20,7 +20,7 @@ import qualified Data.Vector.Storable.Mutable as VM
20import Foreign.C.Types 20import Foreign.C.Types
21import Foreign.ForeignPtr (newForeignPtr_) 21import Foreign.ForeignPtr (newForeignPtr_)
22import Foreign.Ptr (Ptr) 22import Foreign.Ptr (Ptr)
23import Foreign.Storable (Storable) 23import Foreign.Storable (Storable(..))
24import qualified Language.C.Inline as C 24import qualified Language.C.Inline as C
25import qualified Language.C.Inline.Unsafe as CU 25import qualified Language.C.Inline.Unsafe as CU
26import System.IO.Unsafe (unsafePerformIO) 26import System.IO.Unsafe (unsafePerformIO)
@@ -30,6 +30,13 @@ import qualified Language.C.Types as CT
30import qualified Data.Map as Map 30import qualified Data.Map as Map
31import Language.C.Inline.Context 31import Language.C.Inline.Context
32 32
33data BarType
34
35instance Storable BarType where
36 sizeOf _ = sizeOf (undefined :: BarType)
37 alignment _ = alignment (undefined :: Ptr ())
38 peek _ = error "peek not implemented for BarType"
39 poke _ _ = error "poke not implemented for BarType"
33 40
34-- This is a lie!!! 41-- This is a lie!!!
35type SunIndexType = CLong 42type SunIndexType = CLong
@@ -38,7 +45,8 @@ sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ
38sunTypesTable = Map.fromList 45sunTypesTable = Map.fromList
39 [ 46 [
40 (CT.TypeName "sunindextype", [t| SunIndexType |] ) 47 (CT.TypeName "sunindextype", [t| SunIndexType |] )
48 , (CT.TypeName "BarType", [t| BarType |] )
41 ] 49 ]
42 50
43sunctx = mempty {ctxTypesTable = sunTypesTable} 51sunCtx = mempty {ctxTypesTable = sunTypesTable}
44 52
diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c
index eab5ac9..6162b71 100644
--- a/packages/sundials/src/helpers.c
+++ b/packages/sundials/src/helpers.c
@@ -64,10 +64,7 @@ int f(realtype t, N_Vector y, N_Vector ydot, void *user_data)
64 64
65int FARK_IMP_FUN(realtype *T, realtype *Y, realtype *YDOT, 65int FARK_IMP_FUN(realtype *T, realtype *Y, realtype *YDOT,
66 long int *IPAR, realtype *RPAR, int *IER) { 66 long int *IPAR, realtype *RPAR, int *IER) {
67 realtype t = *T; 67 multiEq(T, Y, YDOT, IPAR, RPAR, IER);
68 realtype u = Y[0];
69 realtype lamda = -100.0;
70 YDOT[0] = singleEq(t, u);
71 return 0; 68 return 0;
72} 69}
73 70
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h
index 3b50163..7f4ba02 100644
--- a/packages/sundials/src/helpers.h
+++ b/packages/sundials/src/helpers.h
@@ -8,6 +8,9 @@
8#define FSYM "f" 8#define FSYM "f"
9#endif 9#endif
10 10
11typedef struct _generic_N_Vector BarType;
12typedef struct _N_VectorContent_Serial BazType;
13
11/* Check function return value... 14/* Check function return value...
12 opt == 0 means SUNDIALS function allocates memory so check if 15 opt == 0 means SUNDIALS function allocates memory so check if
13 returned NULL pointer 16 returned NULL pointer