diff options
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/Arkode.hsc')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/Arkode.hsc | 88 |
1 files changed, 86 insertions, 2 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc index 1700cdf..0850258 100644 --- a/packages/sundials/src/Numeric/Sundials/Arkode.hsc +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc | |||
@@ -1,7 +1,23 @@ | |||
1 | {-# LANGUAGE QuasiQuotes #-} | ||
2 | {-# LANGUAGE TemplateHaskell #-} | ||
3 | {-# LANGUAGE OverloadedStrings #-} | ||
4 | {-# LANGUAGE EmptyDataDecls #-} | ||
5 | |||
1 | module Numeric.Sundials.Arkode where | 6 | module Numeric.Sundials.Arkode where |
2 | 7 | ||
3 | import Foreign | 8 | import Foreign |
4 | import Foreign.C.Types | 9 | import Foreign.C.Types |
10 | |||
11 | import Language.C.Types as CT | ||
12 | |||
13 | import qualified Data.Vector.Storable as VS | ||
14 | import qualified Data.Vector.Storable.Mutable as VM | ||
15 | |||
16 | import qualified Language.Haskell.TH as TH | ||
17 | import qualified Data.Map as Map | ||
18 | import Language.C.Inline.Context | ||
19 | |||
20 | import qualified Data.Vector.Storable as V | ||
5 | 21 | ||
6 | 22 | ||
7 | #include <stdio.h> | 23 | #include <stdio.h> |
@@ -13,6 +29,74 @@ import Foreign.C.Types | |||
13 | #include <cvode/cvode.h> | 29 | #include <cvode/cvode.h> |
14 | 30 | ||
15 | 31 | ||
32 | data SunVector | ||
33 | data SunMatrix = SunMatrix { rows :: CInt | ||
34 | , cols :: CInt | ||
35 | , vals :: V.Vector CDouble | ||
36 | } | ||
37 | |||
38 | -- | This is true only if configured/ built as 64 bits | ||
39 | type SunIndexType = CLong | ||
40 | |||
41 | sunTypesTable :: Map.Map TypeSpecifier TH.TypeQ | ||
42 | sunTypesTable = Map.fromList | ||
43 | [ | ||
44 | (TypeName "sunindextype", [t| SunIndexType |] ) | ||
45 | , (TypeName "SunVector", [t| SunVector |] ) | ||
46 | , (TypeName "SunMatrix", [t| SunMatrix |] ) | ||
47 | ] | ||
48 | |||
49 | sunCtx :: Context | ||
50 | sunCtx = mempty {ctxTypesTable = sunTypesTable} | ||
51 | |||
52 | getMatrixDataFromContents :: Ptr SunMatrix -> IO SunMatrix | ||
53 | getMatrixDataFromContents ptr = do | ||
54 | qtr <- getContentMatrixPtr ptr | ||
55 | rs <- getNRows qtr | ||
56 | cs <- getNCols qtr | ||
57 | rtr <- getMatrixData qtr | ||
58 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
59 | return $ SunMatrix { rows = rs, cols = cs, vals = vs } | ||
60 | |||
61 | putMatrixDataFromContents :: SunMatrix -> Ptr SunMatrix -> IO () | ||
62 | putMatrixDataFromContents mat ptr = do | ||
63 | let rs = rows mat | ||
64 | cs = cols mat | ||
65 | vs = vals mat | ||
66 | qtr <- getContentMatrixPtr ptr | ||
67 | putNRows rs qtr | ||
68 | putNCols cs qtr | ||
69 | rtr <- getMatrixData qtr | ||
70 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
71 | |||
72 | instance Storable SunMatrix where | ||
73 | poke = flip putMatrixDataFromContents | ||
74 | peek = getMatrixDataFromContents | ||
75 | sizeOf _ = error "sizeOf not supported for SunMatrix" | ||
76 | alignment _ = error "alignment not supported for SunMatrix" | ||
77 | |||
78 | vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) | ||
79 | vectorFromC len ptr = do | ||
80 | ptr' <- newForeignPtr_ ptr | ||
81 | VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
82 | |||
83 | vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () | ||
84 | vectorToC vec len ptr = do | ||
85 | ptr' <- newForeignPtr_ ptr | ||
86 | VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
87 | |||
88 | getDataFromContents :: Int -> Ptr SunVector -> IO (VS.Vector CDouble) | ||
89 | getDataFromContents len ptr = do | ||
90 | qtr <- getContentPtr ptr | ||
91 | rtr <- getData qtr | ||
92 | vectorFromC len rtr | ||
93 | |||
94 | putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () | ||
95 | putDataInContents vec len ptr = do | ||
96 | qtr <- getContentPtr ptr | ||
97 | rtr <- getData qtr | ||
98 | vectorToC vec len rtr | ||
99 | |||
16 | #def typedef struct _generic_N_Vector SunVector; | 100 | #def typedef struct _generic_N_Vector SunVector; |
17 | #def typedef struct _N_VectorContent_Serial SunContent; | 101 | #def typedef struct _N_VectorContent_Serial SunContent; |
18 | 102 | ||