diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2018-04-25 16:25:20 +0100 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2018-04-25 16:25:20 +0100 |
commit | 729eb192cf77d4cddf33d2724b4409ab7d828921 (patch) | |
tree | b0646ef5ea95179b96029b662dafdf4740fa11f1 /packages/sundials/src/Numeric/Sundials/ARKode | |
parent | c73f86f64a60209a50b9c4cc3b137726955f2df7 (diff) |
Pull out common code and start to follow gsl naming convention
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 123 |
1 files changed, 31 insertions, 92 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index e5a2e4d..8b713c6 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -117,7 +117,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve | |||
117 | , ODEMethod(..) | 117 | , ODEMethod(..) |
118 | , StepControl(..) | 118 | , StepControl(..) |
119 | , Jacobian | 119 | , Jacobian |
120 | , SundialsDiagnostics(..) | ||
121 | ) where | 120 | ) where |
122 | 121 | ||
123 | import qualified Language.C.Inline as C | 122 | import qualified Language.C.Inline as C |
@@ -126,17 +125,15 @@ import qualified Language.C.Inline.Unsafe as CU | |||
126 | import Data.Monoid ((<>)) | 125 | import Data.Monoid ((<>)) |
127 | import Data.Maybe (isJust) | 126 | import Data.Maybe (isJust) |
128 | 127 | ||
129 | import Foreign.C.Types | 128 | import Foreign.C.Types (CDouble, CInt, CLong) |
130 | import Foreign.Ptr (Ptr) | 129 | import Foreign.Ptr (Ptr) |
131 | import Foreign.ForeignPtr (newForeignPtr_) | ||
132 | import Foreign.Storable (Storable) | ||
133 | 130 | ||
134 | import qualified Data.Vector.Storable as V | 131 | import qualified Data.Vector.Storable as V |
135 | import qualified Data.Vector.Storable.Mutable as VM | ||
136 | 132 | ||
137 | import Data.Coerce (coerce) | 133 | import Data.Coerce (coerce) |
138 | import System.IO.Unsafe (unsafePerformIO) | 134 | import System.IO.Unsafe (unsafePerformIO) |
139 | import GHC.Generics | 135 | import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..), |
136 | from, conName) | ||
140 | 137 | ||
141 | import Numeric.LinearAlgebra.Devel (createVector) | 138 | import Numeric.LinearAlgebra.Devel (createVector) |
142 | 139 | ||
@@ -147,6 +144,7 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | |||
147 | import qualified Types as T | 144 | import qualified Types as T |
148 | import Arkode | 145 | import Arkode |
149 | import qualified Arkode as B | 146 | import qualified Arkode as B |
147 | import qualified Numeric.Sundials.ODEOpts as SO | ||
150 | 148 | ||
151 | 149 | ||
152 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 150 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -165,65 +163,6 @@ C.include "../../../helpers.h" | |||
165 | C.include "Arkode_hsc.h" | 163 | C.include "Arkode_hsc.h" |
166 | 164 | ||
167 | 165 | ||
168 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) | ||
169 | getDataFromContents len ptr = do | ||
170 | qtr <- B.getContentPtr ptr | ||
171 | rtr <- B.getData qtr | ||
172 | vectorFromC len rtr | ||
173 | |||
174 | -- FIXME: Potentially an instance of Storable | ||
175 | _getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
176 | _getMatrixDataFromContents ptr = do | ||
177 | qtr <- B.getContentMatrixPtr ptr | ||
178 | rs <- B.getNRows qtr | ||
179 | cs <- B.getNCols qtr | ||
180 | rtr <- B.getMatrixData qtr | ||
181 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
182 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
183 | |||
184 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
185 | putMatrixDataFromContents mat ptr = do | ||
186 | let rs = T.rows mat | ||
187 | cs = T.cols mat | ||
188 | vs = T.vals mat | ||
189 | qtr <- B.getContentMatrixPtr ptr | ||
190 | B.putNRows rs qtr | ||
191 | B.putNCols cs qtr | ||
192 | rtr <- B.getMatrixData qtr | ||
193 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
194 | -- FIXME: END | ||
195 | |||
196 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
197 | putDataInContents vec len ptr = do | ||
198 | qtr <- B.getContentPtr ptr | ||
199 | rtr <- B.getData qtr | ||
200 | vectorToC vec len rtr | ||
201 | |||
202 | -- Utils | ||
203 | |||
204 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
205 | vectorFromC len ptr = do | ||
206 | ptr' <- newForeignPtr_ ptr | ||
207 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
208 | |||
209 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
210 | vectorToC vec len ptr = do | ||
211 | ptr' <- newForeignPtr_ ptr | ||
212 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
213 | |||
214 | data SundialsDiagnostics = SundialsDiagnostics { | ||
215 | aRKodeGetNumSteps :: Int | ||
216 | , aRKodeGetNumStepAttempts :: Int | ||
217 | , aRKodeGetNumRhsEvals_fe :: Int | ||
218 | , aRKodeGetNumRhsEvals_fi :: Int | ||
219 | , aRKodeGetNumLinSolvSetups :: Int | ||
220 | , aRKodeGetNumErrTestFails :: Int | ||
221 | , aRKodeGetNumNonlinSolvIters :: Int | ||
222 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
223 | , aRKDlsGetNumJacEvals :: Int | ||
224 | , aRKDlsGetNumRhsEvals :: Int | ||
225 | } deriving Show | ||
226 | |||
227 | type Jacobian = Double -> Vector Double -> Matrix Double | 166 | type Jacobian = Double -> Vector Double -> Matrix Double |
228 | 167 | ||
229 | -- | Stepping functions | 168 | -- | Stepping functions |
@@ -390,7 +329,7 @@ odeSolveV | |||
390 | -> Vector Double -- ^ desired solution times | 329 | -> Vector Double -- ^ desired solution times |
391 | -> Matrix Double -- ^ solution | 330 | -> Matrix Double -- ^ solution |
392 | odeSolveV meth hi epsAbs epsRel f y0 ts = | 331 | odeSolveV meth hi epsAbs epsRel f y0 ts = |
393 | case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of | 332 | case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of |
394 | Left c -> error $ show c -- FIXME | 333 | Left c -> error $ show c -- FIXME |
395 | -- FIXME: Can we do better than using lists? | 334 | -- FIXME: Can we do better than using lists? |
396 | Right (v, _d) -> (nR >< nC) (V.toList v) | 335 | Right (v, _d) -> (nR >< nC) (V.toList v) |
@@ -410,7 +349,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
410 | -> Matrix Double -- ^ solution | 349 | -> Matrix Double -- ^ solution |
411 | odeSolve f y0 ts = | 350 | odeSolve f y0 ts = |
412 | -- FIXME: These tolerances are different from the ones in GSL | 351 | -- FIXME: These tolerances are different from the ones in GSL |
413 | case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | 352 | case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of |
414 | Left c -> error $ show c -- FIXME | 353 | Left c -> error $ show c -- FIXME |
415 | Right (v, _d) -> (nR >< nC) (V.toList v) | 354 | Right (v, _d) -> (nR >< nC) (V.toList v) |
416 | where | 355 | where |
@@ -419,7 +358,7 @@ odeSolve f y0 ts = | |||
419 | nC = length y0 | 358 | nC = length y0 |
420 | g t x0 = V.fromList $ f t (V.toList x0) | 359 | g t x0 = V.fromList $ f t (V.toList x0) |
421 | 360 | ||
422 | odeSolveVWith' :: | 361 | odeSolveVWith :: |
423 | ODEMethod | 362 | ODEMethod |
424 | -> StepControl | 363 | -> StepControl |
425 | -> Maybe Double -- ^ initial step size - by default, ARKode | 364 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -432,15 +371,15 @@ odeSolveVWith' :: | |||
432 | -> V.Vector Double -- ^ Initial conditions | 371 | -> V.Vector Double -- ^ Initial conditions |
433 | -> V.Vector Double -- ^ Desired solution times | 372 | -> V.Vector Double -- ^ Desired solution times |
434 | -> Matrix Double -- ^ Error code or solution | 373 | -> Matrix Double -- ^ Error code or solution |
435 | odeSolveVWith' method control initStepSize f y0 tt = | 374 | odeSolveVWith method control initStepSize f y0 tt = |
436 | case odeSolveVWith method control initStepSize f y0 tt of | 375 | case odeSolveVWith' method control initStepSize f y0 tt of |
437 | Left c -> error $ show c -- FIXME | 376 | Left c -> error $ show c -- FIXME |
438 | Right (v, _d) -> (nR >< nC) (V.toList v) | 377 | Right (v, _d) -> (nR >< nC) (V.toList v) |
439 | where | 378 | where |
440 | nR = V.length tt | 379 | nR = V.length tt |
441 | nC = V.length y0 | 380 | nC = V.length y0 |
442 | 381 | ||
443 | odeSolveVWith :: | 382 | odeSolveVWith' :: |
444 | ODEMethod | 383 | ODEMethod |
445 | -> StepControl | 384 | -> StepControl |
446 | -> Maybe Double -- ^ initial step size - by default, ARKode | 385 | -> Maybe Double -- ^ initial step size - by default, ARKode |
@@ -452,8 +391,8 @@ odeSolveVWith :: | |||
452 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 391 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
453 | -> V.Vector Double -- ^ Initial conditions | 392 | -> V.Vector Double -- ^ Initial conditions |
454 | -> V.Vector Double -- ^ Desired solution times | 393 | -> V.Vector Double -- ^ Desired solution times |
455 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 394 | -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution |
456 | odeSolveVWith method control initStepSize f y0 tt = | 395 | odeSolveVWith' method control initStepSize f y0 tt = |
457 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | 396 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) |
458 | (coerce f) (coerce y0) (coerce tt) of | 397 | (coerce f) (coerce y0) (coerce tt) of |
459 | Left c -> Left $ fromIntegral c | 398 | Left c -> Left $ fromIntegral c |
@@ -482,7 +421,7 @@ solveOdeC :: | |||
482 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 421 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
483 | -> V.Vector CDouble -- ^ Initial conditions | 422 | -> V.Vector CDouble -- ^ Initial conditions |
484 | -> V.Vector CDouble -- ^ Desired solution times | 423 | -> V.Vector CDouble -- ^ Desired solution times |
485 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | 424 | -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution |
486 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | 425 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do |
487 | 426 | ||
488 | let isInitStepSize :: CInt | 427 | let isInitStepSize :: CInt |
@@ -514,9 +453,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
514 | funIO x y f _ptr = do | 453 | funIO x y f _ptr = do |
515 | -- Convert the pointer we get from C (y) to a vector, and then | 454 | -- Convert the pointer we get from C (y) to a vector, and then |
516 | -- apply the user-supplied function. | 455 | -- apply the user-supplied function. |
517 | fImm <- fun x <$> getDataFromContents dim y | 456 | fImm <- fun x <$> SO.getDataFromContents dim y |
518 | -- Fill in the provided pointer with the resulting vector. | 457 | -- Fill in the provided pointer with the resulting vector. |
519 | putDataInContents fImm dim f | 458 | SO.putDataInContents fImm dim f |
520 | -- FIXME: I don't understand what this comment means | 459 | -- FIXME: I don't understand what this comment means |
521 | -- Unsafe since the function will be called many times. | 460 | -- Unsafe since the function will be called many times. |
522 | [CU.exp| int{ 0 } |] | 461 | [CU.exp| int{ 0 } |] |
@@ -528,8 +467,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
528 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | 467 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do |
529 | case jacH of | 468 | case jacH of |
530 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" | 469 | Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" |
531 | Just jacI -> do j <- jacI t <$> getDataFromContents dim y | 470 | Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y |
532 | putMatrixDataFromContents j jacS | 471 | SO.putMatrixDataFromContents j jacS |
533 | -- FIXME: I don't understand what this comment means | 472 | -- FIXME: I don't understand what this comment means |
534 | -- Unsafe since the function will be called many times. | 473 | -- Unsafe since the function will be called many times. |
535 | [CU.exp| int{ 0 } |] | 474 | [CU.exp| int{ 0 } |] |
@@ -704,16 +643,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
704 | if res == 0 | 643 | if res == 0 |
705 | then do | 644 | then do |
706 | preD <- V.freeze diagMut | 645 | preD <- V.freeze diagMut |
707 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | 646 | let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0) |
708 | (fromIntegral $ preD V.!1) | 647 | (fromIntegral $ preD V.!1) |
709 | (fromIntegral $ preD V.!2) | 648 | (fromIntegral $ preD V.!2) |
710 | (fromIntegral $ preD V.!3) | 649 | (fromIntegral $ preD V.!3) |
711 | (fromIntegral $ preD V.!4) | 650 | (fromIntegral $ preD V.!4) |
712 | (fromIntegral $ preD V.!5) | 651 | (fromIntegral $ preD V.!5) |
713 | (fromIntegral $ preD V.!6) | 652 | (fromIntegral $ preD V.!6) |
714 | (fromIntegral $ preD V.!7) | 653 | (fromIntegral $ preD V.!7) |
715 | (fromIntegral $ preD V.!8) | 654 | (fromIntegral $ preD V.!8) |
716 | (fromIntegral $ preD V.!9) | 655 | (fromIntegral $ preD V.!9) |
717 | m <- V.freeze qMatMut | 656 | m <- V.freeze qMatMut |
718 | return $ Right (m, d) | 657 | return $ Right (m, d) |
719 | else do | 658 | else do |
@@ -783,15 +722,15 @@ getButcherTable method = unsafePerformIO $ do | |||
783 | btB2sMut <- V.thaw btB2s | 722 | btB2sMut <- V.thaw btB2s |
784 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 723 | let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
785 | funIOI x y f _ptr = do | 724 | funIOI x y f _ptr = do |
786 | fImm <- funI x <$> getDataFromContents dim y | 725 | fImm <- funI x <$> SO.getDataFromContents dim y |
787 | putDataInContents fImm dim f | 726 | SO.putDataInContents fImm dim f |
788 | -- FIXME: I don't understand what this comment means | 727 | -- FIXME: I don't understand what this comment means |
789 | -- Unsafe since the function will be called many times. | 728 | -- Unsafe since the function will be called many times. |
790 | [CU.exp| int{ 0 } |] | 729 | [CU.exp| int{ 0 } |] |
791 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt | 730 | let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt |
792 | funIOE x y f _ptr = do | 731 | funIOE x y f _ptr = do |
793 | fImm <- funE x <$> getDataFromContents dim y | 732 | fImm <- funE x <$> SO.getDataFromContents dim y |
794 | putDataInContents fImm dim f | 733 | SO.putDataInContents fImm dim f |
795 | -- FIXME: I don't understand what this comment means | 734 | -- FIXME: I don't understand what this comment means |
796 | -- Unsafe since the function will be called many times. | 735 | -- Unsafe since the function will be called many times. |
797 | [CU.exp| int{ 0 } |] | 736 | [CU.exp| int{ 0 } |] |