diff options
Diffstat (limited to 'packages/sundials/src')
-rw-r--r-- | packages/sundials/src/Bar.hsc | 9 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 28 | ||||
-rw-r--r-- | packages/sundials/src/Types.hs | 14 | ||||
-rw-r--r-- | packages/sundials/src/helpers.h | 5 |
4 files changed, 18 insertions, 38 deletions
diff --git a/packages/sundials/src/Bar.hsc b/packages/sundials/src/Bar.hsc index 7d53af9..4fe1b4b 100644 --- a/packages/sundials/src/Bar.hsc +++ b/packages/sundials/src/Bar.hsc | |||
@@ -10,15 +10,14 @@ import Foreign.C.String | |||
10 | #include "/Users/dom/sundials/include/nvector/nvector_serial.h" | 10 | #include "/Users/dom/sundials/include/nvector/nvector_serial.h" |
11 | #include "/Users/dom/sundials/include/arkode/arkode.h" | 11 | #include "/Users/dom/sundials/include/arkode/arkode.h" |
12 | 12 | ||
13 | #def typedef struct _generic_N_Vector BarType; | 13 | #def typedef struct _generic_N_Vector SunVector; |
14 | #def typedef struct _N_VectorContent_Serial BazType; | 14 | #def typedef struct _N_VectorContent_Serial SunContent; |
15 | |||
16 | 15 | ||
17 | getContentPtr :: Storable a => Ptr b -> IO a | 16 | getContentPtr :: Storable a => Ptr b -> IO a |
18 | getContentPtr ptr = (#peek BarType, content) ptr | 17 | getContentPtr ptr = (#peek SunVector, content) ptr |
19 | 18 | ||
20 | getData :: Storable a => Ptr b -> IO a | 19 | getData :: Storable a => Ptr b -> IO a |
21 | getData ptr = (#peek BazType, data) ptr | 20 | getData ptr = (#peek SunContent, data) ptr |
22 | 21 | ||
23 | arkSMax :: Int | 22 | arkSMax :: Int |
24 | arkSMax = #const ARK_S_MAX | 23 | arkSMax = #const ARK_S_MAX |
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 44b724e..b621c58 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -61,9 +61,9 @@ import qualified Types as T | |||
61 | import Bar (sDIRK_2_1_2, kVAERNO_4_2_3) | 61 | import Bar (sDIRK_2_1_2, kVAERNO_4_2_3) |
62 | import qualified Bar as B | 62 | import qualified Bar as B |
63 | 63 | ||
64 | |||
64 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 65 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
65 | 66 | ||
66 | -- C includes | ||
67 | C.include "<stdlib.h>" | 67 | C.include "<stdlib.h>" |
68 | C.include "<stdio.h>" | 68 | C.include "<stdio.h>" |
69 | C.include "<math.h>" | 69 | C.include "<math.h>" |
@@ -77,26 +77,16 @@ C.include "<sundials/sundials_math.h>" | |||
77 | C.include "../../../helpers.h" | 77 | C.include "../../../helpers.h" |
78 | 78 | ||
79 | 79 | ||
80 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
81 | -- template. They are probably very fragile and could easily break on | ||
82 | -- different architectures and / or changes in the sundials package. | ||
83 | |||
84 | getContentPtr :: Storable a => Ptr b -> IO a | ||
85 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
86 | |||
87 | getData :: Storable a => Ptr b -> IO a | ||
88 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
89 | |||
90 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) | 80 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) |
91 | getDataFromContents len ptr = do | 81 | getDataFromContents len ptr = do |
92 | qtr <- getContentPtr ptr | 82 | qtr <- B.getContentPtr ptr |
93 | rtr <- getData qtr | 83 | rtr <- B.getData qtr |
94 | vectorFromC len rtr | 84 | vectorFromC len rtr |
95 | 85 | ||
96 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | 86 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () |
97 | putDataInContents vec len ptr = do | 87 | putDataInContents vec len ptr = do |
98 | qtr <- getContentPtr ptr | 88 | qtr <- B.getContentPtr ptr |
99 | rtr <- getData qtr | 89 | rtr <- B.getData qtr |
100 | vectorToC vec len rtr | 90 | vectorToC vec len rtr |
101 | 91 | ||
102 | -- Utils | 92 | -- Utils |
@@ -199,7 +189,7 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do | |||
199 | diagMut <- V.thaw diagnostics | 189 | diagMut <- V.thaw diagnostics |
200 | -- We need the types that sundials expects. These are tied together | 190 | -- We need the types that sundials expects. These are tied together |
201 | -- in 'Types'. FIXME: The Haskell type is currently empty! | 191 | -- in 'Types'. FIXME: The Haskell type is currently empty! |
202 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 192 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
203 | funIO x y f _ptr = do | 193 | funIO x y f _ptr = do |
204 | -- Convert the pointer we get from C (y) to a vector, and then | 194 | -- Convert the pointer we get from C (y) to a vector, and then |
205 | -- apply the user-supplied function. | 195 | -- apply the user-supplied function. |
@@ -240,7 +230,7 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do | |||
240 | 230 | ||
241 | /* Here we use the C types defined in helpers.h which tie up with */ | 231 | /* Here we use the C types defined in helpers.h which tie up with */ |
242 | /* the Haskell types defined in Types */ | 232 | /* the Haskell types defined in Types */ |
243 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | 233 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); |
244 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 234 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
245 | 235 | ||
246 | /* Set routines */ | 236 | /* Set routines */ |
@@ -378,7 +368,7 @@ getButcherTable method = unsafePerformIO $ do | |||
378 | btAsMut <- V.thaw btAs | 368 | btAsMut <- V.thaw btAs |
379 | -- We need the types that sundials expects. These are tied together | 369 | -- We need the types that sundials expects. These are tied together |
380 | -- in 'Types'. FIXME: The Haskell type is currently empty! | 370 | -- in 'Types'. FIXME: The Haskell type is currently empty! |
381 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 371 | let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
382 | funIO x y f _ptr = do | 372 | funIO x y f _ptr = do |
383 | -- Convert the pointer we get from C (y) to a vector, and then | 373 | -- Convert the pointer we get from C (y) to a vector, and then |
384 | -- apply the user-supplied function. | 374 | -- apply the user-supplied function. |
@@ -411,7 +401,7 @@ getButcherTable method = unsafePerformIO $ do | |||
411 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | 401 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ |
412 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | 402 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; |
413 | 403 | ||
414 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | 404 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); |
415 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 405 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
416 | 406 | ||
417 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); | 407 | flag = ARKodeSetIRKTableNum(arkode_mem, $(int mN)); |
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs index 9654527..e910c57 100644 --- a/packages/sundials/src/Types.hs +++ b/packages/sundials/src/Types.hs | |||
@@ -9,31 +9,23 @@ | |||
9 | module Types where | 9 | module Types where |
10 | 10 | ||
11 | import Foreign.C.Types | 11 | import Foreign.C.Types |
12 | import Foreign.Ptr (Ptr) | ||
13 | |||
14 | import Foreign.Storable (Storable(..)) | ||
15 | 12 | ||
16 | import qualified Language.Haskell.TH as TH | 13 | import qualified Language.Haskell.TH as TH |
17 | import qualified Language.C.Types as CT | 14 | import qualified Language.C.Types as CT |
18 | import qualified Data.Map as Map | 15 | import qualified Data.Map as Map |
19 | import Language.C.Inline.Context | 16 | import Language.C.Inline.Context |
20 | 17 | ||
21 | data BarType | ||
22 | 18 | ||
23 | instance Storable BarType where | 19 | data SunVector |
24 | sizeOf _ = sizeOf (undefined :: BarType) | ||
25 | alignment _ = alignment (undefined :: Ptr ()) | ||
26 | peek _ = error "peek not implemented for BarType" | ||
27 | poke _ _ = error "poke not implemented for BarType" | ||
28 | 20 | ||
29 | -- This is a lie!!! | 21 | -- FIXME: Is this true? |
30 | type SunIndexType = CLong | 22 | type SunIndexType = CLong |
31 | 23 | ||
32 | sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ | 24 | sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ |
33 | sunTypesTable = Map.fromList | 25 | sunTypesTable = Map.fromList |
34 | [ | 26 | [ |
35 | (CT.TypeName "sunindextype", [t| SunIndexType |] ) | 27 | (CT.TypeName "sunindextype", [t| SunIndexType |] ) |
36 | , (CT.TypeName "BarType", [t| BarType |] ) | 28 | , (CT.TypeName "SunVector", [t| SunVector |] ) |
37 | ] | 29 | ] |
38 | 30 | ||
39 | sunCtx :: Context | 31 | sunCtx :: Context |
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h index 5c1d9f3..b41ab73 100644 --- a/packages/sundials/src/helpers.h +++ b/packages/sundials/src/helpers.h | |||
@@ -8,8 +8,7 @@ | |||
8 | #define FSYM "f" | 8 | #define FSYM "f" |
9 | #endif | 9 | #endif |
10 | 10 | ||
11 | typedef struct _generic_N_Vector BarType; | 11 | typedef struct _generic_N_Vector SunVector; |
12 | typedef struct _N_VectorContent_Serial BazType; | ||
13 | 12 | ||
14 | /* Check function return value... | 13 | /* Check function return value... |
15 | opt == 0 means SUNDIALS function allocates memory so check if | 14 | opt == 0 means SUNDIALS function allocates memory so check if |
@@ -17,6 +16,6 @@ typedef struct _N_VectorContent_Serial BazType; | |||
17 | opt == 1 means SUNDIALS function returns a flag so check if | 16 | opt == 1 means SUNDIALS function returns a flag so check if |
18 | flag >= 0 | 17 | flag >= 0 |
19 | opt == 2 means function allocates memory so check if returned | 18 | opt == 2 means function allocates memory so check if returned |
20 | NULL pointer | 19 | NULL pointer |
21 | */ | 20 | */ |
22 | int check_flag(void *flagvalue, const char *funcname, int opt); | 21 | int check_flag(void *flagvalue, const char *funcname, int opt); |