summaryrefslogtreecommitdiff
path: root/packages/sundials
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs48
1 files changed, 28 insertions, 20 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 85f1b3d..8f83fe7 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -338,31 +338,34 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
338 338
339 res <- [C.block| int { 339 res <- [C.block| int {
340 /* general problem variables */ 340 /* general problem variables */
341 int flag; /* reusable error-checking flag */
342 N_Vector y = NULL; /* empty vector for storing solution */
343 /* empty vector for storing absolute tolerances */
344 N_Vector tv = NULL;
345 341
346 SUNMatrix A = NULL; /* empty matrix for linear solver */ 342 int flag; /* reusable error-checking flag */
347 SUNLinearSolver LS = NULL; /* empty linear solver object */ 343 int i, j; /* reusable loop indices */
348 void *arkode_mem = NULL; /* empty ARKode memory structure */ 344 N_Vector y = NULL; /* empty vector for storing solution */
345 N_Vector tv = NULL; /* empty vector for storing absolute tolerances */
346 SUNMatrix A = NULL; /* empty matrix for linear solver */
347 SUNLinearSolver LS = NULL; /* empty linear solver object */
348 void *arkode_mem = NULL; /* empty ARKode memory structure */
349 realtype t; 349 realtype t;
350 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; 350 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
351 351
352 /* general problem parameters */ 352 /* general problem parameters */
353 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ 353
354 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ 354 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */
355 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
355 356
356 /* Initialize data structures */ 357 /* Initialize data structures */
357 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ 358
359 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
358 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 360 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
359 int i, j; 361 /* Specify initial condition */
360 for (i = 0; i < NEQ; i++) { 362 for (i = 0; i < NEQ; i++) {
361 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; 363 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
362 }; /* Specify initial condition */ 364 };
363 365
364 tv = N_VNew_Serial(NEQ); /* Create serial vector for solution */ 366 tv = N_VNew_Serial(NEQ); /* Create serial vector for absolute tolerances */
365 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; 367 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
368 /* Specify tolerances */
366 for (i = 0; i < NEQ; i++) { 369 for (i = 0; i < NEQ; i++) {
367 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; 370 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i];
368 }; 371 };
@@ -371,9 +374,9 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
371 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; 374 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
372 375
373 /* Call ARKodeInit to initialize the integrator memory and specify the */ 376 /* Call ARKodeInit to initialize the integrator memory and specify the */
374 /* right-hand side function in y'=f(t,y), the inital time T0, and */ 377 /* right-hand side function in y'=f(t,y), the inital time T0, and */
375 /* the initial dependent variable vector y. Note: since this */ 378 /* the initial dependent variable vector y. Note: we treat the */
376 /* problem is fully implicit, we set f_E to NULL and f_I to f. */ 379 /* problem as fully implicit and set f_E to NULL and f_I to f. */
377 380
378 /* Here we use the C types defined in helpers.h which tie up with */ 381 /* Here we use the C types defined in helpers.h which tie up with */
379 /* the Haskell types defined in Types */ 382 /* the Haskell types defined in Types */
@@ -390,9 +393,11 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
390 LS = SUNDenseLinearSolver(y, A); 393 LS = SUNDenseLinearSolver(y, A);
391 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; 394 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
392 395
393 /* Linear solver interface */ 396 /* Attach matrix and linear solver */
394 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ 397 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A);
398 if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1;
395 399
400 /* Set the Jacobian if there is one */
396 if ($(int isJac)) { 401 if ($(int isJac)) {
397 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); 402 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
398 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; 403 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1;
@@ -403,11 +408,12 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
403 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); 408 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
404 } 409 }
405 410
411 /* Explicitly set the method */
406 flag = ARKodeSetIRKTableNum(arkode_mem, $(int method)); 412 flag = ARKodeSetIRKTableNum(arkode_mem, $(int method));
407 if (check_flag(&flag, "ARKode", 1)) return 1; 413 if (check_flag(&flag, "ARKode", 1)) return 1;
408 414
409 /* Main time-stepping loop: calls ARKode to perform the integration */ 415 /* Main time-stepping loop: calls ARKode to perform the integration */
410 /* Stops when the final time has been reached */ 416 /* Stops when the final time has been reached */
411 for (i = 1; i < $(int nTs); i++) { 417 for (i = 1; i < $(int nTs); i++) {
412 418
413 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ 419 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */
@@ -418,13 +424,15 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do
418 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); 424 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
419 } 425 }
420 426
421 if (flag < 0) { /* unsuccessful solve: break */ 427 /* unsuccessful solve: break */
428 if (flag < 0) {
422 fprintf(stderr,"Solver failure, stopping integration\n"); 429 fprintf(stderr,"Solver failure, stopping integration\n");
423 break; 430 break;
424 } 431 }
425 } 432 }
426 433
427 /* Get some final statistics on how the solve progressed */ 434 /* Get some final statistics on how the solve progressed */
435
428 flag = ARKodeGetNumSteps(arkode_mem, &nst); 436 flag = ARKodeGetNumSteps(arkode_mem, &nst);
429 check_flag(&flag, "ARKodeGetNumSteps", 1); 437 check_flag(&flag, "ARKodeGetNumSteps", 1);
430 ($vec-ptr:(long int *diagMut))[0] = nst; 438 ($vec-ptr:(long int *diagMut))[0] = nst;