summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/sundials/src/Main.hs10
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs38
2 files changed, 36 insertions, 12 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 9978aa5..6d6a397 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -69,6 +69,12 @@ stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
69 lamda = -100.0 69 lamda = -100.0
70 u = v !! 0 70 u = v !! 0
71 71
72stiffishV :: Double -> Vector Double -> Vector Double
73stiffishV t v = fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
74 where
75 lamda = -100.0
76 u = v ! 0
77
72stiffJac :: Double -> Vector Double -> Matrix Double 78stiffJac :: Double -> Vector Double -> Matrix Double
73stiffJac _t _v = (1><1) [ lamda ] 79stiffJac _t _v = (1><1) [ lamda ]
74 where 80 where
@@ -143,6 +149,10 @@ main = do
143 (D.dims2D 500.0 500.0) 149 (D.dims2D 500.0 500.0)
144 (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) 150 (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2))
145 151
152 let res2a = odeSolveV (SDIRK_5_3_4 stiffJac) 0.1 1e-3 1e-6 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
153 putStrLn "Lower tolerances"
154 putStrLn $ show res2a
155
146 let res3 = odeSolve' (SDIRK_5_3_4 lorenzJac) lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) 156 let res3 = odeSolve' (SDIRK_5_3_4 lorenzJac) lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0])
147 putStrLn $ show $ last ((toLists $ tr res3)!!0) 157 putStrLn $ show $ last ((toLists $ tr res3)!!0)
148 158
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 9ddb3df..85f1b3d 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -215,7 +215,7 @@ odeSolveV
215 -> Vector Double -- ^ desired solution times 215 -> Vector Double -- ^ desired solution times
216 -> Matrix Double -- ^ solution 216 -> Matrix Double -- ^ solution
217odeSolveV meth _hi epsAbs epsRel f y0 ts = 217odeSolveV meth _hi epsAbs epsRel f y0 ts =
218 case odeSolveVWith meth (XX' epsAbs epsRel 1 1) epsAbs epsAbs g y0 ts of 218 case odeSolveVWith meth (X epsAbs epsRel) g y0 ts of
219 Left c -> error $ show c -- FIXME 219 Left c -> error $ show c -- FIXME
220 -- FIXME: Can we do better than using lists? 220 -- FIXME: Can we do better than using lists?
221 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 221 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
@@ -235,7 +235,7 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
235 -> Matrix Double -- ^ solution 235 -> Matrix Double -- ^ solution
236odeSolve f y0 ts = 236odeSolve f y0 ts =
237 -- FIXME: These tolerances are different from the ones in GSL 237 -- FIXME: These tolerances are different from the ones in GSL
238 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of 238 case odeSolveVWith SDIRK_5_3_4' (XX' 1.0e-6 1.0e-10 1 1) g (V.fromList y0) (V.fromList $ toList ts) of
239 Left c -> error $ show c -- FIXME 239 Left c -> error $ show c -- FIXME
240 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 240 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
241 where 241 where
@@ -250,7 +250,7 @@ odeSolve' :: ODEMethod
250 -> Vector Double -- ^ desired solution times 250 -> Vector Double -- ^ desired solution times
251 -> Matrix Double -- ^ solution 251 -> Matrix Double -- ^ solution
252odeSolve' method f y0 ts = 252odeSolve' method f y0 ts =
253 case odeSolveVWith method (XX' 1.0e-6 1.0e-10 1 1) 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of 253 case odeSolveVWith method (XX' 1.0e-6 1.0e-10 1 1) g (V.fromList y0) (V.fromList $ toList ts) of
254 Left c -> error $ show c -- FIXME 254 Left c -> error $ show c -- FIXME
255 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v) 255 Right (v, d) -> trace (show d) $ (nR >< nC) (V.toList v)
256 where 256 where
@@ -262,36 +262,40 @@ odeSolve' method f y0 ts =
262odeSolveVWith :: 262odeSolveVWith ::
263 ODEMethod 263 ODEMethod
264 -> StepControl 264 -> StepControl
265 -> Double
266 -> Double
267 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 265 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
268 -> V.Vector Double -- ^ Initial conditions 266 -> V.Vector Double -- ^ Initial conditions
269 -> V.Vector Double -- ^ Desired solution times 267 -> V.Vector Double -- ^ Desired solution times
270 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 268 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
271odeSolveVWith method _control relTol absTol f y0 tt = 269odeSolveVWith method control f y0 tt =
272 case solveOdeC (fromIntegral $ getMethod method) jacH (CDouble relTol) (CDouble absTol) 270 case solveOdeC (fromIntegral $ getMethod method) jacH (scise control)
273 (coerce f) (coerce y0) (coerce tt) of 271 (coerce f) (coerce y0) (coerce tt) of
274 Left c -> Left $ fromIntegral c 272 Left c -> Left $ fromIntegral c
275 Right (v, d) -> Right (coerce v, d) 273 Right (v, d) -> Right (coerce v, d)
276 where 274 where
275 l = size y0
276 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol)
277 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol)
278 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol)
279 -- FIXME; Should we check that the length of ss is correct?
280 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol)
277 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ 281 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
278 getJacobian method 282 getJacobian method
279 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } 283 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
280 where 284 where
281 nr = fromIntegral $ rows m 285 nr = fromIntegral $ rows m
282 nc = fromIntegral $ cols m 286 nc = fromIntegral $ cols m
287 -- FIXME: efficiency
283 vs = V.fromList $ map coerce $ concat $ toLists m 288 vs = V.fromList $ map coerce $ concat $ toLists m
284 289
285solveOdeC :: 290solveOdeC ::
286 CInt -> 291 CInt ->
287 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> 292 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
288 CDouble -> 293 (V.Vector CDouble, CDouble) ->
289 CDouble ->
290 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 294 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
291 -> V.Vector CDouble -- ^ Initial conditions 295 -> V.Vector CDouble -- ^ Initial conditions
292 -> V.Vector CDouble -- ^ Desired solution times 296 -> V.Vector CDouble -- ^ Desired solution times
293 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 297 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
294solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do 298solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
295 let dim = V.length f0 299 let dim = V.length f0
296 nEq :: CLong 300 nEq :: CLong
297 nEq = fromIntegral dim 301 nEq = fromIntegral dim
@@ -336,6 +340,9 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
336 /* general problem variables */ 340 /* general problem variables */
337 int flag; /* reusable error-checking flag */ 341 int flag; /* reusable error-checking flag */
338 N_Vector y = NULL; /* empty vector for storing solution */ 342 N_Vector y = NULL; /* empty vector for storing solution */
343 /* empty vector for storing absolute tolerances */
344 N_Vector tv = NULL;
345
339 SUNMatrix A = NULL; /* empty matrix for linear solver */ 346 SUNMatrix A = NULL; /* empty matrix for linear solver */
340 SUNLinearSolver LS = NULL; /* empty linear solver object */ 347 SUNLinearSolver LS = NULL; /* empty linear solver object */
341 void *arkode_mem = NULL; /* empty ARKode memory structure */ 348 void *arkode_mem = NULL; /* empty ARKode memory structure */
@@ -353,6 +360,13 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
353 for (i = 0; i < NEQ; i++) { 360 for (i = 0; i < NEQ; i++) {
354 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; 361 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
355 }; /* Specify initial condition */ 362 }; /* Specify initial condition */
363
364 tv = N_VNew_Serial(NEQ); /* Create serial vector for solution */
365 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
366 for (i = 0; i < NEQ; i++) {
367 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i];
368 };
369
356 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 370 arkode_mem = ARKodeCreate(); /* Create the solver memory */
357 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; 371 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
358 372
@@ -367,8 +381,8 @@ solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
367 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 381 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
368 382
369 /* Set routines */ 383 /* Set routines */
370 flag = ARKodeSStolerances(arkode_mem, $(double relTol), $(double absTol)); 384 flag = ARKodeSVtolerances(arkode_mem, $(double relTol), tv);
371 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; 385 if (check_flag(&flag, "ARKodeSVtolerances", 1)) return 1;
372 386
373 /* Initialize dense matrix data structure and solver */ 387 /* Initialize dense matrix data structure and solver */
374 A = SUNDenseMatrix(NEQ, NEQ); 388 A = SUNDenseMatrix(NEQ, NEQ);