diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/sundials/hmatrix-sundials.cabal | 4 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 123 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | 153 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 78 |
4 files changed, 137 insertions, 221 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index 4cc02c6..b7fa0fe 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -29,7 +29,8 @@ library | |||
29 | sundials_cvode | 29 | sundials_cvode |
30 | other-extensions: QuasiQuotes | 30 | other-extensions: QuasiQuotes |
31 | hs-source-dirs: src | 31 | hs-source-dirs: src |
32 | exposed-modules: Numeric.Sundials.ARKode.ODE, | 32 | exposed-modules: Numeric.Sundials.ODEOpts, |
33 | Numeric.Sundials.ARKode.ODE, | ||
33 | Numeric.Sundials.CVode.ODE | 34 | Numeric.Sundials.CVode.ODE |
34 | other-modules: Types, | 35 | other-modules: Types, |
35 | Arkode | 36 | Arkode |
@@ -40,6 +41,7 @@ test-suite hmatrix-sundials-testsuite | |||
40 | type: exitcode-stdio-1.0 | 41 | type: exitcode-stdio-1.0 |
41 | main-is: Main.hs | 42 | main-is: Main.hs |
42 | other-modules: Types, | 43 | other-modules: Types, |
44 | Numeric.Sundials.ODEOpts, | ||
43 | Numeric.Sundials.ARKode.ODE, | 45 | Numeric.Sundials.ARKode.ODE, |
44 | Numeric.Sundials.CVode.ODE, | 46 | Numeric.Sundials.CVode.ODE, |
45 | Arkode | 47 | Arkode |
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index e5a2e4d..8b713c6 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -117,7 +117,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve | |||
117 | , ODEMethod(..) | 117 | , ODEMethod(..) |
118 | , StepControl(..) | 118 | , StepControl(..) |
119 | , Jacobian | 119 | , Jacobian |
120 | , SundialsDiagnostics(..) | ||
121 | ) where | 120 | ) where |
122 | 121 | ||
123 | import qualified Language.C.Inline as C | 122 | import qualified Language.C.Inline as C |
@@ -126,17 +125,15 @@ import qualified Language.C.Inline.Unsafe as CU | |||
126 | import Data.Monoid ((<>)) | 125 | import Data.Monoid ((<>)) |
127 | import Data.Maybe (isJust) | 126 | import Data.Maybe (isJust) |
128 | 127 | ||
129 | import Foreign.C.Types | 128 | import Foreign.C.Types (CDouble, CInt, CLong) |
130 | import Foreign.Ptr (Ptr) | 129 | import Foreign.Ptr (Ptr) |
131 | import Foreign.ForeignPtr (newForeignPtr_) | ||
132 | import Foreign.Storable (Storable) | ||
133 | 130 | ||
134 | import qualified Data.Vector.Storable as V | 131 | import qualified Data.Vector.Storable as V |
135 | import qualified Data.Vector.Storable.Mutable as VM | ||
136 | 132 | ||
137 | import Data.Coerce (coerce) | 133 | import Data.Coerce (coerce) |
138 | import System.IO.Unsafe (unsafePerformIO) | 134 | import System.IO.Unsafe (unsafePerformIO) |
139 | import GHC.Generics | 135 | import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..), |
136 | from, conName) | ||
140 | 137 | ||
141 | import Numeric.LinearAlgebra.Devel (createVector) | 138 | import Numeric.LinearAlgebra.Devel (createVector) |
142 | 139 | ||
@@ -147,6 +144,7 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | |||
147 | import qualified Types as T | 144 | import qualified Types as T |
148 | import Arkode | 145 | import Arkode |
149 | import qualified Arkode as B | 146 | import qualified Arkode as B |
147 | import qualified Numeric.Sundials.ODEOpts as SO | ||
150 | 148 | ||
151 | 149 | ||
152 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 150 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -165,65 +163,6 @@ C.include "../../../helpers.h" | |||
165 | C.include "Arkode_hsc.h" | 163 | C.include "Arkode_hsc.h" |
166 | 164 | ||
167 | 165 | ||
168 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) | ||
169 | getDataFromContents len ptr = do | ||
170 | qtr <- B.getContentPtr ptr | ||
171 | rtr <- B.getData qtr | ||
172 | vectorFromC len rtr | ||
173 | |||
174 | -- FIXME: Potentially an instance of Storable | ||
175 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
176 | _getMatrixDataFromContents ptr = do | ||
177 | qtr <- B.getContentMatrixPtr ptr | ||
178 | rs <- B.getNRows qtr | ||
179 | cs <- B.getNCols qtr | ||
180 | rtr <- B.getMatrixData qtr | ||
181 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
182 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
183 | |||
184 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
185 | putMatrixDataFromContents mat ptr = do | ||
186 | let rs = T.rows mat | ||
187 | cs = T.cols mat | ||
188 | vs = T.vals mat | ||
189 | qtr <- B.getContentMatrixPtr ptr | ||
190 | B.putNRows rs qtr | ||
191 | B.putNCols cs qtr | ||
192 | rtr <- B.getMatrixData qtr | ||
193 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
194 | -- FIXME: END | ||
195 | |||
196 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
197 | putDataInContents vec len ptr = do | ||
198 | qtr <- B.getContentPtr ptr | ||
199 | rtr <- B.getData qtr | ||
200 | vectorToC vec len rtr | ||
201 | |||
202 | -- Utils | ||
203 | |||
204 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
205 | vectorFromC len ptr = do | ||
206 | ptr' <- newForeignPtr_ ptr | ||
207 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
208 | |||
209 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
210 | vectorToC vec len ptr = do | ||
211 | ptr' <- newForeignPtr_ ptr | ||
212 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
213 | |||
214 | data SundialsDiagnostics = SundialsDiagnostics { | ||
215 | aRKodeGetNumSteps :: Int | ||
216 | , aRKodeGetNumStepAttempts :: Int | ||
217 | , aRKodeGetNumRhsEvals_fe :: Int | ||
218 | , aRKodeGetNumRhsEvals_fi :: Int | ||
219 | , aRKodeGetNumLinSolvSetups :: Int | ||
220 | , aRKodeGetNumErrTestFails :: Int | ||
221 | , aRKodeGetNumNonlinSolvIters :: Int | ||
222 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
223 | , aRKDlsGetNumJacEvals :: Int | ||
224 | , aRKDlsGetNumRhsEvals :: Int | ||
225 | } deriving Show | ||
226 | |||
227 | type Jacobian = Double -> Vector Double -> Matrix Double | 166 | type Jacobian = Double -> Vector Double -> Matrix Double |
228 | 167 | ||
229 | -- | Stepping functions | 168 | -- | Stepping functions |
@@ -390,7 +329,7 @@ odeSolveV | |||
390 | -> Vector Double -- ^ desired solution times | 329 | -> Vector Double -- ^ desired solution times |
391 | -> Matrix Double -- ^ solution | 330 | -> Matrix Double -- ^ solution |
392 | odeSolveV meth hi epsAbs epsRel f y0 ts = | 331 | odeSolveV meth hi epsAbs epsRel f y0 ts = |
393 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of | 332 | case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of |
394 | Left c -> error $ show c -- FIXME | 333 | Left c -> error $ show c -- FIXME |
395 | -- FIXME: Can we do better than using lists? | 334 | -- FIXME: Can we do better than using lists? |
396 | Right (v, _d) -> (nR >< nC) (V.toList v) | 335 | Right (v, _d) -> (nR >< nC) (V.toList v) |
@@ -410,7 +349,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
410 | -> Matrix Double -- ^ solution | 349 | -> Matrix Double -- ^ solution |
411 | odeSolve f y0 ts = | 350 | odeSolve f y0 ts = |
412 | -- FIXME: These tolerances are different from the ones in GSL | 351 | -- FIXME: These tolerances are different from the ones in GSL |
413 | case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | 352 | case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of |
414 | Left c -> error $ show c -- FIXME | 353 | Left c -> error $ show c -- FIXME |
415 | Right (v, _d) -> (nR >< nC) (V.toList v) | 354 | Right (v, _d) -> (nR >< nC) (V.toList v) |
416 | where | 355 | where |
@@ -419,7 +358,7 @@ odeSolve f y0 ts = | |||
419 | nC = length y0 | 358 | nC = length y0 |
420 | g t x0 = V.fromList $ f t (V.toList x0) | 359 | g t x0 = V.fromList $ f t (V.toList x0) |
421 | 360 | ||
422 | odeSolveVWith' :: | 361 | odeSolveVWith :: |
423 | ODEMethod | 362 | ODEMethod |
424 | -> StepControl | 363 | -> StepControl |
425 | -> Maybe Double -- ^ initial step size - by default, ARKode | 364 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -432,15 +371,15 @@ odeSolveVWith' :: | |||
432 | -> V.Vector Double -- ^ Initial conditions | 371 | -> V.Vector Double -- ^ Initial conditions |
433 | -> V.Vector Double -- ^ Desired solution times | 372 | -> V.Vector Double -- ^ Desired solution times |
434 | -> Matrix Double -- ^ Error code or solution | 373 | -> Matrix Double -- ^ Error code or solution |
435 | odeSolveVWith' method control initStepSize f y0 tt = | 374 | odeSolveVWith method control initStepSize f y0 tt = |
436 | case odeSolveVWith method control initStepSize f y0 tt of | 375 | case odeSolveVWith' method control initStepSize f y0 tt of |
437 | Left c -> error $ show c -- FIXME | 376 | Left c -> error $ show c -- FIXME |
438 | Right (v, _d) -> (nR >< nC) (V.toList v) | 377 | Right (v, _d) -> (nR >< nC) (V.toList v) |
439 | where | 378 | where |
440 | nR = V.length tt | 379 | nR = V.length tt |
441 | nC = V.length y0 | 380 | nC = V.length y0 |
442 | 381 | ||
443 | odeSolveVWith :: | 382 | odeSolveVWith' :: |
444 | ODEMethod | 383 | ODEMethod |
445 | -> StepControl | 384 | -> StepControl |
446 | -> Maybe Double -- ^ initial step size - by default, ARKode | 385 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -452,8 +391,8 @@ odeSolveVWith :: | |||
452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 391 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
453 | -> V.Vector Double -- ^ Initial conditions | 392 | -> V.Vector Double -- ^ Initial conditions |
454 | -> V.Vector Double -- ^ Desired solution times | 393 | -> V.Vector Double -- ^ Desired solution times |
455 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 394 | -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution |
456 | odeSolveVWith method control initStepSize f y0 tt = | 395 | odeSolveVWith' method control initStepSize f y0 tt = |
457 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | 396 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) |
458 | (coerce f) (coerce y0) (coerce tt) of | 397 | (coerce f) (coerce y0) (coerce tt) of |
459 | Left c -> Left $ fromIntegral c | 398 | Left c -> Left $ fromIntegral c |
@@ -482,7 +421,7 @@ solveOdeC :: | |||
482 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 421 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
483 | -> V.Vector CDouble -- ^ Initial conditions | 422 | -> V.Vector CDouble -- ^ Initial conditions |
484 | -> V.Vector CDouble -- ^ Desired solution times | 423 | -> V.Vector CDouble -- ^ Desired solution times |
485 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | 424 | -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution |
486 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | 425 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do |
487 | 426 | ||
488 | let isInitStepSize :: CInt | 427 | let isInitStepSize :: CInt |
@@ -514,9 +453,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
514 | funIO x y f _ptr = do | 453 | funIO x y f _ptr = do |
515 | -- Convert the pointer we get from C (y) to a vector, and then | 454 | -- Convert the pointer we get from C (y) to a vector, and then |
516 | -- apply the user-supplied function. | 455 | -- apply the user-supplied function. |
517 | fImm <- fun x <$> getDataFromContents dim y | 456 | fImm <- fun x <$> SO.getDataFromContents dim y |
518 | -- Fill in the provided pointer with the resulting vector. | 457 | -- Fill in the provided pointer with the resulting vector. |
519 | putDataInContents fImm dim f | 458 | SO.putDataInContents fImm dim f |
520 | -- FIXME: I don't understand what this comment means | 459 | -- FIXME: I don't understand what this comment means |
521 | -- Unsafe since the function will be called many times. | 460 | -- Unsafe since the function will be called many times. |
522 | [CU.exp| int{ 0 } |] | 461 | [CU.exp| int{ 0 } |] |
@@ -528,8 +467,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
528 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | 467 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do |
529 | case jacH of | 468 | case jacH of |
530 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | 469 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
531 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | 470 | Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y |
532 | putMatrixDataFromContents j jacS | 471 | SO.putMatrixDataFromContents j jacS |
533 | -- FIXME: I don't understand what this comment means | 472 | -- FIXME: I don't understand what this comment means |
534 | -- Unsafe since the function will be called many times. | 473 | -- Unsafe since the function will be called many times. |
535 | [CU.exp| int{ 0 } |] | 474 | [CU.exp| int{ 0 } |] |
@@ -704,16 +643,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
704 | if res == 0 | 643 | if res == 0 |
705 | then do | 644 | then do |
706 | preD <- V.freeze diagMut | 645 | preD <- V.freeze diagMut |
707 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | 646 | let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) |
708 | (fromIntegral $ preD V.!1) | 647 | (fromIntegral $ preD V.!1) |
709 | (fromIntegral $ preD V.!2) | 648 | (fromIntegral $ preD V.!2) |
710 | (fromIntegral $ preD V.!3) | 649 | (fromIntegral $ preD V.!3) |
711 | (fromIntegral $ preD V.!4) | 650 | (fromIntegral $ preD V.!4) |
712 | (fromIntegral $ preD V.!5) | 651 | (fromIntegral $ preD V.!5) |
713 | (fromIntegral $ preD V.!6) | 652 | (fromIntegral $ preD V.!6) |
714 | (fromIntegral $ preD V.!7) | 653 | (fromIntegral $ preD V.!7) |
715 | (fromIntegral $ preD V.!8) | 654 | (fromIntegral $ preD V.!8) |
716 | (fromIntegral $ preD V.!9) | 655 | (fromIntegral $ preD V.!9) |
717 | m <- V.freeze qMatMut | 656 | m <- V.freeze qMatMut |
718 | return $ Right (m, d) | 657 | return $ Right (m, d) |
719 | else do | 658 | else do |
@@ -783,15 +722,15 @@ getButcherTable method = unsafePerformIO $ do | |||
783 | btB2sMut <- V.thaw btB2s | 722 | btB2sMut <- V.thaw btB2s |
784 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 723 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
785 | funIOI x y f _ptr = do | 724 | funIOI x y f _ptr = do |
786 | fImm <- funI x <$> getDataFromContents dim y | 725 | fImm <- funI x <$> SO.getDataFromContents dim y |
787 | putDataInContents fImm dim f | 726 | SO.putDataInContents fImm dim f |
788 | -- FIXME: I don't understand what this comment means | 727 | -- FIXME: I don't understand what this comment means |
789 | -- Unsafe since the function will be called many times. | 728 | -- Unsafe since the function will be called many times. |
790 | [CU.exp| int{ 0 } |] | 729 | [CU.exp| int{ 0 } |] |
791 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 730 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
792 | funIOE x y f _ptr = do | 731 | funIOE x y f _ptr = do |
793 | fImm <- funE x <$> getDataFromContents dim y | 732 | fImm <- funE x <$> SO.getDataFromContents dim y |
794 | putDataInContents fImm dim f | 733 | SO.putDataInContents fImm dim f |
795 | -- FIXME: I don't understand what this comment means | 734 | -- FIXME: I don't understand what this comment means |
796 | -- Unsafe since the function will be called many times. | 735 | -- Unsafe since the function will be called many times. |
797 | [CU.exp| int{ 0 } |] | 736 | [CU.exp| int{ 0 } |] |
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index abe1bfe..d7a2b53 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | |||
@@ -61,46 +61,6 @@ | |||
61 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) | 61 | -- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) |
62 | -- @ | 62 | -- @ |
63 | -- | 63 | -- |
64 | -- KVAERNO_4_2_3 | ||
65 | -- | ||
66 | -- \[ | ||
67 | -- \begin{array}{c|cccc} | ||
68 | -- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
69 | -- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ | ||
70 | -- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
71 | -- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
72 | -- \hline | ||
73 | -- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ | ||
74 | -- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ | ||
75 | -- \end{array} | ||
76 | -- \] | ||
77 | -- | ||
78 | -- SDIRK_2_1_2 | ||
79 | -- | ||
80 | -- \[ | ||
81 | -- \begin{array}{c|cc} | ||
82 | -- 1.0 & 1.0 & 0.0 \\ | ||
83 | -- 0.0 & -1.0 & 1.0 \\ | ||
84 | -- \hline | ||
85 | -- & 0.5 & 0.5 \\ | ||
86 | -- & 1.0 & 0.0 \\ | ||
87 | -- \end{array} | ||
88 | -- \] | ||
89 | -- | ||
90 | -- SDIRK_5_3_4 | ||
91 | -- | ||
92 | -- \[ | ||
93 | -- \begin{array}{c|ccccc} | ||
94 | -- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ | ||
95 | -- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ | ||
96 | -- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ | ||
97 | -- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ | ||
98 | -- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ | ||
99 | -- \hline | ||
100 | -- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ | ||
101 | -- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\ | ||
102 | -- \end{array} | ||
103 | -- \] | ||
104 | ----------------------------------------------------------------------------- | 64 | ----------------------------------------------------------------------------- |
105 | module Numeric.Sundials.CVode.ODE ( odeSolve | 65 | module Numeric.Sundials.CVode.ODE ( odeSolve |
106 | , odeSolveV | 66 | , odeSolveV |
@@ -109,7 +69,6 @@ module Numeric.Sundials.CVode.ODE ( odeSolve | |||
109 | , ODEMethod(..) | 69 | , ODEMethod(..) |
110 | , StepControl(..) | 70 | , StepControl(..) |
111 | , Jacobian | 71 | , Jacobian |
112 | , SundialsDiagnostics(..) | ||
113 | ) where | 72 | ) where |
114 | 73 | ||
115 | import qualified Language.C.Inline as C | 74 | import qualified Language.C.Inline as C |
@@ -118,13 +77,10 @@ import qualified Language.C.Inline.Unsafe as CU | |||
118 | import Data.Monoid ((<>)) | 77 | import Data.Monoid ((<>)) |
119 | import Data.Maybe (isJust) | 78 | import Data.Maybe (isJust) |
120 | 79 | ||
121 | import Foreign.C.Types | 80 | import Foreign.C.Types (CDouble, CInt, CLong) |
122 | import Foreign.Ptr (Ptr) | 81 | import Foreign.Ptr (Ptr) |
123 | import Foreign.ForeignPtr (newForeignPtr_) | ||
124 | import Foreign.Storable (Storable) | ||
125 | 82 | ||
126 | import qualified Data.Vector.Storable as V | 83 | import qualified Data.Vector.Storable as V |
127 | import qualified Data.Vector.Storable.Mutable as VM | ||
128 | 84 | ||
129 | import Data.Coerce (coerce) | 85 | import Data.Coerce (coerce) |
130 | import System.IO.Unsafe (unsafePerformIO) | 86 | import System.IO.Unsafe (unsafePerformIO) |
@@ -136,7 +92,7 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | |||
136 | 92 | ||
137 | import qualified Types as T | 93 | import qualified Types as T |
138 | import Arkode (cV_ADAMS, cV_BDF) | 94 | import Arkode (cV_ADAMS, cV_BDF) |
139 | import qualified Arkode as B | 95 | import qualified Numeric.Sundials.ODEOpts as SO |
140 | 96 | ||
141 | 97 | ||
142 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 98 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -155,65 +111,6 @@ C.include "../../../helpers.h" | |||
155 | C.include "Arkode_hsc.h" | 111 | C.include "Arkode_hsc.h" |
156 | 112 | ||
157 | 113 | ||
158 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) | ||
159 | getDataFromContents len ptr = do | ||
160 | qtr <- B.getContentPtr ptr | ||
161 | rtr <- B.getData qtr | ||
162 | vectorFromC len rtr | ||
163 | |||
164 | -- FIXME: Potentially an instance of Storable | ||
165 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
166 | _getMatrixDataFromContents ptr = do | ||
167 | qtr <- B.getContentMatrixPtr ptr | ||
168 | rs <- B.getNRows qtr | ||
169 | cs <- B.getNCols qtr | ||
170 | rtr <- B.getMatrixData qtr | ||
171 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
172 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
173 | |||
174 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
175 | putMatrixDataFromContents mat ptr = do | ||
176 | let rs = T.rows mat | ||
177 | cs = T.cols mat | ||
178 | vs = T.vals mat | ||
179 | qtr <- B.getContentMatrixPtr ptr | ||
180 | B.putNRows rs qtr | ||
181 | B.putNCols cs qtr | ||
182 | rtr <- B.getMatrixData qtr | ||
183 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
184 | -- FIXME: END | ||
185 | |||
186 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
187 | putDataInContents vec len ptr = do | ||
188 | qtr <- B.getContentPtr ptr | ||
189 | rtr <- B.getData qtr | ||
190 | vectorToC vec len rtr | ||
191 | |||
192 | -- Utils | ||
193 | |||
194 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
195 | vectorFromC len ptr = do | ||
196 | ptr' <- newForeignPtr_ ptr | ||
197 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
198 | |||
199 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
200 | vectorToC vec len ptr = do | ||
201 | ptr' <- newForeignPtr_ ptr | ||
202 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
203 | |||
204 | data SundialsDiagnostics = SundialsDiagnostics { | ||
205 | aRKodeGetNumSteps :: Int | ||
206 | , aRKodeGetNumStepAttempts :: Int | ||
207 | , aRKodeGetNumRhsEvals_fe :: Int | ||
208 | , aRKodeGetNumRhsEvals_fi :: Int | ||
209 | , aRKodeGetNumLinSolvSetups :: Int | ||
210 | , aRKodeGetNumErrTestFails :: Int | ||
211 | , aRKodeGetNumNonlinSolvIters :: Int | ||
212 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
213 | , aRKDlsGetNumJacEvals :: Int | ||
214 | , aRKDlsGetNumRhsEvals :: Int | ||
215 | } deriving Show | ||
216 | |||
217 | type Jacobian = Double -> Vector Double -> Matrix Double | 114 | type Jacobian = Double -> Vector Double -> Matrix Double |
218 | 115 | ||
219 | -- | Stepping functions | 116 | -- | Stepping functions |
@@ -243,7 +140,7 @@ odeSolveV | |||
243 | -> Vector Double -- ^ desired solution times | 140 | -> Vector Double -- ^ desired solution times |
244 | -> Matrix Double -- ^ solution | 141 | -> Matrix Double -- ^ solution |
245 | odeSolveV meth hi epsAbs epsRel f y0 ts = | 142 | odeSolveV meth hi epsAbs epsRel f y0 ts = |
246 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of | 143 | case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of |
247 | Left c -> error $ show c -- FIXME | 144 | Left c -> error $ show c -- FIXME |
248 | -- FIXME: Can we do better than using lists? | 145 | -- FIXME: Can we do better than using lists? |
249 | Right (v, _d) -> (nR >< nC) (V.toList v) | 146 | Right (v, _d) -> (nR >< nC) (V.toList v) |
@@ -263,7 +160,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
263 | -> Matrix Double -- ^ solution | 160 | -> Matrix Double -- ^ solution |
264 | odeSolve f y0 ts = | 161 | odeSolve f y0 ts = |
265 | -- FIXME: These tolerances are different from the ones in GSL | 162 | -- FIXME: These tolerances are different from the ones in GSL |
266 | case odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | 163 | case odeSolveVWith' BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of |
267 | Left c -> error $ show c -- FIXME | 164 | Left c -> error $ show c -- FIXME |
268 | Right (v, _d) -> (nR >< nC) (V.toList v) | 165 | Right (v, _d) -> (nR >< nC) (V.toList v) |
269 | where | 166 | where |
@@ -272,7 +169,7 @@ odeSolve f y0 ts = | |||
272 | nC = length y0 | 169 | nC = length y0 |
273 | g t x0 = V.fromList $ f t (V.toList x0) | 170 | g t x0 = V.fromList $ f t (V.toList x0) |
274 | 171 | ||
275 | odeSolveVWith' :: | 172 | odeSolveVWith :: |
276 | ODEMethod | 173 | ODEMethod |
277 | -> StepControl | 174 | -> StepControl |
278 | -> Maybe Double -- ^ initial step size - by default, ARKode | 175 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -285,15 +182,15 @@ odeSolveVWith' :: | |||
285 | -> V.Vector Double -- ^ Initial conditions | 182 | -> V.Vector Double -- ^ Initial conditions |
286 | -> V.Vector Double -- ^ Desired solution times | 183 | -> V.Vector Double -- ^ Desired solution times |
287 | -> Matrix Double -- ^ Error code or solution | 184 | -> Matrix Double -- ^ Error code or solution |
288 | odeSolveVWith' method control initStepSize f y0 tt = | 185 | odeSolveVWith method control initStepSize f y0 tt = |
289 | case odeSolveVWith method control initStepSize f y0 tt of | 186 | case odeSolveVWith' method control initStepSize f y0 tt of |
290 | Left c -> error $ show c -- FIXME | 187 | Left c -> error $ show c -- FIXME |
291 | Right (v, _d) -> (nR >< nC) (V.toList v) | 188 | Right (v, _d) -> (nR >< nC) (V.toList v) |
292 | where | 189 | where |
293 | nR = V.length tt | 190 | nR = V.length tt |
294 | nC = V.length y0 | 191 | nC = V.length y0 |
295 | 192 | ||
296 | odeSolveVWith :: | 193 | odeSolveVWith' :: |
297 | ODEMethod | 194 | ODEMethod |
298 | -> StepControl | 195 | -> StepControl |
299 | -> Maybe Double -- ^ initial step size - by default, ARKode | 196 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -305,8 +202,8 @@ odeSolveVWith :: | |||
305 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 202 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
306 | -> V.Vector Double -- ^ Initial conditions | 203 | -> V.Vector Double -- ^ Initial conditions |
307 | -> V.Vector Double -- ^ Desired solution times | 204 | -> V.Vector Double -- ^ Desired solution times |
308 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 205 | -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution |
309 | odeSolveVWith method control initStepSize f y0 tt = | 206 | odeSolveVWith' method control initStepSize f y0 tt = |
310 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | 207 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) |
311 | (coerce f) (coerce y0) (coerce tt) of | 208 | (coerce f) (coerce y0) (coerce tt) of |
312 | Left c -> Left $ fromIntegral c | 209 | Left c -> Left $ fromIntegral c |
@@ -335,7 +232,7 @@ solveOdeC :: | |||
335 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 232 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
336 | -> V.Vector CDouble -- ^ Initial conditions | 233 | -> V.Vector CDouble -- ^ Initial conditions |
337 | -> V.Vector CDouble -- ^ Desired solution times | 234 | -> V.Vector CDouble -- ^ Desired solution times |
338 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | 235 | -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution |
339 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | 236 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do |
340 | 237 | ||
341 | let isInitStepSize :: CInt | 238 | let isInitStepSize :: CInt |
@@ -366,9 +263,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
366 | funIO x y f _ptr = do | 263 | funIO x y f _ptr = do |
367 | -- Convert the pointer we get from C (y) to a vector, and then | 264 | -- Convert the pointer we get from C (y) to a vector, and then |
368 | -- apply the user-supplied function. | 265 | -- apply the user-supplied function. |
369 | fImm <- fun x <$> getDataFromContents dim y | 266 | fImm <- fun x <$> SO.getDataFromContents dim y |
370 | -- Fill in the provided pointer with the resulting vector. | 267 | -- Fill in the provided pointer with the resulting vector. |
371 | putDataInContents fImm dim f | 268 | SO.putDataInContents fImm dim f |
372 | -- FIXME: I don't understand what this comment means | 269 | -- FIXME: I don't understand what this comment means |
373 | -- Unsafe since the function will be called many times. | 270 | -- Unsafe since the function will be called many times. |
374 | [CU.exp| int{ 0 } |] | 271 | [CU.exp| int{ 0 } |] |
@@ -380,8 +277,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
380 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | 277 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do |
381 | case jacH of | 278 | case jacH of |
382 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | 279 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
383 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | 280 | Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y |
384 | putMatrixDataFromContents j jacS | 281 | SO.putMatrixDataFromContents j jacS |
385 | -- FIXME: I don't understand what this comment means | 282 | -- FIXME: I don't understand what this comment means |
386 | -- Unsafe since the function will be called many times. | 283 | -- Unsafe since the function will be called many times. |
387 | [CU.exp| int{ 0 } |] | 284 | [CU.exp| int{ 0 } |] |
@@ -541,16 +438,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
541 | if res == 0 | 438 | if res == 0 |
542 | then do | 439 | then do |
543 | preD <- V.freeze diagMut | 440 | preD <- V.freeze diagMut |
544 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | 441 | let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) |
545 | (fromIntegral $ preD V.!1) | 442 | (fromIntegral $ preD V.!1) |
546 | (fromIntegral $ preD V.!2) | 443 | (fromIntegral $ preD V.!2) |
547 | (fromIntegral $ preD V.!3) | 444 | (fromIntegral $ preD V.!3) |
548 | (fromIntegral $ preD V.!4) | 445 | (fromIntegral $ preD V.!4) |
549 | (fromIntegral $ preD V.!5) | 446 | (fromIntegral $ preD V.!5) |
550 | (fromIntegral $ preD V.!6) | 447 | (fromIntegral $ preD V.!6) |
551 | (fromIntegral $ preD V.!7) | 448 | (fromIntegral $ preD V.!7) |
552 | (fromIntegral $ preD V.!8) | 449 | (fromIntegral $ preD V.!8) |
553 | (fromIntegral $ preD V.!9) | 450 | (fromIntegral $ preD V.!9) |
554 | m <- V.freeze qMatMut | 451 | m <- V.freeze qMatMut |
555 | return $ Right (m, d) | 452 | return $ Right (m, d) |
556 | else do | 453 | else do |
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs new file mode 100644 index 0000000..e924292 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs | |||
@@ -0,0 +1,78 @@ | |||
1 | module Numeric.Sundials.ODEOpts where | ||
2 | |||
3 | import Data.Int (Int32) | ||
4 | import Foreign.Ptr (Ptr) | ||
5 | import Foreign.Storable as FS | ||
6 | import Foreign.ForeignPtr as FF | ||
7 | import Foreign.C.Types | ||
8 | import qualified Data.Vector.Storable as VS | ||
9 | import qualified Data.Vector.Storable.Mutable as VM | ||
10 | |||
11 | import qualified Types as T | ||
12 | import qualified Arkode as B | ||
13 | |||
14 | data ODEOpts = ODEOpts { | ||
15 | maxNumSteps :: Int32 | ||
16 | , minStep :: Double | ||
17 | , relTol :: Double | ||
18 | , absTols :: VS.Vector Double | ||
19 | , initStep :: Double | ||
20 | } deriving (Read, Show, Eq, Ord) | ||
21 | |||
22 | -- FIXME: Potentially an instance of Storable | ||
23 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
24 | _getMatrixDataFromContents ptr = do | ||
25 | qtr <- B.getContentMatrixPtr ptr | ||
26 | rs <- B.getNRows qtr | ||
27 | cs <- B.getNCols qtr | ||
28 | rtr <- B.getMatrixData qtr | ||
29 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
30 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
31 | |||
32 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
33 | putMatrixDataFromContents mat ptr = do | ||
34 | let rs = T.rows mat | ||
35 | cs = T.cols mat | ||
36 | vs = T.vals mat | ||
37 | qtr <- B.getContentMatrixPtr ptr | ||
38 | B.putNRows rs qtr | ||
39 | B.putNCols cs qtr | ||
40 | rtr <- B.getMatrixData qtr | ||
41 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
42 | -- FIXME: END | ||
43 | |||
44 | vectorFromC :: Storable a => Int -> Ptr a -> IO (VS.Vector a) | ||
45 | vectorFromC len ptr = do | ||
46 | ptr' <- newForeignPtr_ ptr | ||
47 | VS.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
48 | |||
49 | vectorToC :: Storable a => VS.Vector a -> Int -> Ptr a -> IO () | ||
50 | vectorToC vec len ptr = do | ||
51 | ptr' <- newForeignPtr_ ptr | ||
52 | VS.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
53 | |||
54 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (VS.Vector CDouble) | ||
55 | getDataFromContents len ptr = do | ||
56 | qtr <- B.getContentPtr ptr | ||
57 | rtr <- B.getData qtr | ||
58 | vectorFromC len rtr | ||
59 | |||
60 | putDataInContents :: Storable a => VS.Vector a -> Int -> Ptr b -> IO () | ||
61 | putDataInContents vec len ptr = do | ||
62 | qtr <- B.getContentPtr ptr | ||
63 | rtr <- B.getData qtr | ||
64 | vectorToC vec len rtr | ||
65 | |||
66 | data SundialsDiagnostics = SundialsDiagnostics { | ||
67 | aRKodeGetNumSteps :: Int | ||
68 | , aRKodeGetNumStepAttempts :: Int | ||
69 | , aRKodeGetNumRhsEvals_fe :: Int | ||
70 | , aRKodeGetNumRhsEvals_fi :: Int | ||
71 | , aRKodeGetNumLinSolvSetups :: Int | ||
72 | , aRKodeGetNumErrTestFails :: Int | ||
73 | , aRKodeGetNumNonlinSolvIters :: Int | ||
74 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
75 | , aRKDlsGetNumJacEvals :: Int | ||
76 | , aRKDlsGetNumRhsEvals :: Int | ||
77 | } deriving Show | ||
78 | |||