summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs38
1 files changed, 26 insertions, 12 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 9ddb3df..85f1b3d 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -215,7 +215,7 @@ odeSolveV
215 -> Vector Double -- ^ desired solution times 215 -> Vector Double -- ^ desired solution times
216 -> Matrix Double -- ^ solution 216 -> Matrix Double -- ^ solution
217odeSolveV meth _hi epsAbs epsRel f y0 ts = 217odeSolveV meth _hi epsAbs epsRel f y0 ts =
218 case odeSolveVWith meth (XX' epsAbs epsRel 1 1) epsAbs epsAbs g y0 ts of 218 case odeSolveVWith meth (X epsAbs epsRel) g y0 ts of
219 Left c -> error $ show c -- FIXME 219 Left c -> error $ show c -- FIXME
220 -- FIXME: Can we do better than using lists? 220 -- FIXME: Can we do better than using lists?
221 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 221 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
@@ -235,7 +235,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
235 -> Matrix Double -- ^ solution 235 -> Matrix Double -- ^ solution
236odeSolve f y0 ts = 236odeSolve f y0 ts =
237 -- FIXME: These tolerances are different from the ones in GSL 237 -- FIXME: These tolerances are different from the ones in GSL
238 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of 238 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) g (V.fromList y0) (V.fromList $ toList ts) of
239 Left c -> error $ show c -- FIXME 239 Left c -> error $ show c -- FIXME
240 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 240 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
241 where 241 where
@@ -250,7 +250,7 @@ odeSolve' :: ODEMethod
250 -> Vector Double -- ^ desired solution times 250 -> Vector Double -- ^ desired solution times
251 -> Matrix Double -- ^ solution 251 -> Matrix Double -- ^ solution
252odeSolve' method f y0 ts = 252odeSolve' method f y0 ts =
253 case odeSolveVWith method (XX' 1.0e-6 1.0e-10 1 1) 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of 253 case odeSolveVWith method (XX' 1.0e-6 1.0e-10 1 1) g (V.fromList y0) (V.fromList $ toList ts) of
254 Left c -> error $ show c -- FIXME 254 Left c -> error $ show c -- FIXME
255 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 255 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
256 where 256 where
@@ -262,36 +262,40 @@ odeSolve' method f y0 ts =
262odeSolveVWith :: 262odeSolveVWith ::
263 ODEMethod 263 ODEMethod
264 -> StepControl 264 -> StepControl
265 -> Double
266 -> Double
267 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 265 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
268 -> V.Vector Double -- ^ Initial conditions 266 -> V.Vector Double -- ^ Initial conditions
269 -> V.Vector Double -- ^ Desired solution times 267 -> V.Vector Double -- ^ Desired solution times
270 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 268 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
271odeSolveVWith method _control relTol absTol f y0 tt = 269odeSolveVWith method control f y0 tt =
272 case solveOdeC (fromIntegral $ getMethod method) jacH (CDouble relTol) (CDouble absTol) 270 case solveOdeC (fromIntegral $ getMethod method) jacH (scise control)
273 (coerce f) (coerce y0) (coerce tt) of 271 (coerce f) (coerce y0) (coerce tt) of
274 Left c -> Left $ fromIntegral c 272 Left c -> Left $ fromIntegral c
275 Right (v, d) -> Right (coerce v, d) 273 Right (v, d) -> Right (coerce v, d)
276 where 274 where
275 l = size y0
276 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol)
277 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol)
278 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol)
279 -- FIXME; Should we check that the length of ss is correct?
280 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol)
277 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ 281 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
278 getJacobian method 282 getJacobian method
279 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } 283 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
280 where 284 where
281 nr = fromIntegral $ rows m 285 nr = fromIntegral $ rows m
282 nc = fromIntegral $ cols m 286 nc = fromIntegral $ cols m
287 -- FIXME: efficiency
283 vs = V.fromList $ map coerce $ concat $ toLists m 288 vs = V.fromList $ map coerce $ concat $ toLists m
284 289
285solveOdeC :: 290solveOdeC ::
286 CInt -> 291 CInt ->
287 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> 292 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
288 CDouble -> 293 (V.Vector CDouble, CDouble) ->
289 CDouble ->
290 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 294 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
291 -> V.Vector CDouble -- ^ Initial conditions 295 -> V.Vector CDouble -- ^ Initial conditions
292 -> V.Vector CDouble -- ^ Desired solution times 296 -> V.Vector CDouble -- ^ Desired solution times
293 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 297 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
294solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do 298solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
295 let dim = V.length f0 299 let dim = V.length f0
296 nEq :: CLong 300 nEq :: CLong
297 nEq = fromIntegral dim 301 nEq = fromIntegral dim
@@ -336,6 +340,9 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
336 /* general problem variables */ 340 /* general problem variables */
337 int flag; /* reusable error-checking flag */ 341 int flag; /* reusable error-checking flag */
338 N_Vector y = NULL; /* empty vector for storing solution */ 342 N_Vector y = NULL; /* empty vector for storing solution */
343 /* empty vector for storing absolute tolerances */
344 N_Vector tv = NULL;
345
339 SUNMatrix A = NULL; /* empty matrix for linear solver */ 346 SUNMatrix A = NULL; /* empty matrix for linear solver */
340 SUNLinearSolver LS = NULL; /* empty linear solver object */ 347 SUNLinearSolver LS = NULL; /* empty linear solver object */
341 void *arkode_mem = NULL; /* empty ARKode memory structure */ 348 void *arkode_mem = NULL; /* empty ARKode memory structure */
@@ -353,6 +360,13 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
353 for (i = 0; i < NEQ; i++) { 360 for (i = 0; i < NEQ; i++) {
354 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; 361 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
355 }; /* Specify initial condition */ 362 }; /* Specify initial condition */
363
364 tv = N_VNew_Serial(NEQ); /* Create serial vector for solution */
365 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
366 for (i = 0; i < NEQ; i++) {
367 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i];
368 };
369
356 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 370 arkode_mem = ARKodeCreate(); /* Create the solver memory */
357 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; 371 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
358 372
@@ -367,8 +381,8 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
367 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 381 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
368 382
369 /* Set routines */ 383 /* Set routines */
370 flag = ARKodeSStolerances(arkode_mem, $(double relTol), $(double absTol)); 384 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv);
371 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; 385 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1;
372 386
373 /* Initialize dense matrix data structure and solver */ 387 /* Initialize dense matrix data structure and solver */
374 A = SUNDenseMatrix(NEQ, NEQ); 388 A = SUNDenseMatrix(NEQ, NEQ);