summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/CVode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/CVode/ODE.hs75
1 files changed, 34 insertions, 41 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
index d7a2b53..0871f9b 100644
--- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
@@ -87,11 +87,12 @@ import System.IO.Unsafe (unsafePerformIO)
87 87
88import Numeric.LinearAlgebra.Devel (createVector) 88import Numeric.LinearAlgebra.Devel (createVector)
89 89
90import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), 90import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows,
91 rows, cols, toLists, size) 91 cols, toLists, size, reshape)
92 92
93import qualified Types as T 93import qualified Types as T
94import Arkode (cV_ADAMS, cV_BDF) 94import Arkode (cV_ADAMS, cV_BDF)
95import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian)
95import qualified Numeric.Sundials.ODEOpts as SO 96import qualified Numeric.Sundials.ODEOpts as SO
96 97
97 98
@@ -111,8 +112,6 @@ C.include "../../../helpers.h"
111C.include "Arkode_hsc.h" 112C.include "Arkode_hsc.h"
112 113
113 114
114type Jacobian = Double -> Vector Double -> Matrix Double
115
116-- | Stepping functions 115-- | Stepping functions
117data ODEMethod = ADAMS 116data ODEMethod = ADAMS
118 | BDF 117 | BDF
@@ -140,14 +139,8 @@ odeSolveV
140 -> Vector Double -- ^ desired solution times 139 -> Vector Double -- ^ desired solution times
141 -> Matrix Double -- ^ solution 140 -> Matrix Double -- ^ solution
142odeSolveV meth hi epsAbs epsRel f y0 ts = 141odeSolveV meth hi epsAbs epsRel f y0 ts =
143 case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of 142 odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts
144 Left c -> error $ show c -- FIXME
145 -- FIXME: Can we do better than using lists?
146 Right (v, _d) -> (nR >< nC) (V.toList v)
147 where 143 where
148 us = toList ts
149 nR = length us
150 nC = size y0
151 g t x0 = coerce $ f t x0 144 g t x0 = coerce $ f t x0
152 145
153-- | A version of 'odeSolveV' with reasonable default parameters and 146-- | A version of 'odeSolveV' with reasonable default parameters and
@@ -160,13 +153,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
160 -> Matrix Double -- ^ solution 153 -> Matrix Double -- ^ solution
161odeSolve f y0 ts = 154odeSolve f y0 ts =
162 -- FIXME: These tolerances are different from the ones in GSL 155 -- FIXME: These tolerances are different from the ones in GSL
163 case odeSolveVWith' BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 156 odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts)
164 Left c -> error $ show c -- FIXME
165 Right (v, _d) -> (nR >< nC) (V.toList v)
166 where 157 where
167 us = toList ts
168 nR = length us
169 nC = length y0
170 g t x0 = V.fromList $ f t (V.toList x0) 158 g t x0 = V.fromList $ f t (V.toList x0)
171 159
172odeSolveVWith :: 160odeSolveVWith ::
@@ -183,15 +171,20 @@ odeSolveVWith ::
183 -> V.Vector Double -- ^ Desired solution times 171 -> V.Vector Double -- ^ Desired solution times
184 -> Matrix Double -- ^ Error code or solution 172 -> Matrix Double -- ^ Error code or solution
185odeSolveVWith method control initStepSize f y0 tt = 173odeSolveVWith method control initStepSize f y0 tt =
186 case odeSolveVWith' method control initStepSize f y0 tt of 174 case odeSolveVWith' opts method control initStepSize f y0 tt of
187 Left c -> error $ show c -- FIXME 175 Left c -> error $ show c -- FIXME
188 Right (v, _d) -> (nR >< nC) (V.toList v) 176 Right (v, _d) -> v
189 where 177 where
190 nR = V.length tt 178 opts = ODEOpts { maxNumSteps = 10000
191 nC = V.length y0 179 , minStep = 1.0e-12
180 , relTol = error "relTol"
181 , absTols = error "absTol"
182 , initStep = error "initStep"
183 }
192 184
193odeSolveVWith' :: 185odeSolveVWith' ::
194 ODEMethod 186 ODEOpts
187 -> ODEMethod
195 -> StepControl 188 -> StepControl
196 -> Maybe Double -- ^ initial step size - by default, ARKode 189 -> Maybe Double -- ^ initial step size - by default, ARKode
197 -- estimates the initial step size to be the 190 -- estimates the initial step size to be the
@@ -202,19 +195,21 @@ odeSolveVWith' ::
202 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 195 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
203 -> V.Vector Double -- ^ Initial conditions 196 -> V.Vector Double -- ^ Initial conditions
204 -> V.Vector Double -- ^ Desired solution times 197 -> V.Vector Double -- ^ Desired solution times
205 -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution 198 -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution
206odeSolveVWith' method control initStepSize f y0 tt = 199odeSolveVWith' opts method control initStepSize f y0 tt =
207 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) 200 case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
201 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
208 (coerce f) (coerce y0) (coerce tt) of 202 (coerce f) (coerce y0) (coerce tt) of
209 Left c -> Left $ fromIntegral c 203 Left c -> Left $ fromIntegral c
210 Right (v, d) -> Right (coerce v, d) 204 Right (v, d) -> Right (reshape nC (coerce v), d)
211 where 205 where
206 nC = V.length y0
212 l = size y0 207 l = size y0
213 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) 208 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
214 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) 209 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
215 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) 210 scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol)
216 -- FIXME; Should we check that the length of ss is correct? 211 -- FIXME; Should we check that the length of ss is correct?
217 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) 212 scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol)
218 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ 213 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
219 getJacobian method 214 getJacobian method
220 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } 215 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
@@ -225,6 +220,8 @@ odeSolveVWith' method control initStepSize f y0 tt =
225 vs = V.fromList $ map coerce $ concat $ toLists m 220 vs = V.fromList $ map coerce $ concat $ toLists m
226 221
227solveOdeC :: 222solveOdeC ::
223 CLong ->
224 CDouble ->
228 CInt -> 225 CInt ->
229 Maybe CDouble -> 226 Maybe CDouble ->
230 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> 227 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
@@ -233,7 +230,8 @@ solveOdeC ::
233 -> V.Vector CDouble -- ^ Initial conditions 230 -> V.Vector CDouble -- ^ Initial conditions
234 -> V.Vector CDouble -- ^ Desired solution times 231 -> V.Vector CDouble -- ^ Desired solution times
235 -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution 232 -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution
236solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do 233solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts =
234 unsafePerformIO $ do
237 235
238 let isInitStepSize :: CInt 236 let isInitStepSize :: CInt
239 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize 237 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
@@ -249,10 +247,6 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
249 nEq = fromIntegral dim 247 nEq = fromIntegral dim
250 nTs :: CInt 248 nTs :: CInt
251 nTs = fromIntegral $ V.length ts 249 nTs = fromIntegral $ V.length ts
252 -- FIXME: tMut is not actually mutatated?
253 tMut <- V.thaw ts
254 -- FIXME: I believe this gets taken from the ghc heap and so should
255 -- be subject to garbage collection.
256 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) 250 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
257 qMatMut <- V.thaw quasiMatrixRes 251 qMatMut <- V.thaw quasiMatrixRes
258 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME 252 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
@@ -324,18 +318,17 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
324 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; 318 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
325 /* Specify tolerances */ 319 /* Specify tolerances */
326 for (i = 0; i < NEQ; i++) { 320 for (i = 0; i < NEQ; i++) {
327 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; 321 NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i];
328 }; 322 };
329 323
330 /* FIXME: A hack for initial testing */ 324 flag = CVodeSetMinStep(cvode_mem, $(double minStep_));
331 flag = CVodeSetMinStep(cvode_mem, 1.0e-12);
332 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; 325 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1;
333 flag = CVodeSetMaxNumSteps(cvode_mem, 10000); 326 flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_));
334 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; 327 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1;
335 328
336 /* Call CVodeSVtolerances to specify the scalar relative tolerance 329 /* Call CVodeSVtolerances to specify the scalar relative tolerance
337 * and vector absolute tolerances */ 330 * and vector absolute tolerances */
338 flag = CVodeSVtolerances(cvode_mem, $(double relTol), tv); 331 flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv);
339 if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); 332 if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1);
340 333
341 /* Initialize dense matrix data structure and solver */ 334 /* Initialize dense matrix data structure and solver */
@@ -371,7 +364,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
371 /* Stops when the final time has been reached */ 364 /* Stops when the final time has been reached */
372 for (i = 1; i < $(int nTs); i++) { 365 for (i = 1; i < $(int nTs); i++) {
373 366
374 flag = CVode(cvode_mem, ($vec-ptr:(double *tMut))[i], y, &t, CV_NORMAL); /* call integrator */ 367 flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */
375 if (check_flag(&flag, "CVode", 1)) break; 368 if (check_flag(&flag, "CVode", 1)) break;
376 369
377 /* Store the results for Haskell */ 370 /* Store the results for Haskell */