diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/sundials/hmatrix-sundials.cabal | 6 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 56 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/Arkode.hsc | 88 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs | 37 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | 41 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 51 |
6 files changed, 149 insertions, 130 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 234bb9c..cd2be4e 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -32,16 +32,14 @@ library | |||
32 | exposed-modules: Numeric.Sundials.ODEOpts, | 32 | exposed-modules: Numeric.Sundials.ODEOpts, |
33 | Numeric.Sundials.ARKode.ODE, | 33 | Numeric.Sundials.ARKode.ODE, |
34 | Numeric.Sundials.CVode.ODE | 34 | Numeric.Sundials.CVode.ODE |
35 | other-modules: Numeric.Sundials.CLangToHaskellTypes, | 35 | other-modules: Numeric.Sundials.Arkode |
36 | Numeric.Sundials.Arkode | ||
37 | c-sources: src/helpers.c src/helpers.h | 36 | c-sources: src/helpers.c src/helpers.h |
38 | default-language: Haskell2010 | 37 | default-language: Haskell2010 |
39 | 38 | ||
40 | test-suite hmatrix-sundials-testsuite | 39 | test-suite hmatrix-sundials-testsuite |
41 | type: exitcode-stdio-1.0 | 40 | type: exitcode-stdio-1.0 |
42 | main-is: Main.hs | 41 | main-is: Main.hs |
43 | other-modules: Numeric.Sundials.CLangToHaskellTypes, | 42 | other-modules: Numeric.Sundials.ODEOpts, |
44 | Numeric.Sundials.ODEOpts, | ||
45 | Numeric.Sundials.ARKode.ODE, | 43 | Numeric.Sundials.ARKode.ODE, |
46 | Numeric.Sundials.CVode.ODE, | 44 | Numeric.Sundials.CVode.ODE, |
47 | Numeric.Sundials.Arkode | 45 | Numeric.Sundials.Arkode |
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index a8d418b..13b7eb8 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -125,6 +125,7 @@ import Data.Maybe (isJust) | |||
125 | 125 | ||
126 | import Foreign.C.Types (CDouble, CInt, CLong) | 126 | import Foreign.C.Types (CDouble, CInt, CLong) |
127 | import Foreign.Ptr (Ptr) | 127 | import Foreign.Ptr (Ptr) |
128 | import Foreign.Storable (poke) | ||
128 | 129 | ||
129 | import qualified Data.Vector.Storable as V | 130 | import qualified Data.Vector.Storable as V |
130 | 131 | ||
@@ -139,10 +140,33 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | |||
139 | subMatrix, rows, cols, toLists, | 140 | subMatrix, rows, cols, toLists, |
140 | size, subVector) | 141 | size, subVector) |
141 | 142 | ||
142 | import qualified Numeric.Sundials.CLangToHaskellTypes as T | ||
143 | import Numeric.Sundials.Arkode | ||
144 | import qualified Numeric.Sundials.Arkode as B | ||
145 | import qualified Numeric.Sundials.ODEOpts as SO | 143 | import qualified Numeric.Sundials.ODEOpts as SO |
144 | import qualified Numeric.Sundials.Arkode as T | ||
145 | import Numeric.Sundials.Arkode (getDataFromContents, putDataInContents, arkSMax, | ||
146 | sDIRK_2_1_2, | ||
147 | bILLINGTON_3_3_2, | ||
148 | tRBDF2_3_3_2, | ||
149 | kVAERNO_4_2_3, | ||
150 | aRK324L2SA_DIRK_4_2_3, | ||
151 | cASH_5_2_4, | ||
152 | cASH_5_3_4, | ||
153 | sDIRK_5_3_4, | ||
154 | kVAERNO_5_3_4, | ||
155 | aRK436L2SA_DIRK_6_3_4, | ||
156 | kVAERNO_7_4_5, | ||
157 | aRK548L2SA_DIRK_8_4_5, | ||
158 | hEUN_EULER_2_1_2, | ||
159 | bOGACKI_SHAMPINE_4_2_3, | ||
160 | aRK324L2SA_ERK_4_2_3, | ||
161 | zONNEVELD_5_3_4, | ||
162 | aRK436L2SA_ERK_6_3_4, | ||
163 | sAYFY_ABURUB_6_3_4, | ||
164 | cASH_KARP_6_4_5, | ||
165 | fEHLBERG_6_4_5, | ||
166 | dORMAND_PRINCE_7_4_5, | ||
167 | aRK548L2SA_ERK_8_4_5, | ||
168 | vERNER_8_5_6, | ||
169 | fEHLBERG_13_7_8) | ||
146 | 170 | ||
147 | 171 | ||
148 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 172 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -451,9 +475,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
451 | funIO x y f _ptr = do | 475 | funIO x y f _ptr = do |
452 | -- Convert the pointer we get from C (y) to a vector, and then | 476 | -- Convert the pointer we get from C (y) to a vector, and then |
453 | -- apply the user-supplied function. | 477 | -- apply the user-supplied function. |
454 | fImm <- fun x <$> SO.getDataFromContents dim y | 478 | fImm <- fun x <$> getDataFromContents dim y |
455 | -- Fill in the provided pointer with the resulting vector. | 479 | -- Fill in the provided pointer with the resulting vector. |
456 | SO.putDataInContents fImm dim f | 480 | putDataInContents fImm dim f |
457 | -- FIXME: I don't understand what this comment means | 481 | -- FIXME: I don't understand what this comment means |
458 | -- Unsafe since the function will be called many times. | 482 | -- Unsafe since the function will be called many times. |
459 | [CU.exp| int{ 0 } |] | 483 | [CU.exp| int{ 0 } |] |
@@ -465,8 +489,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
465 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | 489 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do |
466 | case jacH of | 490 | case jacH of |
467 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | 491 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
468 | Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y | 492 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y |
469 | SO.putMatrixDataFromContents j jacS | 493 | poke jacS j |
470 | -- FIXME: I don't understand what this comment means | 494 | -- FIXME: I don't understand what this comment means |
471 | -- Unsafe since the function will be called many times. | 495 | -- Unsafe since the function will be called many times. |
472 | [CU.exp| int{ 0 } |] | 496 | [CU.exp| int{ 0 } |] |
@@ -675,7 +699,7 @@ butcherTable method = | |||
675 | case getBT method of | 699 | case getBT method of |
676 | Left c -> error $ show c -- FIXME | 700 | Left c -> error $ show c -- FIXME |
677 | Right (ButcherTable' v w x y, sqp) -> | 701 | Right (ButcherTable' v w x y, sqp) -> |
678 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) | 702 | ButcherTable { am = subMatrix (0, 0) (s, s) $ (arkSMax >< arkSMax) (V.toList v) |
679 | , cv = subVector 0 s w | 703 | , cv = subVector 0 s w |
680 | , bv = subVector 0 s x | 704 | , bv = subVector 0 s x |
681 | , b2v = subVector 0 s y | 705 | , b2v = subVector 0 s y |
@@ -710,25 +734,25 @@ getButcherTable method = unsafePerformIO $ do | |||
710 | 734 | ||
711 | btSQP :: V.Vector CInt <- createVector 3 | 735 | btSQP :: V.Vector CInt <- createVector 3 |
712 | btSQPMut <- V.thaw btSQP | 736 | btSQPMut <- V.thaw btSQP |
713 | btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) | 737 | btAs :: V.Vector CDouble <- createVector (arkSMax * arkSMax) |
714 | btAsMut <- V.thaw btAs | 738 | btAsMut <- V.thaw btAs |
715 | btCs :: V.Vector CDouble <- createVector B.arkSMax | 739 | btCs :: V.Vector CDouble <- createVector arkSMax |
716 | btBs :: V.Vector CDouble <- createVector B.arkSMax | 740 | btBs :: V.Vector CDouble <- createVector arkSMax |
717 | btB2s :: V.Vector CDouble <- createVector B.arkSMax | 741 | btB2s :: V.Vector CDouble <- createVector arkSMax |
718 | btCsMut <- V.thaw btCs | 742 | btCsMut <- V.thaw btCs |
719 | btBsMut <- V.thaw btBs | 743 | btBsMut <- V.thaw btBs |
720 | btB2sMut <- V.thaw btB2s | 744 | btB2sMut <- V.thaw btB2s |
721 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 745 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
722 | funIOI x y f _ptr = do | 746 | funIOI x y f _ptr = do |
723 | fImm <- funI x <$> SO.getDataFromContents dim y | 747 | fImm <- funI x <$> getDataFromContents dim y |
724 | SO.putDataInContents fImm dim f | 748 | putDataInContents fImm dim f |
725 | -- FIXME: I don't understand what this comment means | 749 | -- FIXME: I don't understand what this comment means |
726 | -- Unsafe since the function will be called many times. | 750 | -- Unsafe since the function will be called many times. |
727 | [CU.exp| int{ 0 } |] | 751 | [CU.exp| int{ 0 } |] |
728 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 752 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
729 | funIOE x y f _ptr = do | 753 | funIOE x y f _ptr = do |
730 | fImm <- funE x <$> SO.getDataFromContents dim y | 754 | fImm <- funE x <$> getDataFromContents dim y |
731 | SO.putDataInContents fImm dim f | 755 | putDataInContents fImm dim f |
732 | -- FIXME: I don't understand what this comment means | 756 | -- FIXME: I don't understand what this comment means |
733 | -- Unsafe since the function will be called many times. | 757 | -- Unsafe since the function will be called many times. |
734 | [CU.exp| int{ 0 } |] | 758 | [CU.exp| int{ 0 } |] |
diff --git a/packages/sundials/src/Numeric/Sundials/Arkode.hsc b/packages/sundials/src/Numeric/Sundials/Arkode.hsc index 1700cdf..0850258 100644 --- a/packages/sundials/src/Numeric/Sundials/Arkode.hsc +++ b/packages/sundials/src/Numeric/Sundials/Arkode.hsc | |||
@@ -1,7 +1,23 @@ | |||
1 | {-# LANGUAGE QuasiQuotes #-} | ||
2 | {-# LANGUAGE TemplateHaskell #-} | ||
3 | {-# LANGUAGE OverloadedStrings #-} | ||
4 | {-# LANGUAGE EmptyDataDecls #-} | ||
5 | |||
1 | module Numeric.Sundials.Arkode where | 6 | module Numeric.Sundials.Arkode where |
2 | 7 | ||
3 | import Foreign | 8 | import Foreign |
4 | import Foreign.C.Types | 9 | import Foreign.C.Types |
10 | |||
11 | import Language.C.Types as CT | ||
12 | |||
13 | import qualified Data.Vector.Storable as VS | ||
14 | import qualified Data.Vector.Storable.Mutable as VM | ||
15 | |||
16 | import qualified Language.Haskell.TH as TH | ||
17 | import qualified Data.Map as Map | ||
18 | import Language.C.Inline.Context | ||
19 | |||
20 | import qualified Data.Vector.Storable as V | ||
5 | 21 | ||
6 | 22 | ||
7 | #include <stdio.h> | 23 | #include <stdio.h> |
@@ -13,6 +29,74 @@ import Foreign.C.Types | |||
13 | #include <cvode/cvode.h> | 29 | #include <cvode/cvode.h> |
14 | 30 | ||
15 | 31 | ||
32 | data SunVector | ||
33 | data SunMatrix = SunMatrix { rows :: CInt | ||
34 | , cols :: CInt | ||
35 | , vals :: V.Vector CDouble | ||
36 | } | ||
37 | |||
38 | -- | This is true only if configured/ built as 64 bits | ||
39 | type SunIndexType = CLong | ||
40 | |||
41 | sunTypesTable :: Map.Map TypeSpecifier TH.TypeQ | ||
42 | sunTypesTable = Map.fromList | ||
43 | [ | ||
44 | (TypeName "sunindextype", [t| SunIndexType |] ) | ||
45 | , (TypeName "SunVector", [t| SunVector |] ) | ||
46 | , (TypeName "SunMatrix", [t| SunMatrix |] ) | ||
47 | ] | ||
48 | |||
49 | sunCtx :: Context | ||
50 | sunCtx = mempty {ctxTypesTable = sunTypesTable} | ||
51 | |||
52 | getMatrixDataFromContents :: Ptr SunMatrix -> IO SunMatrix | ||
53 | getMatrixDataFromContents ptr = do | ||
54 | qtr <- getContentMatrixPtr ptr | ||
55 | rs <- getNRows qtr | ||
56 | cs <- getNCols qtr | ||
57 | rtr <- getMatrixData qtr | ||
58 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
59 | return $ SunMatrix { rows = rs, cols = cs, vals = vs } | ||
60 | |||
61 | putMatrixDataFromContents :: SunMatrix -> Ptr SunMatrix -> IO () | ||
62 | putMatrixDataFromContents mat ptr = do | ||
63 | let rs = rows mat | ||
64 | cs = cols mat | ||
65 | vs = vals mat | ||
66 | qtr <- getContentMatrixPtr ptr | ||
67 | putNRows rs qtr | ||
68 | putNCols cs qtr | ||
69 | rtr <- getMatrixData qtr | ||
70 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
71 | |||
72 | instance Storable SunMatrix where | ||
73 | poke = flip putMatrixDataFromContents | ||
74 | peek = getMatrixDataFromContents | ||
75 | sizeOf _ = error "sizeOf not supported for SunMatrix" | ||
76 | alignment _ = error "alignment not supported for SunMatrix" | ||
77 | |||
78 | vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) | ||
79 | vectorFromC len ptr = do | ||
80 | ptr' <- newForeignPtr_ ptr | ||
81 | VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
82 | |||
83 | vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () | ||
84 | vectorToC vec len ptr = do | ||
85 | ptr' <- newForeignPtr_ ptr | ||
86 | VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
87 | |||
88 | getDataFromContents :: Int -> Ptr SunVector -> IO (VS.Vector CDouble) | ||
89 | getDataFromContents len ptr = do | ||
90 | qtr <- getContentPtr ptr | ||
91 | rtr <- getData qtr | ||
92 | vectorFromC len rtr | ||
93 | |||
94 | putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () | ||
95 | putDataInContents vec len ptr = do | ||
96 | qtr <- getContentPtr ptr | ||
97 | rtr <- getData qtr | ||
98 | vectorToC vec len rtr | ||
99 | |||
16 | #def typedef struct _generic_N_Vector SunVector; | 100 | #def typedef struct _generic_N_Vector SunVector; |
17 | #def typedef struct _N_VectorContent_Serial SunContent; | 101 | #def typedef struct _N_VectorContent_Serial SunContent; |
18 | 102 | ||
diff --git a/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs b/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs deleted file mode 100644 index 0908cbe..0000000 --- a/packages/sundials/src/Numeric/Sundials/CLangToHaskellTypes.hs +++ /dev/null | |||
@@ -1,37 +0,0 @@ | |||
1 | {-# LANGUAGE QuasiQuotes #-} | ||
2 | {-# LANGUAGE TemplateHaskell #-} | ||
3 | {-# LANGUAGE OverloadedStrings #-} | ||
4 | {-# LANGUAGE EmptyDataDecls #-} | ||
5 | |||
6 | module Numeric.Sundials.CLangToHaskellTypes where | ||
7 | |||
8 | import Foreign.C.Types | ||
9 | |||
10 | import qualified Language.Haskell.TH as TH | ||
11 | import qualified Language.C.Types as CT | ||
12 | import qualified Data.Map as Map | ||
13 | import Language.C.Inline.Context | ||
14 | |||
15 | import qualified Data.Vector.Storable as V | ||
16 | |||
17 | |||
18 | data SunVector | ||
19 | data SunMatrix = SunMatrix { rows :: CInt | ||
20 | , cols :: CInt | ||
21 | , vals :: V.Vector CDouble | ||
22 | } | ||
23 | |||
24 | -- | This is true only if configured/ built as 64 bits | ||
25 | type SunIndexType = CLong | ||
26 | |||
27 | sunTypesTable :: Map.Map CT.TypeSpecifier TH.TypeQ | ||
28 | sunTypesTable = Map.fromList | ||
29 | [ | ||
30 | (CT.TypeName "sunindextype", [t| SunIndexType |] ) | ||
31 | , (CT.TypeName "SunVector", [t| SunVector |] ) | ||
32 | , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) | ||
33 | ] | ||
34 | |||
35 | sunCtx :: Context | ||
36 | sunCtx = mempty {ctxTypesTable = sunTypesTable} | ||
37 | |||
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index 1cd072f..159fbe2 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | |||
@@ -79,6 +79,7 @@ import Data.Maybe (isJust) | |||
79 | 79 | ||
80 | import Foreign.C.Types (CDouble, CInt, CLong) | 80 | import Foreign.C.Types (CDouble, CInt, CLong) |
81 | import Foreign.Ptr (Ptr) | 81 | import Foreign.Ptr (Ptr) |
82 | import Foreign.Storable (poke) | ||
82 | 83 | ||
83 | import qualified Data.Vector.Storable as V | 84 | import qualified Data.Vector.Storable as V |
84 | 85 | ||
@@ -90,10 +91,10 @@ import Numeric.LinearAlgebra.Devel (createVector) | |||
90 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, | 91 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, |
91 | cols, toLists, size, reshape) | 92 | cols, toLists, size, reshape) |
92 | 93 | ||
93 | import qualified Numeric.Sundials.CLangToHaskellTypes as T | 94 | import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF, |
94 | import Numeric.Sundials.Arkode (cV_ADAMS, cV_BDF) | 95 | getDataFromContents, putDataInContents) |
95 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian) | 96 | import qualified Numeric.Sundials.Arkode as T |
96 | import qualified Numeric.Sundials.ODEOpts as SO | 97 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian, SundialsDiagnostics(..)) |
97 | 98 | ||
98 | 99 | ||
99 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 100 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -195,7 +196,7 @@ odeSolveVWith' :: | |||
195 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 196 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
196 | -> V.Vector Double -- ^ Initial conditions | 197 | -> V.Vector Double -- ^ Initial conditions |
197 | -> V.Vector Double -- ^ Desired solution times | 198 | -> V.Vector Double -- ^ Desired solution times |
198 | -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution | 199 | -> Either Int (Matrix Double, SundialsDiagnostics) -- ^ Error code or solution |
199 | odeSolveVWith' opts method control initStepSize f y0 tt = | 200 | odeSolveVWith' opts method control initStepSize f y0 tt = |
200 | case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) | 201 | case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) |
201 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | 202 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) |
@@ -229,7 +230,7 @@ solveOdeC :: | |||
229 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 230 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
230 | -> V.Vector CDouble -- ^ Initial conditions | 231 | -> V.Vector CDouble -- ^ Initial conditions |
231 | -> V.Vector CDouble -- ^ Desired solution times | 232 | -> V.Vector CDouble -- ^ Desired solution times |
232 | -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution | 233 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution |
233 | solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = | 234 | solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = |
234 | unsafePerformIO $ do | 235 | unsafePerformIO $ do |
235 | 236 | ||
@@ -257,9 +258,9 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts | |||
257 | funIO x y f _ptr = do | 258 | funIO x y f _ptr = do |
258 | -- Convert the pointer we get from C (y) to a vector, and then | 259 | -- Convert the pointer we get from C (y) to a vector, and then |
259 | -- apply the user-supplied function. | 260 | -- apply the user-supplied function. |
260 | fImm <- fun x <$> SO.getDataFromContents dim y | 261 | fImm <- fun x <$> getDataFromContents dim y |
261 | -- Fill in the provided pointer with the resulting vector. | 262 | -- Fill in the provided pointer with the resulting vector. |
262 | SO.putDataInContents fImm dim f | 263 | putDataInContents fImm dim f |
263 | -- FIXME: I don't understand what this comment means | 264 | -- FIXME: I don't understand what this comment means |
264 | -- Unsafe since the function will be called many times. | 265 | -- Unsafe since the function will be called many times. |
265 | [CU.exp| int{ 0 } |] | 266 | [CU.exp| int{ 0 } |] |
@@ -271,8 +272,8 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts | |||
271 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | 272 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do |
272 | case jacH of | 273 | case jacH of |
273 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | 274 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
274 | Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y | 275 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y |
275 | SO.putMatrixDataFromContents j jacS | 276 | poke jacS j |
276 | -- FIXME: I don't understand what this comment means | 277 | -- FIXME: I don't understand what this comment means |
277 | -- Unsafe since the function will be called many times. | 278 | -- Unsafe since the function will be called many times. |
278 | [CU.exp| int{ 0 } |] | 279 | [CU.exp| int{ 0 } |] |
@@ -431,16 +432,16 @@ solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts | |||
431 | if res == 0 | 432 | if res == 0 |
432 | then do | 433 | then do |
433 | preD <- V.freeze diagMut | 434 | preD <- V.freeze diagMut |
434 | let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) | 435 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) |
435 | (fromIntegral $ preD V.!1) | 436 | (fromIntegral $ preD V.!1) |
436 | (fromIntegral $ preD V.!2) | 437 | (fromIntegral $ preD V.!2) |
437 | (fromIntegral $ preD V.!3) | 438 | (fromIntegral $ preD V.!3) |
438 | (fromIntegral $ preD V.!4) | 439 | (fromIntegral $ preD V.!4) |
439 | (fromIntegral $ preD V.!5) | 440 | (fromIntegral $ preD V.!5) |
440 | (fromIntegral $ preD V.!6) | 441 | (fromIntegral $ preD V.!6) |
441 | (fromIntegral $ preD V.!7) | 442 | (fromIntegral $ preD V.!7) |
442 | (fromIntegral $ preD V.!8) | 443 | (fromIntegral $ preD V.!8) |
443 | (fromIntegral $ preD V.!9) | 444 | (fromIntegral $ preD V.!9) |
444 | m <- V.freeze qMatMut | 445 | m <- V.freeze qMatMut |
445 | return $ Right (m, d) | 446 | return $ Right (m, d) |
446 | else do | 447 | else do |
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index 56dc12c..89f2306 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs | |||
@@ -1,17 +1,10 @@ | |||
1 | module Numeric.Sundials.ODEOpts where | 1 | module Numeric.Sundials.ODEOpts where |
2 | 2 | ||
3 | import Data.Int (Int32) | 3 | import Data.Int (Int32) |
4 | import Foreign.Ptr (Ptr) | ||
5 | import Foreign.Storable as FS | ||
6 | import Foreign.ForeignPtr as FF | ||
7 | import Foreign.C.Types | ||
8 | import qualified Data.Vector.Storable as VS | 4 | import qualified Data.Vector.Storable as VS |
9 | import qualified Data.Vector.Storable.Mutable as VM | ||
10 | 5 | ||
11 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | 6 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) |
12 | 7 | ||
13 | import qualified Numeric.Sundials.CLangToHaskellTypes as T | ||
14 | import qualified Numeric.Sundials.Arkode as B | ||
15 | 8 | ||
16 | type Jacobian = Double -> Vector Double -> Matrix Double | 9 | type Jacobian = Double -> Vector Double -> Matrix Double |
17 | 10 | ||
@@ -23,50 +16,6 @@ data ODEOpts = ODEOpts { | |||
23 | , initStep :: Double | 16 | , initStep :: Double |
24 | } deriving (Read, Show, Eq, Ord) | 17 | } deriving (Read, Show, Eq, Ord) |
25 | 18 | ||
26 | -- FIXME: Potentially an instance of Storable | ||
27 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
28 | _getMatrixDataFromContents ptr = do | ||
29 | qtr <- B.getContentMatrixPtr ptr | ||
30 | rs <- B.getNRows qtr | ||
31 | cs <- B.getNCols qtr | ||
32 | rtr <- B.getMatrixData qtr | ||
33 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
34 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
35 | |||
36 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
37 | putMatrixDataFromContents mat ptr = do | ||
38 | let rs = T.rows mat | ||
39 | cs = T.cols mat | ||
40 | vs = T.vals mat | ||
41 | qtr <- B.getContentMatrixPtr ptr | ||
42 | B.putNRows rs qtr | ||
43 | B.putNCols cs qtr | ||
44 | rtr <- B.getMatrixData qtr | ||
45 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
46 | -- FIXME: END | ||
47 | |||
48 | vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) | ||
49 | vectorFromC len ptr = do | ||
50 | ptr' <- newForeignPtr_ ptr | ||
51 | VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
52 | |||
53 | vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () | ||
54 | vectorToC vec len ptr = do | ||
55 | ptr' <- newForeignPtr_ ptr | ||
56 | VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
57 | |||
58 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (VS.Vector CDouble) | ||
59 | getDataFromContents len ptr = do | ||
60 | qtr <- B.getContentPtr ptr | ||
61 | rtr <- B.getData qtr | ||
62 | vectorFromC len rtr | ||
63 | |||
64 | putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () | ||
65 | putDataInContents vec len ptr = do | ||
66 | qtr <- B.getContentPtr ptr | ||
67 | rtr <- B.getData qtr | ||
68 | vectorToC vec len rtr | ||
69 | |||
70 | data SundialsDiagnostics = SundialsDiagnostics { | 19 | data SundialsDiagnostics = SundialsDiagnostics { |
71 | aRKodeGetNumSteps :: Int | 20 | aRKodeGetNumSteps :: Int |
72 | , aRKodeGetNumStepAttempts :: Int | 21 | , aRKodeGetNumStepAttempts :: Int |