summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Main.hs
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-14 07:25:10 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-14 07:25:10 +0000
commitd23f3abc8038e9669ef1aa6b7ab9fe5346f95410 (patch)
tree2b0a86f8c241bee693adbf7534fb39a0f121619f /packages/sundials/src/Main.hs
parent07df48225553adc441aa68d65a518b145b80a7f5 (diff)
Now as a function
Diffstat (limited to 'packages/sundials/src/Main.hs')
-rw-r--r--packages/sundials/src/Main.hs258
1 files changed, 151 insertions, 107 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 28b813a..b6855cb 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -42,114 +42,21 @@ C.include "<sundials/sundials_types.h>" -- definition of type realtype
42C.include "<sundials/sundials_math.h>" 42C.include "<sundials/sundials_math.h>"
43C.include "helpers.h" 43C.include "helpers.h"
44 44
45-- | Solves a system of ODEs. Every 'V.Vector' involved must be of the
46-- same size.
47-- {-# NOINLINE solveOdeC #-}
48-- solveOdeC
49-- :: (CDouble -> V.Vector CDouble -> V.Vector CDouble)
50-- -- ^ ODE to Solve
51-- -> CDouble
52-- -- ^ Start
53-- -> V.Vector CDouble
54-- -- ^ Solution at start point
55-- -> CDouble
56-- -- ^ End
57-- -> Either String (V.Vector CDouble)
58-- -- ^ Solution at end point, or error.
59-- solveOdeC fun x0 f0 xend = unsafePerformIO $ do
60-- let dim = V.length f0
61-- let dim_c = fromIntegral dim -- This is in CInt
62-- -- Convert the function to something of the right type to C.
63-- let funIO x y f _ptr = do
64-- -- Convert the pointer we get from C (y) to a vector, and then
65-- -- apply the user-supplied function.
66-- fImm <- fun x <$> vectorFromC dim y
67-- -- Fill in the provided pointer with the resulting vector.
68-- vectorToC fImm dim f
69-- -- Unsafe since the function will be called many times.
70-- [CU.exp| int{ GSL_SUCCESS } |]
71-- -- Create a mutable vector from the initial solution. This will be
72-- -- passed to the ODE solving function provided by GSL, and will
73-- -- contain the final solution.
74-- fMut <- V.thaw f0
75-- res <- [C.block| int {
76-- gsl_odeiv2_system sys = {
77-- $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)),
78-- // The ODE to solve, converted to function pointer using the `fun`
79-- // anti-quoter
80-- NULL, // We don't provide a Jacobian
81-- $(int dim_c), // The dimension
82-- NULL // We don't need the parameter pointer
83-- };
84-- // Create the driver, using some sensible values for the stepping
85-- // function and the tolerances
86-- gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new (
87-- &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0);
88-- // Finally, apply the driver.
89-- int status = gsl_odeiv2_driver_apply(
90-- d, &$(double x0), $(double xend), $vec-ptr:(double *fMut));
91-- // Free the driver
92-- gsl_odeiv2_driver_free(d);
93-- return status;
94-- } |]
95-- -- Check the error code
96-- maxSteps <- [C.exp| int{ GSL_EMAXITER } |]
97-- smallStep <- [C.exp| int{ GSL_ENOPROG } |]
98-- good <- [C.exp| int{ GSL_SUCCESS } |]
99-- if | res == good -> Right <$> V.freeze fMut
100-- | res == maxSteps -> return $ Left "Too many steps"
101-- | res == smallStep -> return $ Left "Step size dropped below minimum allowed size"
102-- | otherwise -> return $ Left $ "Unknown error code " ++ show res
103
104-- -- Utils
105
106-- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
107-- vectorFromC len ptr = do
108-- ptr' <- newForeignPtr_ ptr
109-- V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
110
111-- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
112-- vectorToC vec len ptr = do
113-- ptr' <- newForeignPtr_ ptr
114-- V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
115
116
117-- /* Check function return value...
118-- opt == 0 means SUNDIALS function allocates memory so check if
119-- returned NULL pointer
120-- opt == 1 means SUNDIALS function returns a flag so check if
121-- flag >= 0
122-- opt == 2 means function allocates memory so check if returned
123-- NULL pointer
124-- */
125-- static int check_flag(void *flagvalue, const char *funcname, int opt)
126-- {
127-- int *errflag;
128
129-- /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
130-- if (opt == 0 && flagvalue == NULL) {
131-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
132-- funcname);
133-- return 1; }
134
135-- /* Check if flag < 0 */
136-- else if (opt == 1) {
137-- errflag = (int *) flagvalue;
138-- if (*errflag < 0) {
139-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n",
140-- funcname, *errflag);
141-- return 1; }}
142
143-- /* Check if function returned NULL pointer - no memory allocated */
144-- else if (opt == 2 && flagvalue == NULL) {
145-- fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
146-- funcname);
147-- return 1; }
148
149-- return 0;
150-- }
151 45
152main = do 46-- Utils
47
48vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
49vectorFromC len ptr = do
50 ptr' <- newForeignPtr_ ptr
51 V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
52
53vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
54vectorToC vec len ptr = do
55 ptr' <- newForeignPtr_ ptr
56 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
57
58solve :: CDouble -> CInt
59solve lambda = unsafePerformIO $ do
153 res <- [C.block| int { /* general problem variables */ 60 res <- [C.block| int { /* general problem variables */
154 int flag; /* reusable error-checking flag */ 61 int flag; /* reusable error-checking flag */
155 N_Vector y = NULL; /* empty vector for storing solution */ 62 N_Vector y = NULL; /* empty vector for storing solution */
@@ -172,6 +79,7 @@ main = do
172 /* Initial diagnostics output */ 79 /* Initial diagnostics output */
173 printf("\nAnalytical ODE test problem:\n"); 80 printf("\nAnalytical ODE test problem:\n");
174 printf(" lamda = %"GSYM"\n", lamda); 81 printf(" lamda = %"GSYM"\n", lamda);
82 printf(" lambda = %"GSYM"\n", $(double lambda));
175 printf(" reltol = %.1"ESYM"\n", reltol); 83 printf(" reltol = %.1"ESYM"\n", reltol);
176 printf(" abstol = %.1"ESYM"\n\n",abstol); 84 printf(" abstol = %.1"ESYM"\n\n",abstol);
177 85
@@ -282,4 +190,140 @@ main = do
282 190
283 return flag; 191 return flag;
284 } |] 192 } |]
193 return res
194
195main = do
196 let res = solve (coerce (100.0 :: Double))
197 -- res <- [C.block| int { /* general problem variables */
198 -- int flag; /* reusable error-checking flag */
199 -- N_Vector y = NULL; /* empty vector for storing solution */
200 -- SUNMatrix A = NULL; /* empty matrix for linear solver */
201 -- SUNLinearSolver LS = NULL; /* empty linear solver object */
202 -- void *arkode_mem = NULL; /* empty ARKode memory structure */
203 -- FILE *UFID;
204 -- realtype t, tout;
205 -- long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
206
207 -- /* general problem parameters */
208 -- realtype T0 = RCONST(0.0); /* initial time */
209 -- realtype Tf = RCONST(10.0); /* final time */
210 -- realtype dTout = RCONST(1.0); /* time between outputs */
211 -- sunindextype NEQ = 1; /* number of dependent vars. */
212 -- realtype reltol = 1.0e-6; /* tolerances */
213 -- realtype abstol = 1.0e-10;
214 -- realtype lamda = -100.0; /* stiffness parameter */
215
216 -- /* Initial diagnostics output */
217 -- printf("\nAnalytical ODE test problem:\n");
218 -- printf(" lamda = %"GSYM"\n", lamda);
219 -- printf(" reltol = %.1"ESYM"\n", reltol);
220 -- printf(" abstol = %.1"ESYM"\n\n",abstol);
221
222 -- /* Initialize data structures */
223 -- y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
224 -- if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
225 -- N_VConst(0.0, y); /* Specify initial condition */
226 -- arkode_mem = ARKodeCreate(); /* Create the solver memory */
227 -- if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
228
229 -- /* Call ARKodeInit to initialize the integrator memory and specify the */
230 -- /* right-hand side function in y'=f(t,y), the inital time T0, and */
231 -- /* the initial dependent variable vector y. Note: since this */
232 -- /* problem is fully implicit, we set f_E to NULL and f_I to f. */
233 -- flag = ARKodeInit(arkode_mem, NULL, f, T0, y);
234 -- if (check_flag(&flag, "ARKodeInit", 1)) return 1;
235
236 -- /* Set routines */
237 -- flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */
238 -- if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1;
239 -- flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */
240 -- if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1;
241
242 -- /* Initialize dense matrix data structure and solver */
243 -- A = SUNDenseMatrix(NEQ, NEQ);
244 -- if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
245 -- LS = SUNDenseLinearSolver(y, A);
246 -- if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
247
248 -- /* Linear solver interface */
249 -- flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
250 -- if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1;
251 -- flag = ARKDlsSetJacFn(arkode_mem, Jac); /* Set Jacobian routine */
252 -- if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1;
253
254 -- /* Specify linearly implicit RHS, with non-time-dependent Jacobian */
255 -- flag = ARKodeSetLinear(arkode_mem, 0);
256 -- if (check_flag(&flag, "ARKodeSetLinear", 1)) return 1;
257
258 -- /* Open output stream for results, output comment line */
259 -- UFID = fopen("solution.txt","w");
260 -- fprintf(UFID,"# t u\n");
261
262 -- /* output initial condition to disk */
263 -- fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0));
264
265 -- /* Main time-stepping loop: calls ARKode to perform the integration, then
266 -- prints results. Stops when the final time has been reached */
267 -- t = T0;
268 -- tout = T0+dTout;
269 -- printf(" t u\n");
270 -- printf(" ---------------------\n");
271 -- while (Tf - t > 1.0e-15) {
272
273 -- flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */
274 -- if (check_flag(&flag, "ARKode", 1)) break;
275 -- printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */
276 -- fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0));
277 -- if (flag >= 0) { /* successful solve: update time */
278 -- tout += dTout;
279 -- tout = (tout > Tf) ? Tf : tout;
280 -- } else { /* unsuccessful solve: break */
281 -- fprintf(stderr,"Solver failure, stopping integration\n");
282 -- break;
283 -- }
284 -- }
285 -- printf(" ---------------------\n");
286 -- fclose(UFID);
287
288 -- /* Get/print some final statistics on how the solve progressed */
289 -- flag = ARKodeGetNumSteps(arkode_mem, &nst);
290 -- check_flag(&flag, "ARKodeGetNumSteps", 1);
291 -- flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a);
292 -- check_flag(&flag, "ARKodeGetNumStepAttempts", 1);
293 -- flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi);
294 -- check_flag(&flag, "ARKodeGetNumRhsEvals", 1);
295 -- flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups);
296 -- check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1);
297 -- flag = ARKodeGetNumErrTestFails(arkode_mem, &netf);
298 -- check_flag(&flag, "ARKodeGetNumErrTestFails", 1);
299 -- flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni);
300 -- check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1);
301 -- flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn);
302 -- check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1);
303 -- flag = ARKDlsGetNumJacEvals(arkode_mem, &nje);
304 -- check_flag(&flag, "ARKDlsGetNumJacEvals", 1);
305 -- flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS);
306 -- check_flag(&flag, "ARKDlsGetNumRhsEvals", 1);
307
308 -- printf("\nFinal Solver Statistics:\n");
309 -- printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a);
310 -- printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi);
311 -- printf(" Total linear solver setups = %li\n", nsetups);
312 -- printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS);
313 -- printf(" Total number of Jacobian evaluations = %li\n", nje);
314 -- printf(" Total number of Newton iterations = %li\n", nni);
315 -- printf(" Total number of linear solver convergence failures = %li\n", ncfn);
316 -- printf(" Total number of error test failures = %li\n\n", netf);
317
318 -- /* check the solution error */
319 -- flag = check_ans(y, t, reltol, abstol);
320
321 -- /* Clean up and return */
322 -- N_VDestroy(y); /* Free y vector */
323 -- ARKodeFree(&arkode_mem); /* Free integrator memory */
324 -- SUNLinSolFree(LS); /* Free linear solver */
325 -- SUNMatDestroy(A); /* Free A matrix */
326
327 -- return flag;
328 -- } |]
285 putStrLn $ show res 329 putStrLn $ show res