summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-04-25 16:25:20 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-04-25 16:25:20 +0100
commit729eb192cf77d4cddf33d2724b4409ab7d828921 (patch)
treeb0646ef5ea95179b96029b662dafdf4740fa11f1 /packages/sundials/src/Numeric/Sundials/ARKode
parentc73f86f64a60209a50b9c4cc3b137726955f2df7 (diff)
Pull out common code and start to follow gsl naming convention
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs123
1 files changed, 31 insertions, 92 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index e5a2e4d..8b713c6 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -117,7 +117,6 @@ module Numeric.Sundials.ARKode.ODE ( odeSolve
117 , ODEMethod(..) 117 , ODEMethod(..)
118 , StepControl(..) 118 , StepControl(..)
119 , Jacobian 119 , Jacobian
120 , SundialsDiagnostics(..)
121 ) where 120 ) where
122 121
123import qualified Language.C.Inline as C 122import qualified Language.C.Inline as C
@@ -126,17 +125,15 @@ import qualified Language.C.Inline.Unsafe as CU
126import Data.Monoid ((<>)) 125import Data.Monoid ((<>))
127import Data.Maybe (isJust) 126import Data.Maybe (isJust)
128 127
129import Foreign.C.Types 128import Foreign.C.Types (CDouble, CInt, CLong)
130import Foreign.Ptr (Ptr) 129import Foreign.Ptr (Ptr)
131import Foreign.ForeignPtr (newForeignPtr_)
132import Foreign.Storable (Storable)
133 130
134import qualified Data.Vector.Storable as V 131import qualified Data.Vector.Storable as V
135import qualified Data.Vector.Storable.Mutable as VM
136 132
137import Data.Coerce (coerce) 133import Data.Coerce (coerce)
138import System.IO.Unsafe (unsafePerformIO) 134import System.IO.Unsafe (unsafePerformIO)
139import GHC.Generics 135import GHC.Generics (C1, Constructor, (:+:)(..), D1, Rep, Generic, M1(..),
136 from, conName)
140 137
141import Numeric.LinearAlgebra.Devel (createVector) 138import Numeric.LinearAlgebra.Devel (createVector)
142 139
@@ -147,6 +144,7 @@ import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
147import qualified Types as T 144import qualified Types as T
148import Arkode 145import Arkode
149import qualified Arkode as B 146import qualified Arkode as B
147import qualified Numeric.Sundials.ODEOpts as SO
150 148
151 149
152C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) 150C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
@@ -165,65 +163,6 @@ C.include "../../../helpers.h"
165C.include "Arkode_hsc.h" 163C.include "Arkode_hsc.h"
166 164
167 165
168getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble)
169getDataFromContents len ptr = do
170 qtr <- B.getContentPtr ptr
171 rtr <- B.getData qtr
172 vectorFromC len rtr
173
174-- FIXME: Potentially an instance of Storable
175_getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix
176_getMatrixDataFromContents ptr = do
177 qtr <- B.getContentMatrixPtr ptr
178 rs <- B.getNRows qtr
179 cs <- B.getNCols qtr
180 rtr <- B.getMatrixData qtr
181 vs <- vectorFromC (fromIntegral $ rs * cs) rtr
182 return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs }
183
184putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO ()
185putMatrixDataFromContents mat ptr = do
186 let rs = T.rows mat
187 cs = T.cols mat
188 vs = T.vals mat
189 qtr <- B.getContentMatrixPtr ptr
190 B.putNRows rs qtr
191 B.putNCols cs qtr
192 rtr <- B.getMatrixData qtr
193 vectorToC vs (fromIntegral $ rs * cs) rtr
194-- FIXME: END
195
196putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
197putDataInContents vec len ptr = do
198 qtr <- B.getContentPtr ptr
199 rtr <- B.getData qtr
200 vectorToC vec len rtr
201
202-- Utils
203
204vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
205vectorFromC len ptr = do
206 ptr' <- newForeignPtr_ ptr
207 V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
208
209vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
210vectorToC vec len ptr = do
211 ptr' <- newForeignPtr_ ptr
212 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
213
214data SundialsDiagnostics = SundialsDiagnostics {
215 aRKodeGetNumSteps :: Int
216 , aRKodeGetNumStepAttempts :: Int
217 , aRKodeGetNumRhsEvals_fe :: Int
218 , aRKodeGetNumRhsEvals_fi :: Int
219 , aRKodeGetNumLinSolvSetups :: Int
220 , aRKodeGetNumErrTestFails :: Int
221 , aRKodeGetNumNonlinSolvIters :: Int
222 , aRKodeGetNumNonlinSolvConvFails :: Int
223 , aRKDlsGetNumJacEvals :: Int
224 , aRKDlsGetNumRhsEvals :: Int
225 } deriving Show
226
227type Jacobian = Double -> Vector Double -> Matrix Double 166type Jacobian = Double -> Vector Double -> Matrix Double
228 167
229-- | Stepping functions 168-- | Stepping functions
@@ -390,7 +329,7 @@ odeSolveV
390 -> Vector Double -- ^ desired solution times 329 -> Vector Double -- ^ desired solution times
391 -> Matrix Double -- ^ solution 330 -> Matrix Double -- ^ solution
392odeSolveV meth hi epsAbs epsRel f y0 ts = 331odeSolveV meth hi epsAbs epsRel f y0 ts =
393 case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of 332 case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of
394 Left c -> error $ show c -- FIXME 333 Left c -> error $ show c -- FIXME
395 -- FIXME: Can we do better than using lists? 334 -- FIXME: Can we do better than using lists?
396 Right (v, _d) -> (nR >< nC) (V.toList v) 335 Right (v, _d) -> (nR >< nC) (V.toList v)
@@ -410,7 +349,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
410 -> Matrix Double -- ^ solution 349 -> Matrix Double -- ^ solution
411odeSolve f y0 ts = 350odeSolve f y0 ts =
412 -- FIXME: These tolerances are different from the ones in GSL 351 -- FIXME: These tolerances are different from the ones in GSL
413 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 352 case odeSolveVWith' SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of
414 Left c -> error $ show c -- FIXME 353 Left c -> error $ show c -- FIXME
415 Right (v, _d) -> (nR >< nC) (V.toList v) 354 Right (v, _d) -> (nR >< nC) (V.toList v)
416 where 355 where
@@ -419,7 +358,7 @@ odeSolve f y0 ts =
419 nC = length y0 358 nC = length y0
420 g t x0 = V.fromList $ f t (V.toList x0) 359 g t x0 = V.fromList $ f t (V.toList x0)
421 360
422odeSolveVWith' :: 361odeSolveVWith ::
423 ODEMethod 362 ODEMethod
424 -> StepControl 363 -> StepControl
425 -> Maybe Double -- ^ initial step size - by default, ARKode 364 -> Maybe Double -- ^ initial step size - by default, ARKode
@@ -432,15 +371,15 @@ odeSolveVWith' ::
432 -> V.Vector Double -- ^ Initial conditions 371 -> V.Vector Double -- ^ Initial conditions
433 -> V.Vector Double -- ^ Desired solution times 372 -> V.Vector Double -- ^ Desired solution times
434 -> Matrix Double -- ^ Error code or solution 373 -> Matrix Double -- ^ Error code or solution
435odeSolveVWith' method control initStepSize f y0 tt = 374odeSolveVWith method control initStepSize f y0 tt =
436 case odeSolveVWith method control initStepSize f y0 tt of 375 case odeSolveVWith' method control initStepSize f y0 tt of
437 Left c -> error $ show c -- FIXME 376 Left c -> error $ show c -- FIXME
438 Right (v, _d) -> (nR >< nC) (V.toList v) 377 Right (v, _d) -> (nR >< nC) (V.toList v)
439 where 378 where
440 nR = V.length tt 379 nR = V.length tt
441 nC = V.length y0 380 nC = V.length y0
442 381
443odeSolveVWith :: 382odeSolveVWith' ::
444 ODEMethod 383 ODEMethod
445 -> StepControl 384 -> StepControl
446 -> Maybe Double -- ^ initial step size - by default, ARKode 385 -> Maybe Double -- ^ initial step size - by default, ARKode
@@ -452,8 +391,8 @@ odeSolveVWith ::
452 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 391 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
453 -> V.Vector Double -- ^ Initial conditions 392 -> V.Vector Double -- ^ Initial conditions
454 -> V.Vector Double -- ^ Desired solution times 393 -> V.Vector Double -- ^ Desired solution times
455 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 394 -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution
456odeSolveVWith method control initStepSize f y0 tt = 395odeSolveVWith' method control initStepSize f y0 tt =
457 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) 396 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
458 (coerce f) (coerce y0) (coerce tt) of 397 (coerce f) (coerce y0) (coerce tt) of
459 Left c -> Left $ fromIntegral c 398 Left c -> Left $ fromIntegral c
@@ -482,7 +421,7 @@ solveOdeC ::
482 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 421 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
483 -> V.Vector CDouble -- ^ Initial conditions 422 -> V.Vector CDouble -- ^ Initial conditions
484 -> V.Vector CDouble -- ^ Desired solution times 423 -> V.Vector CDouble -- ^ Desired solution times
485 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 424 -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution
486solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do 425solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
487 426
488 let isInitStepSize :: CInt 427 let isInitStepSize :: CInt
@@ -514,9 +453,9 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
514 funIO x y f _ptr = do 453 funIO x y f _ptr = do
515 -- Convert the pointer we get from C (y) to a vector, and then 454 -- Convert the pointer we get from C (y) to a vector, and then
516 -- apply the user-supplied function. 455 -- apply the user-supplied function.
517 fImm <- fun x <$> getDataFromContents dim y 456 fImm <- fun x <$> SO.getDataFromContents dim y
518 -- Fill in the provided pointer with the resulting vector. 457 -- Fill in the provided pointer with the resulting vector.
519 putDataInContents fImm dim f 458 SO.putDataInContents fImm dim f
520 -- FIXME: I don't understand what this comment means 459 -- FIXME: I don't understand what this comment means
521 -- Unsafe since the function will be called many times. 460 -- Unsafe since the function will be called many times.
522 [CU.exp| int{ 0 } |] 461 [CU.exp| int{ 0 } |]
@@ -528,8 +467,8 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
528 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do 467 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
529 case jacH of 468 case jacH of
530 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined" 469 Nothing -> error "Numeric.Sundials.ARKode.ODE: Jacobian not defined"
531 Just jacI -> do j <- jacI t <$> getDataFromContents dim y 470 Just jacI -> do j <- jacI t <$> SO.getDataFromContents dim y
532 putMatrixDataFromContents j jacS 471 SO.putMatrixDataFromContents j jacS
533 -- FIXME: I don't understand what this comment means 472 -- FIXME: I don't understand what this comment means
534 -- Unsafe since the function will be called many times. 473 -- Unsafe since the function will be called many times.
535 [CU.exp| int{ 0 } |] 474 [CU.exp| int{ 0 } |]
@@ -704,16 +643,16 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
704 if res == 0 643 if res == 0
705 then do 644 then do
706 preD <- V.freeze diagMut 645 preD <- V.freeze diagMut
707 let d = SundialsDiagnostics (fromIntegral $ preD V.!0) 646 let d = SO.SundialsDiagnostics (fromIntegral $ preD V.!0)
708 (fromIntegral $ preD V.!1) 647 (fromIntegral $ preD V.!1)
709 (fromIntegral $ preD V.!2) 648 (fromIntegral $ preD V.!2)
710 (fromIntegral $ preD V.!3) 649 (fromIntegral $ preD V.!3)
711 (fromIntegral $ preD V.!4) 650 (fromIntegral $ preD V.!4)
712 (fromIntegral $ preD V.!5) 651 (fromIntegral $ preD V.!5)
713 (fromIntegral $ preD V.!6) 652 (fromIntegral $ preD V.!6)
714 (fromIntegral $ preD V.!7) 653 (fromIntegral $ preD V.!7)
715 (fromIntegral $ preD V.!8) 654 (fromIntegral $ preD V.!8)
716 (fromIntegral $ preD V.!9) 655 (fromIntegral $ preD V.!9)
717 m <- V.freeze qMatMut 656 m <- V.freeze qMatMut
718 return $ Right (m, d) 657 return $ Right (m, d)
719 else do 658 else do
@@ -783,15 +722,15 @@ getButcherTable method = unsafePerformIO $ do
783 btB2sMut <- V.thaw btB2s 722 btB2sMut <- V.thaw btB2s
784 let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt 723 let funIOI :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
785 funIOI x y f _ptr = do 724 funIOI x y f _ptr = do
786 fImm <- funI x <$> getDataFromContents dim y 725 fImm <- funI x <$> SO.getDataFromContents dim y
787 putDataInContents fImm dim f 726 SO.putDataInContents fImm dim f
788 -- FIXME: I don't understand what this comment means 727 -- FIXME: I don't understand what this comment means
789 -- Unsafe since the function will be called many times. 728 -- Unsafe since the function will be called many times.
790 [CU.exp| int{ 0 } |] 729 [CU.exp| int{ 0 } |]
791 let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt 730 let funIOE :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt
792 funIOE x y f _ptr = do 731 funIOE x y f _ptr = do
793 fImm <- funE x <$> getDataFromContents dim y 732 fImm <- funE x <$> SO.getDataFromContents dim y
794 putDataInContents fImm dim f 733 SO.putDataInContents fImm dim f
795 -- FIXME: I don't understand what this comment means 734 -- FIXME: I don't understand what this comment means
796 -- Unsafe since the function will be called many times. 735 -- Unsafe since the function will be called many times.
797 [CU.exp| int{ 0 } |] 736 [CU.exp| int{ 0 } |]