diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2018-03-17 12:37:41 +0000 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2018-03-17 12:37:41 +0000 |
commit | 1635f317b5fe8bfcea33c5e7428598fffb0446d0 (patch) | |
tree | a068683ee467b8b6749952b61b2508cf161c1e3e /packages/sundials/src | |
parent | 7b98f43f6ddca22ec0581e1a48e0626e2a8bc3de (diff) |
Start of the correct (non_Fortran) way
Diffstat (limited to 'packages/sundials/src')
-rw-r--r-- | packages/sundials/src/Bar.hsc | 24 | ||||
-rw-r--r-- | packages/sundials/src/Main.hs | 57 | ||||
-rw-r--r-- | packages/sundials/src/Types.hs | 12 | ||||
-rw-r--r-- | packages/sundials/src/helpers.c | 5 | ||||
-rw-r--r-- | packages/sundials/src/helpers.h | 3 |
5 files changed, 87 insertions, 14 deletions
diff --git a/packages/sundials/src/Bar.hsc b/packages/sundials/src/Bar.hsc new file mode 100644 index 0000000..b1159b6 --- /dev/null +++ b/packages/sundials/src/Bar.hsc | |||
@@ -0,0 +1,24 @@ | |||
1 | {-# LANGUAGE RecordWildCards #-} | ||
2 | |||
3 | module Example where | ||
4 | |||
5 | import Foreign | ||
6 | import Foreign.C.Types | ||
7 | import Foreign.C.String | ||
8 | |||
9 | #include "/Users/dom/sundials/include/sundials/sundials_nvector.h" | ||
10 | #include "/Users/dom/sundials/include/nvector/nvector_serial.h" | ||
11 | |||
12 | #def typedef struct _generic_N_Vector BarType; | ||
13 | #def typedef struct _N_VectorContent_Serial BazType; | ||
14 | |||
15 | |||
16 | getContentPtr :: Storable a => Ptr b -> IO a | ||
17 | getContPtr ptr = (#peek BarType, content) ptr | ||
18 | |||
19 | getData ptr = (#peek BazType, data) ptr | ||
20 | |||
21 | foo ptr = do | ||
22 | qtr <- getContPtr ptr | ||
23 | getData qtr | ||
24 | |||
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index bab5710..328af08 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -28,7 +28,13 @@ import qualified Language.C.Types as CT | |||
28 | import qualified Data.Map as Map | 28 | import qualified Data.Map as Map |
29 | import Language.C.Inline.Context | 29 | import Language.C.Inline.Context |
30 | 30 | ||
31 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx) | 31 | import Foreign.C.String |
32 | import Foreign.Storable (peek, poke, peekByteOff) | ||
33 | import Data.Int | ||
34 | |||
35 | import qualified Types as T | ||
36 | |||
37 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
32 | 38 | ||
33 | -- C includes | 39 | -- C includes |
34 | C.include "<stdio.h>" | 40 | C.include "<stdio.h>" |
@@ -43,6 +49,21 @@ C.include "<sundials/sundials_math.h>" | |||
43 | C.include "helpers.h" | 49 | C.include "helpers.h" |
44 | 50 | ||
45 | 51 | ||
52 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
53 | -- template. They are probably very fragile and could easily break on | ||
54 | -- different architectures and / or changes in the sundials package. | ||
55 | |||
56 | getContentPtr :: Storable a => Ptr b -> IO a | ||
57 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
58 | |||
59 | getData :: Storable a => Ptr b -> IO a | ||
60 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
61 | |||
62 | getDataFromContents :: Storable a => Ptr b -> IO a | ||
63 | getDataFromContents ptr = do | ||
64 | qtr <- getContentPtr ptr | ||
65 | getData qtr | ||
66 | |||
46 | -- Utils | 67 | -- Utils |
47 | 68 | ||
48 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | 69 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) |
@@ -55,11 +76,23 @@ vectorToC vec len ptr = do | |||
55 | ptr' <- newForeignPtr_ ptr | 76 | ptr' <- newForeignPtr_ ptr |
56 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 77 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
57 | 78 | ||
58 | foreign export ccall singleEq :: Double -> Double -> IO Double | 79 | -- Provided you always call your function 'multiEq' then we can |
80 | -- probably solve any set of ODEs! But of course we don't want to | ||
81 | -- follow the Fortran way of interacting with sundials. | ||
82 | |||
83 | -- foreign export ccall multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO () | ||
59 | 84 | ||
60 | singleEq :: Double -> Double -> IO Double | 85 | multiEq :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> Ptr CLong -> Ptr CDouble -> Ptr CInt -> IO () |
61 | singleEq t u = return $ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t | 86 | multiEq tPtr yPtr yDotPtr iParPtr rParPtr ierPtr = do |
87 | t <- peek tPtr | ||
88 | y <- vectorFromC 1 yPtr | ||
89 | vectorToC (V.map realToFrac $ stiffish (realToFrac t) (V.map realToFrac y)) 1 yDotPtr | ||
90 | poke ierPtr 0 | ||
91 | |||
92 | stiffish :: Double -> V.Vector Double -> V.Vector Double | ||
93 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
62 | where | 94 | where |
95 | u = v V.! 0 | ||
63 | lamda = -100.0 | 96 | lamda = -100.0 |
64 | 97 | ||
65 | solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | 98 | solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> |
@@ -68,12 +101,16 @@ solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | |||
68 | CInt | 101 | CInt |
69 | solve fun f0 lambda = unsafePerformIO $ do | 102 | solve fun f0 lambda = unsafePerformIO $ do |
70 | let dim = V.length f0 | 103 | let dim = V.length f0 |
71 | let funIO x y f _ptr = do | 104 | -- We need the types that sundials expects. These are tied together |
105 | -- in 'Types'. The Haskell type is currently empty! | ||
106 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
107 | funIO x y f _ptr = do | ||
108 | error $ show x | ||
72 | -- Convert the pointer we get from C (y) to a vector, and then | 109 | -- Convert the pointer we get from C (y) to a vector, and then |
73 | -- apply the user-supplied function. | 110 | -- apply the user-supplied function. |
74 | fImm <- fun x <$> vectorFromC dim y | 111 | -- fImm <- fun x <$> vectorFromC dim y |
75 | -- Fill in the provided pointer with the resulting vector. | 112 | -- Fill in the provided pointer with the resulting vector. |
76 | vectorToC fImm dim f | 113 | -- vectorToC fImm dim f |
77 | -- Unsafe since the function will be called many times. | 114 | -- Unsafe since the function will be called many times. |
78 | [CU.exp| int{ 0 } |] | 115 | [CU.exp| int{ 0 } |] |
79 | res <- [C.block| int { | 116 | res <- [C.block| int { |
@@ -114,7 +151,11 @@ solve fun f0 lambda = unsafePerformIO $ do | |||
114 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | 151 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ |
115 | /* the initial dependent variable vector y. Note: since this */ | 152 | /* the initial dependent variable vector y. Note: since this */ |
116 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | 153 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ |
117 | flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); | 154 | |
155 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
156 | /* the Haskell types defined in Types */ | ||
157 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
158 | /* flag = ARKodeInit(arkode_mem, NULL, FARKfi, T0, y); */ | ||
118 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 159 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
119 | 160 | ||
120 | /* Set routines */ | 161 | /* Set routines */ |
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs index 355850d..325072c 100644 --- a/packages/sundials/src/Types.hs +++ b/packages/sundials/src/Types.hs | |||
@@ -20,7 +20,7 @@ import qualified Data.Vector.Storable.Mutable as VM | |||
20 | import Foreign.C.Types | 20 | import Foreign.C.Types |
21 | import Foreign.ForeignPtr (newForeignPtr_) | 21 | import Foreign.ForeignPtr (newForeignPtr_) |
22 | import Foreign.Ptr (Ptr) | 22 | import Foreign.Ptr (Ptr) |
23 | import Foreign.Storable (Storable) | 23 | import Foreign.Storable (Storable(..)) |
24 | import qualified Language.C.Inline as C | 24 | import qualified Language.C.Inline as C |
25 | import qualified Language.C.Inline.Unsafe as CU | 25 | import qualified Language.C.Inline.Unsafe as CU |
26 | import System.IO.Unsafe (unsafePerformIO) | 26 | import System.IO.Unsafe (unsafePerformIO) |
@@ -30,6 +30,13 @@ import qualified Language.C.Types as CT | |||
30 | import qualified Data.Map as Map | 30 | import qualified Data.Map as Map |
31 | import Language.C.Inline.Context | 31 | import Language.C.Inline.Context |
32 | 32 | ||
33 | data BarType | ||
34 | |||
35 | instance Storable BarType where | ||
36 | sizeOf _ = sizeOf (undefined :: BarType) | ||
37 | alignment _ = alignment (undefined :: Ptr ()) | ||
38 | peek _ = error "peek not implemented for BarType" | ||
39 | poke _ _ = error "poke not implemented for BarType" | ||
33 | 40 | ||
34 | -- This is a lie!!! | 41 | -- This is a lie!!! |
35 | type SunIndexType = CLong | 42 | type SunIndexType = CLong |
@@ -38,7 +45,8 @@ sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ | |||
38 | sunTypesTable = Map.fromList | 45 | sunTypesTable = Map.fromList |
39 | [ | 46 | [ |
40 | (CT.TypeName "sunindextype", [t| SunIndexType |] ) | 47 | (CT.TypeName "sunindextype", [t| SunIndexType |] ) |
48 | , (CT.TypeName "BarType", [t| BarType |] ) | ||
41 | ] | 49 | ] |
42 | 50 | ||
43 | sunctx = mempty {ctxTypesTable = sunTypesTable} | 51 | sunCtx = mempty {ctxTypesTable = sunTypesTable} |
44 | 52 | ||
diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c index eab5ac9..6162b71 100644 --- a/packages/sundials/src/helpers.c +++ b/packages/sundials/src/helpers.c | |||
@@ -64,10 +64,7 @@ int f(realtype t, N_Vector y, N_Vector ydot, void *user_data) | |||
64 | 64 | ||
65 | int FARK_IMP_FUN(realtype *T, realtype *Y, realtype *YDOT, | 65 | int FARK_IMP_FUN(realtype *T, realtype *Y, realtype *YDOT, |
66 | long int *IPAR, realtype *RPAR, int *IER) { | 66 | long int *IPAR, realtype *RPAR, int *IER) { |
67 | realtype t = *T; | 67 | multiEq(T, Y, YDOT, IPAR, RPAR, IER); |
68 | realtype u = Y[0]; | ||
69 | realtype lamda = -100.0; | ||
70 | YDOT[0] = singleEq(t, u); | ||
71 | return 0; | 68 | return 0; |
72 | } | 69 | } |
73 | 70 | ||
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h index 3b50163..7f4ba02 100644 --- a/packages/sundials/src/helpers.h +++ b/packages/sundials/src/helpers.h | |||
@@ -8,6 +8,9 @@ | |||
8 | #define FSYM "f" | 8 | #define FSYM "f" |
9 | #endif | 9 | #endif |
10 | 10 | ||
11 | typedef struct _generic_N_Vector BarType; | ||
12 | typedef struct _N_VectorContent_Serial BazType; | ||
13 | |||
11 | /* Check function return value... | 14 | /* Check function return value... |
12 | opt == 0 means SUNDIALS function allocates memory so check if | 15 | opt == 0 means SUNDIALS function allocates memory so check if |
13 | returned NULL pointer | 16 | returned NULL pointer |