summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-04-24 11:53:50 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-04-24 11:53:50 +0100
commitc73f86f64a60209a50b9c4cc3b137726955f2df7 (patch)
tree2324868db4dec3a51580592bda7898e9cb1e757a
parent79962d2141f356b6a8018d767e49db162a146405 (diff)
CVODE now supported somewhat
-rw-r--r--packages/sundials/src/Main.hs20
-rw-r--r--packages/sundials/src/Numeric/Sundials/CVode/ODE.hs144
2 files changed, 147 insertions, 17 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 3904b09..85928e2 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -117,15 +117,21 @@ main = do
117 117
118 let res2b = ARK.odeSolveV (ARK.TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) 118 let res2b = ARK.odeSolveV (ARK.TRBDF2_3_3_2') Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
119 119
120 let maxDiff = maximum $ map abs $ 120 let maxDiffA = maximum $ map abs $
121 zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0) 121 zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2b)!!0)
122
123 hspec $ describe "Compare results" $ do
124 it "for two different RK methods" $
125 maxDiff < 1.0e-6
126 122
127 let res2c = CV.odeSolveV (CV.BDF) Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0]) 123 let res2c = CV.odeSolveV (CV.BDF) Nothing 1e-6 1e-10 stiffishV (fromList [0.0]) (fromList [0.0, 0.1 .. 10.0])
128 putStrLn $ show res2c 124
125 let maxDiffB = maximum $ map abs $
126 zipWith (-) ((toLists $ tr res2a)!!0) ((toLists $ tr res2c)!!0)
127
128 let maxDiffC = maximum $ map abs $
129 zipWith (-) ((toLists $ tr res2b)!!0) ((toLists $ tr res2c)!!0)
130
131 hspec $ describe "Compare results" $ do
132 it "for SDIRK_5_3_4' and TRBDF2_3_3_2'" $ maxDiffA < 1.0e-6
133 it "for SDIRK_5_3_4' and BDF" $ maxDiffB < 1.0e-6
134 it "for TRBDF2_3_3_2' and BDF" $ maxDiffC < 1.0e-6
129 135
130 let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0]) 136 let res3 = ARK.odeSolve lorenz [-5.0, -5.0, 1.0] (fromList [0.0, 0.01 .. 10.0])
131 137
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
index f75d91f..abe1bfe 100644
--- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
@@ -132,8 +132,7 @@ import System.IO.Unsafe (unsafePerformIO)
132import Numeric.LinearAlgebra.Devel (createVector) 132import Numeric.LinearAlgebra.Devel (createVector)
133 133
134import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), 134import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
135 subMatrix, rows, cols, toLists, 135 rows, cols, toLists, size)
136 size, subVector)
137 136
138import qualified Types as T 137import qualified Types as T
139import Arkode (cV_ADAMS, cV_BDF) 138import Arkode (cV_ADAMS, cV_BDF)
@@ -247,7 +246,7 @@ odeSolveV meth hi epsAbs epsRel f y0 ts =
247 case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of 246 case odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts of
248 Left c -> error $ show c -- FIXME 247 Left c -> error $ show c -- FIXME
249 -- FIXME: Can we do better than using lists? 248 -- FIXME: Can we do better than using lists?
250 Right (v, d) -> (nR >< nC) (V.toList v) 249 Right (v, _d) -> (nR >< nC) (V.toList v)
251 where 250 where
252 us = toList ts 251 us = toList ts
253 nR = length us 252 nR = length us
@@ -266,7 +265,7 @@ odeSolve f y0 ts =
266 -- FIXME: These tolerances are different from the ones in GSL 265 -- FIXME: These tolerances are different from the ones in GSL
267 case odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 266 case odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of
268 Left c -> error $ show c -- FIXME 267 Left c -> error $ show c -- FIXME
269 Right (v, d) -> (nR >< nC) (V.toList v) 268 Right (v, _d) -> (nR >< nC) (V.toList v)
270 where 269 where
271 us = toList ts 270 us = toList ts
272 nR = length us 271 nR = length us
@@ -353,13 +352,12 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
353 nEq = fromIntegral dim 352 nEq = fromIntegral dim
354 nTs :: CInt 353 nTs :: CInt
355 nTs = fromIntegral $ V.length ts 354 nTs = fromIntegral $ V.length ts
356 -- FIXME: fMut is not actually mutatated 355 -- FIXME: tMut is not actually mutatated?
357 fMut <- V.thaw f0
358 tMut <- V.thaw ts 356 tMut <- V.thaw ts
359 -- FIXME: I believe this gets taken from the ghc heap and so should 357 -- FIXME: I believe this gets taken from the ghc heap and so should
360 -- be subject to garbage collection. 358 -- be subject to garbage collection.
361 -- quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) 359 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
362 -- qMatMut <- V.thaw quasiMatrixRes 360 qMatMut <- V.thaw quasiMatrixRes
363 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME 361 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
364 diagMut <- V.thaw diagnostics 362 diagMut <- V.thaw diagnostics
365 -- We need the types that sundials expects. These are tied together 363 -- We need the types that sundials expects. These are tied together
@@ -394,7 +392,13 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
394 int flag; /* reusable error-checking flag */ 392 int flag; /* reusable error-checking flag */
395 int i, j; /* reusable loop indices */ 393 int i, j; /* reusable loop indices */
396 N_Vector y = NULL; /* empty vector for storing solution */ 394 N_Vector y = NULL; /* empty vector for storing solution */
395 N_Vector tv = NULL; /* empty vector for storing absolute tolerances */
396
397 SUNMatrix A = NULL; /* empty matrix for linear solver */
398 SUNLinearSolver LS = NULL; /* empty linear solver object */
397 void *cvode_mem = NULL; /* empty CVODE memory structure */ 399 void *cvode_mem = NULL; /* empty CVODE memory structure */
400 realtype t;
401 long int nst, nfe, nsetups, nje, nfeLS, nni, ncfn, netf, nge;
398 402
399 /* general problem parameters */ 403 /* general problem parameters */
400 404
@@ -410,7 +414,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
410 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; 414 NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i];
411 }; 415 };
412 416
413 cvode_mem = CVodeCreate(CV_BDF, CV_NEWTON); 417 cvode_mem = CVodeCreate($(int method), CV_NEWTON);
414 if (check_flag((void *)cvode_mem, "CVodeCreate", 0)) return(1); 418 if (check_flag((void *)cvode_mem, "CVodeCreate", 0)) return(1);
415 419
416 /* Call CVodeInit to initialize the integrator memory and specify the 420 /* Call CVodeInit to initialize the integrator memory and specify the
@@ -419,16 +423,136 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
419 flag = CVodeInit(cvode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y); 423 flag = CVodeInit(cvode_mem, $fun:(int (* funIO) (double t, SunVector y[], SunVector dydt[], void * params)), T0, y);
420 if (check_flag(&flag, "CVodeInit", 1)) return(1); 424 if (check_flag(&flag, "CVodeInit", 1)) return(1);
421 425
426 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
427 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
428 /* Specify tolerances */
429 for (i = 0; i < NEQ; i++) {
430 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i];
431 };
432
433 /* FIXME: A hack for initial testing */
434 flag = CVodeSetMinStep(cvode_mem, 1.0e-12);
435 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1;
436 flag = CVodeSetMaxNumSteps(cvode_mem, 10000);
437 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1;
438
439 /* Call CVodeSVtolerances to specify the scalar relative tolerance
440 * and vector absolute tolerances */
441 flag = CVodeSVtolerances(cvode_mem, $(double relTol), tv);
442 if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1);
443
444 /* Initialize dense matrix data structure and solver */
445 A = SUNDenseMatrix(NEQ, NEQ);
446 if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
447 LS = SUNDenseLinearSolver(y, A);
448 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
449
450 /* Attach matrix and linear solver */
451 flag = CVDlsSetLinearSolver(cvode_mem, LS, A);
452 if (check_flag(&flag, "CVDlsSetLinearSolver", 1)) return 1;
453
454 /* Set the initial step size if there is one */
455 if ($(int isInitStepSize)) {
456 /* FIXME: We could check if the initial step size is 0 */
457 /* or even NaN and then throw an error */
458 flag = CVodeSetInitStep(cvode_mem, $(double ss));
459 if (check_flag(&flag, "CVodeSetInitStep", 1)) return 1;
460 }
461
462 /* Set the Jacobian if there is one */
463 if ($(int isJac)) {
464 flag = CVDlsSetJacFn(cvode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
465 if (check_flag(&flag, "CVDlsSetJacFn", 1)) return 1;
466 }
467
468 /* Store initial conditions */
469 for (j = 0; j < NEQ; j++) {
470 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
471 }
472
473 /* Main time-stepping loop: calls CVode to perform the integration */
474 /* Stops when the final time has been reached */
475 for (i = 1; i < $(int nTs); i++) {
476
477 flag = CVode(cvode_mem, ($vec-ptr:(double *tMut))[i], y, &t, CV_NORMAL); /* call integrator */
478 if (check_flag(&flag, "CVode", 1)) break;
479
480 /* Store the results for Haskell */
481 for (j = 0; j < NEQ; j++) {
482 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
483 }
484
485 /* unsuccessful solve: break */
486 if (flag < 0) {
487 fprintf(stderr,"Solver failure, stopping integration\n");
488 break;
489 }
490 }
491
492 /* Get some final statistics on how the solve progressed */
493
494 flag = CVodeGetNumSteps(cvode_mem, &nst);
495 check_flag(&flag, "CVodeGetNumSteps", 1);
496 ($vec-ptr:(long int *diagMut))[0] = nst;
497
498 /* FIXME */
499 ($vec-ptr:(long int *diagMut))[1] = 0;
500
501 flag = CVodeGetNumRhsEvals(cvode_mem, &nfe);
502 check_flag(&flag, "CVodeGetNumRhsEvals", 1);
503 ($vec-ptr:(long int *diagMut))[2] = nfe;
504 /* FIXME */
505 ($vec-ptr:(long int *diagMut))[3] = 0;
506
507 flag = CVodeGetNumLinSolvSetups(cvode_mem, &nsetups);
508 check_flag(&flag, "CVodeGetNumLinSolvSetups", 1);
509 ($vec-ptr:(long int *diagMut))[4] = nsetups;
510
511 flag = CVodeGetNumErrTestFails(cvode_mem, &netf);
512 check_flag(&flag, "CVodeGetNumErrTestFails", 1);
513 ($vec-ptr:(long int *diagMut))[5] = netf;
514
515 flag = CVodeGetNumNonlinSolvIters(cvode_mem, &nni);
516 check_flag(&flag, "CVodeGetNumNonlinSolvIters", 1);
517 ($vec-ptr:(long int *diagMut))[6] = nni;
518
519 flag = CVodeGetNumNonlinSolvConvFails(cvode_mem, &ncfn);
520 check_flag(&flag, "CVodeGetNumNonlinSolvConvFails", 1);
521 ($vec-ptr:(long int *diagMut))[7] = ncfn;
522
523 flag = CVDlsGetNumJacEvals(cvode_mem, &nje);
524 check_flag(&flag, "CVDlsGetNumJacEvals", 1);
525 ($vec-ptr:(long int *diagMut))[8] = ncfn;
526
527 flag = CVDlsGetNumRhsEvals(cvode_mem, &nfeLS);
528 check_flag(&flag, "CVDlsGetNumRhsEvals", 1);
529 ($vec-ptr:(long int *diagMut))[9] = ncfn;
530
422 /* Clean up and return */ 531 /* Clean up and return */
423 532
424 N_VDestroy(y); /* Free y vector */ 533 N_VDestroy(y); /* Free y vector */
534 N_VDestroy(tv); /* Free tv vector */
425 CVodeFree(&cvode_mem); /* Free integrator memory */ 535 CVodeFree(&cvode_mem); /* Free integrator memory */
536 SUNLinSolFree(LS); /* Free linear solver */
537 SUNMatDestroy(A); /* Free A matrix */
426 538
427 return flag; 539 return flag;
428 } |] 540 } |]
429 if res == 0 541 if res == 0
430 then do 542 then do
431 return $ Left res 543 preD <- V.freeze diagMut
544 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
545 (fromIntegral $ preD V.!1)
546 (fromIntegral $ preD V.!2)
547 (fromIntegral $ preD V.!3)
548 (fromIntegral $ preD V.!4)
549 (fromIntegral $ preD V.!5)
550 (fromIntegral $ preD V.!6)
551 (fromIntegral $ preD V.!7)
552 (fromIntegral $ preD V.!8)
553 (fromIntegral $ preD V.!9)
554 m <- V.freeze qMatMut
555 return $ Right (m, d)
432 else do 556 else do
433 return $ Left res 557 return $ Left res
434 558