summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs99
1 files changed, 75 insertions, 24 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 30ff4c8..5af9e41 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -35,6 +35,7 @@ module Numeric.Sundials.Arkode.ODE ( solveOde
35 , getBT 35 , getBT
36 , btGet 36 , btGet
37 , ODEMethod(..) 37 , ODEMethod(..)
38 , odeSolveV
38 ) where 39 ) where
39 40
40import qualified Language.C.Inline as C 41import qualified Language.C.Inline as C
@@ -45,7 +46,7 @@ import Data.Monoid ((<>))
45import Foreign.C.Types 46import Foreign.C.Types
46import Foreign.Ptr (Ptr) 47import Foreign.Ptr (Ptr)
47import Foreign.ForeignPtr (newForeignPtr_) 48import Foreign.ForeignPtr (newForeignPtr_)
48import Foreign.Storable (Storable, peekByteOff) 49import Foreign.Storable (Storable)
49 50
50import qualified Data.Vector.Storable as V 51import qualified Data.Vector.Storable as V
51import qualified Data.Vector.Storable.Mutable as VM 52import qualified Data.Vector.Storable.Mutable as VM
@@ -55,7 +56,8 @@ import System.IO.Unsafe (unsafePerformIO)
55 56
56import Numeric.LinearAlgebra.Devel (createVector) 57import Numeric.LinearAlgebra.Devel (createVector)
57 58
58import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) 59import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
60 subMatrix, rows, cols, toLists)
59 61
60import qualified Types as T 62import qualified Types as T
61import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) 63import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3)
@@ -78,12 +80,34 @@ C.include "../../../helpers.h"
78C.include "Arkode_hsc.h" 80C.include "Arkode_hsc.h"
79 81
80 82
81getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) 83getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble)
82getDataFromContents len ptr = do 84getDataFromContents len ptr = do
83 qtr <- B.getContentPtr ptr 85 qtr <- B.getContentPtr ptr
84 rtr <- B.getData qtr 86 rtr <- B.getData qtr
85 vectorFromC len rtr 87 vectorFromC len rtr
86 88
89-- FIXME: Potentially an instance of Storable
90getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix
91getMatrixDataFromContents ptr = do
92 qtr <- B.getContentMatrixPtr ptr
93 rs <- B.getNRows qtr
94 cs <- B.getNCols qtr
95 rtr <- B.getMatrixData qtr
96 vs <- vectorFromC (fromIntegral $ rs * cs) rtr
97 return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs }
98
99putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO ()
100putMatrixDataFromContents mat ptr = do
101 let rs = T.rows mat
102 cs = T.cols mat
103 vs = T.vals mat
104 qtr <- B.getContentMatrixPtr ptr
105 B.putNRows rs qtr
106 B.putNCols cs qtr
107 rtr <- B.getMatrixData qtr
108 vectorToC vs (fromIntegral $ rs * cs) rtr
109-- FIXME: END
110
87putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () 111putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
88putDataInContents vec len ptr = do 112putDataInContents vec len ptr = do
89 qtr <- B.getContentPtr ptr 113 qtr <- B.getContentPtr ptr
@@ -103,16 +127,16 @@ vectorToC vec len ptr = do
103 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 127 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
104 128
105data SundialsDiagnostics = SundialsDiagnostics { 129data SundialsDiagnostics = SundialsDiagnostics {
106 aRKodeGetNumSteps :: Int 130 _aRKodeGetNumSteps :: Int
107 , aRKodeGetNumStepAttempts :: Int 131 , _aRKodeGetNumStepAttempts :: Int
108 , aRKodeGetNumRhsEvals_fe :: Int 132 , _aRKodeGetNumRhsEvals_fe :: Int
109 , aRKodeGetNumRhsEvals_fi :: Int 133 , _aRKodeGetNumRhsEvals_fi :: Int
110 , aRKodeGetNumLinSolvSetups :: Int 134 , _aRKodeGetNumLinSolvSetups :: Int
111 , aRKodeGetNumErrTestFails :: Int 135 , _aRKodeGetNumErrTestFails :: Int
112 , aRKodeGetNumNonlinSolvIters :: Int 136 , _aRKodeGetNumNonlinSolvIters :: Int
113 , aRKodeGetNumNonlinSolvConvFails :: Int 137 , _aRKodeGetNumNonlinSolvConvFails :: Int
114 , aRKDlsGetNumJacEvals :: Int 138 , _aRKDlsGetNumJacEvals :: Int
115 , aRKDlsGetNumRhsEvals :: Int 139 , _aRKDlsGetNumRhsEvals :: Int
116 } deriving Show 140 } deriving Show
117 141
118-- | Stepping functions 142-- | Stepping functions
@@ -134,15 +158,16 @@ odeSolveV
134 -> Vector Double -- ^ initial conditions 158 -> Vector Double -- ^ initial conditions
135 -> Vector Double -- ^ desired solution times 159 -> Vector Double -- ^ desired solution times
136 -> Matrix Double -- ^ solution 160 -> Matrix Double -- ^ solution
137odeSolveV meth hi epsAbs epsRel = undefined 161odeSolveV _meth _hi _epsAbs _epsRel = undefined
138 162
139odeSolve :: ODEMethod 163odeSolve :: ODEMethod
164 -> (Double -> Vector Double -> Matrix Double)
140 -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 165 -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
141 -> [Double] -- ^ initial conditions 166 -> [Double] -- ^ initial conditions
142 -> Vector Double -- ^ desired solution times 167 -> Vector Double -- ^ desired solution times
143 -> Matrix Double -- ^ solution 168 -> Matrix Double -- ^ solution
144odeSolve method f y0 ts = 169odeSolve method jac f y0 ts =
145 case solveOde method 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of 170 case solveOde method jac' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of
146 Left c -> error $ show c -- FIXME 171 Left c -> error $ show c -- FIXME
147 Right (v, _) -> (nR >< nC) (V.toList v) 172 Right (v, _) -> (nR >< nC) (V.toList v)
148 where 173 where
@@ -150,30 +175,40 @@ odeSolve method f y0 ts =
150 nR = length us 175 nR = length us
151 nC = length y0 176 nC = length y0
152 g t x0 = V.fromList $ f t (V.toList x0) 177 g t x0 = V.fromList $ f t (V.toList x0)
178 jac' t v = foo $ jac t (V.fromList $ toList v)
179 foo m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
180 where
181 nr = fromIntegral $ rows m
182 nc = fromIntegral $ cols m
183 vs = V.fromList $ map coerce $ concat $ toLists m
153 184
154solveOde :: 185solveOde ::
155 ODEMethod 186 ODEMethod
187 -> (Double -> V.Vector Double -> T.SunMatrix)
156 -> Double 188 -> Double
157 -> Double 189 -> Double
158 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 190 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
159 -> V.Vector Double -- ^ Initial conditions 191 -> V.Vector Double -- ^ Initial conditions
160 -> V.Vector Double -- ^ Desired solution times 192 -> V.Vector Double -- ^ Desired solution times
161 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 193 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
162solveOde method relTol absTol f y0 tt = 194solveOde method jac relTol absTol f y0 tt =
163 case solveOdeC (fromIntegral $ fromEnum method) (CDouble relTol) (CDouble absTol) 195 case solveOdeC (fromIntegral $ fromEnum method) jacH (CDouble relTol) (CDouble absTol)
164 (coerce f) (coerce y0) (coerce tt) of 196 (coerce f) (coerce y0) (coerce tt) of
165 Left c -> Left $ fromIntegral c 197 Left c -> Left $ fromIntegral c
166 Right (v, d) -> Right (coerce v, d) 198 Right (v, d) -> Right (coerce v, d)
199 where
200 jacH t v = jac (coerce t) (coerce v)
167 201
168solveOdeC :: 202solveOdeC ::
169 CInt -> 203 CInt ->
204 (CDouble -> V.Vector CDouble -> T.SunMatrix) ->
170 CDouble -> 205 CDouble ->
171 CDouble -> 206 CDouble ->
172 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 207 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
173 -> V.Vector CDouble -- ^ Initial conditions 208 -> V.Vector CDouble -- ^ Initial conditions
174 -> V.Vector CDouble -- ^ Desired solution times 209 -> V.Vector CDouble -- ^ Desired solution times
175 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 210 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
176solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do 211solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
177 let dim = V.length f0 212 let dim = V.length f0
178 nEq :: CLong 213 nEq :: CLong
179 nEq = fromIntegral dim 214 nEq = fromIntegral dim
@@ -197,9 +232,19 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do
197 fImm <- fun x <$> getDataFromContents dim y 232 fImm <- fun x <$> getDataFromContents dim y
198 -- Fill in the provided pointer with the resulting vector. 233 -- Fill in the provided pointer with the resulting vector.
199 putDataInContents fImm dim f 234 putDataInContents fImm dim f
200 -- I don't understand what this comment means 235 -- FIXME: I don't understand what this comment means
236 -- Unsafe since the function will be called many times.
237 [CU.exp| int{ 0 } |]
238 let jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix ->
239 Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector ->
240 IO CInt
241 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
242 foo <- jacH t <$> getDataFromContents dim y
243 putMatrixDataFromContents foo jacS
244 -- FIXME: I don't understand what this comment means
201 -- Unsafe since the function will be called many times. 245 -- Unsafe since the function will be called many times.
202 [CU.exp| int{ 0 } |] 246 [CU.exp| int{ 0 } |]
247
203 res <- [C.block| int { 248 res <- [C.block| int {
204 /* general problem variables */ 249 /* general problem variables */
205 int flag; /* reusable error-checking flag */ 250 int flag; /* reusable error-checking flag */
@@ -246,6 +291,11 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do
246 291
247 /* Linear solver interface */ 292 /* Linear solver interface */
248 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ 293 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
294
295 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
296 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1;
297
298
249 /* Store initial conditions */ 299 /* Store initial conditions */
250 for (j = 0; j < NEQ; j++) { 300 for (j = 0; j < NEQ; j++) {
251 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); 301 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
@@ -340,7 +390,8 @@ btGet :: ODEMethod -> Matrix Double
340btGet method = case getBT method of 390btGet method = case getBT method of
341 Left c -> error $ show c -- FIXME 391 Left c -> error $ show c -- FIXME
342 -- FIXME 392 -- FIXME
343 Right (v, sqp) -> subMatrix (0, 0) (2, 2) $ (B.arkSMax >< B.arkSMax) (V.toList v) 393 Right (v, _sqp) -> subMatrix (0, 0) (2, 2) $
394 (B.arkSMax >< B.arkSMax) (V.toList v)
344 395
345getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) 396getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int)
346getBT method = case getButcherTable method of 397getBT method = case getButcherTable method of
@@ -352,9 +403,9 @@ getButcherTable method = unsafePerformIO $ do
352 -- arkode seems to want an ODE in order to set and then get the 403 -- arkode seems to want an ODE in order to set and then get the
353 -- Butcher tableau so here's one to keep it happy 404 -- Butcher tableau so here's one to keep it happy
354 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble 405 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble
355 fun t ys = V.fromList [ ys V.! 0 ] 406 fun _t ys = V.fromList [ ys V.! 0 ]
356 f0 = V.fromList [ 1.0 ] 407 f0 = V.fromList [ 1.0 ]
357 ts = V.fromList [ 0.0 ] 408 ts = V.fromList [ 0.0 ]
358 dim = V.length f0 409 dim = V.length f0
359 nEq :: CLong 410 nEq :: CLong
360 nEq = fromIntegral dim 411 nEq = fromIntegral dim