diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2018-03-14 07:25:10 +0000 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2018-03-14 07:25:10 +0000 |
commit | d23f3abc8038e9669ef1aa6b7ab9fe5346f95410 (patch) | |
tree | 2b0a86f8c241bee693adbf7534fb39a0f121619f /packages | |
parent | 07df48225553adc441aa68d65a518b145b80a7f5 (diff) |
Now as a function
Diffstat (limited to 'packages')
-rw-r--r-- | packages/sundials/src/Main.hs | 258 |
1 files changed, 151 insertions, 107 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 28b813a..b6855cb 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -42,114 +42,21 @@ C.include "<sundials/sundials_types.h>" -- definition of type realtype | |||
42 | C.include "<sundials/sundials_math.h>" | 42 | C.include "<sundials/sundials_math.h>" |
43 | C.include "helpers.h" | 43 | C.include "helpers.h" |
44 | 44 | ||
45 | -- | Solves a system of ODEs. Every 'V.Vector' involved must be of the | ||
46 | -- same size. | ||
47 | -- {-# NOINLINE solveOdeC #-} | ||
48 | -- solveOdeC | ||
49 | -- :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) | ||
50 | -- -- ^ ODE to Solve | ||
51 | -- -> CDouble | ||
52 | -- -- ^ Start | ||
53 | -- -> V.Vector CDouble | ||
54 | -- -- ^ Solution at start point | ||
55 | -- -> CDouble | ||
56 | -- -- ^ End | ||
57 | -- -> Either String (V.Vector CDouble) | ||
58 | -- -- ^ Solution at end point, or error. | ||
59 | -- solveOdeC fun x0 f0 xend = unsafePerformIO $ do | ||
60 | -- let dim = V.length f0 | ||
61 | -- let dim_c = fromIntegral dim -- This is in CInt | ||
62 | -- -- Convert the function to something of the right type to C. | ||
63 | -- let funIO x y f _ptr = do | ||
64 | -- -- Convert the pointer we get from C (y) to a vector, and then | ||
65 | -- -- apply the user-supplied function. | ||
66 | -- fImm <- fun x <$> vectorFromC dim y | ||
67 | -- -- Fill in the provided pointer with the resulting vector. | ||
68 | -- vectorToC fImm dim f | ||
69 | -- -- Unsafe since the function will be called many times. | ||
70 | -- [CU.exp| int{ GSL_SUCCESS } |] | ||
71 | -- -- Create a mutable vector from the initial solution. This will be | ||
72 | -- -- passed to the ODE solving function provided by GSL, and will | ||
73 | -- -- contain the final solution. | ||
74 | -- fMut <- V.thaw f0 | ||
75 | -- res <- [C.block| int { | ||
76 | -- gsl_odeiv2_system sys = { | ||
77 | -- $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)), | ||
78 | -- // The ODE to solve, converted to function pointer using the `fun` | ||
79 | -- // anti-quoter | ||
80 | -- NULL, // We don't provide a Jacobian | ||
81 | -- $(int dim_c), // The dimension | ||
82 | -- NULL // We don't need the parameter pointer | ||
83 | -- }; | ||
84 | -- // Create the driver, using some sensible values for the stepping | ||
85 | -- // function and the tolerances | ||
86 | -- gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new ( | ||
87 | -- &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0); | ||
88 | -- // Finally, apply the driver. | ||
89 | -- int status = gsl_odeiv2_driver_apply( | ||
90 | -- d, &$(double x0), $(double xend), $vec-ptr:(double *fMut)); | ||
91 | -- // Free the driver | ||
92 | -- gsl_odeiv2_driver_free(d); | ||
93 | -- return status; | ||
94 | -- } |] | ||
95 | -- -- Check the error code | ||
96 | -- maxSteps <- [C.exp| int{ GSL_EMAXITER } |] | ||
97 | -- smallStep <- [C.exp| int{ GSL_ENOPROG } |] | ||
98 | -- good <- [C.exp| int{ GSL_SUCCESS } |] | ||
99 | -- if | res == good -> Right <$> V.freeze fMut | ||
100 | -- | res == maxSteps -> return $ Left "Too many steps" | ||
101 | -- | res == smallStep -> return $ Left "Step size dropped below minimum allowed size" | ||
102 | -- | otherwise -> return $ Left $ "Unknown error code " ++ show res | ||
103 | |||
104 | -- -- Utils | ||
105 | |||
106 | -- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
107 | -- vectorFromC len ptr = do | ||
108 | -- ptr' <- newForeignPtr_ ptr | ||
109 | -- V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
110 | |||
111 | -- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
112 | -- vectorToC vec len ptr = do | ||
113 | -- ptr' <- newForeignPtr_ ptr | ||
114 | -- V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
115 | |||
116 | |||
117 | -- /* Check function return value... | ||
118 | -- opt == 0 means SUNDIALS function allocates memory so check if | ||
119 | -- returned NULL pointer | ||
120 | -- opt == 1 means SUNDIALS function returns a flag so check if | ||
121 | -- flag >= 0 | ||
122 | -- opt == 2 means function allocates memory so check if returned | ||
123 | -- NULL pointer | ||
124 | -- */ | ||
125 | -- static int check_flag(void *flagvalue, const char *funcname, int opt) | ||
126 | -- { | ||
127 | -- int *errflag; | ||
128 | |||
129 | -- /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ | ||
130 | -- if (opt == 0 && flagvalue == NULL) { | ||
131 | -- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", | ||
132 | -- funcname); | ||
133 | -- return 1; } | ||
134 | |||
135 | -- /* Check if flag < 0 */ | ||
136 | -- else if (opt == 1) { | ||
137 | -- errflag = (int *) flagvalue; | ||
138 | -- if (*errflag < 0) { | ||
139 | -- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n", | ||
140 | -- funcname, *errflag); | ||
141 | -- return 1; }} | ||
142 | |||
143 | -- /* Check if function returned NULL pointer - no memory allocated */ | ||
144 | -- else if (opt == 2 && flagvalue == NULL) { | ||
145 | -- fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n", | ||
146 | -- funcname); | ||
147 | -- return 1; } | ||
148 | |||
149 | -- return 0; | ||
150 | -- } | ||
151 | 45 | ||
152 | main = do | 46 | -- Utils |
47 | |||
48 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
49 | vectorFromC len ptr = do | ||
50 | ptr' <- newForeignPtr_ ptr | ||
51 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
52 | |||
53 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
54 | vectorToC vec len ptr = do | ||
55 | ptr' <- newForeignPtr_ ptr | ||
56 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
57 | |||
58 | solve :: CDouble -> CInt | ||
59 | solve lambda = unsafePerformIO $ do | ||
153 | res <- [C.block| int { /* general problem variables */ | 60 | res <- [C.block| int { /* general problem variables */ |
154 | int flag; /* reusable error-checking flag */ | 61 | int flag; /* reusable error-checking flag */ |
155 | N_Vector y = NULL; /* empty vector for storing solution */ | 62 | N_Vector y = NULL; /* empty vector for storing solution */ |
@@ -172,6 +79,7 @@ main = do | |||
172 | /* Initial diagnostics output */ | 79 | /* Initial diagnostics output */ |
173 | printf("\nAnalytical ODE test problem:\n"); | 80 | printf("\nAnalytical ODE test problem:\n"); |
174 | printf(" lamda = %"GSYM"\n", lamda); | 81 | printf(" lamda = %"GSYM"\n", lamda); |
82 | printf(" lambda = %"GSYM"\n", $(double lambda)); | ||
175 | printf(" reltol = %.1"ESYM"\n", reltol); | 83 | printf(" reltol = %.1"ESYM"\n", reltol); |
176 | printf(" abstol = %.1"ESYM"\n\n",abstol); | 84 | printf(" abstol = %.1"ESYM"\n\n",abstol); |
177 | 85 | ||
@@ -282,4 +190,140 @@ main = do | |||
282 | 190 | ||
283 | return flag; | 191 | return flag; |
284 | } |] | 192 | } |] |
193 | return res | ||
194 | |||
195 | main = do | ||
196 | let res = solve (coerce (100.0 :: Double)) | ||
197 | -- res <- [C.block| int { /* general problem variables */ | ||
198 | -- int flag; /* reusable error-checking flag */ | ||
199 | -- N_Vector y = NULL; /* empty vector for storing solution */ | ||
200 | -- SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
201 | -- SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
202 | -- void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
203 | -- FILE *UFID; | ||
204 | -- realtype t, tout; | ||
205 | -- long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
206 | |||
207 | -- /* general problem parameters */ | ||
208 | -- realtype T0 = RCONST(0.0); /* initial time */ | ||
209 | -- realtype Tf = RCONST(10.0); /* final time */ | ||
210 | -- realtype dTout = RCONST(1.0); /* time between outputs */ | ||
211 | -- sunindextype NEQ = 1; /* number of dependent vars. */ | ||
212 | -- realtype reltol = 1.0e-6; /* tolerances */ | ||
213 | -- realtype abstol = 1.0e-10; | ||
214 | -- realtype lamda = -100.0; /* stiffness parameter */ | ||
215 | |||
216 | -- /* Initial diagnostics output */ | ||
217 | -- printf("\nAnalytical ODE test problem:\n"); | ||
218 | -- printf(" lamda = %"GSYM"\n", lamda); | ||
219 | -- printf(" reltol = %.1"ESYM"\n", reltol); | ||
220 | -- printf(" abstol = %.1"ESYM"\n\n",abstol); | ||
221 | |||
222 | -- /* Initialize data structures */ | ||
223 | -- y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
224 | -- if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
225 | -- N_VConst(0.0, y); /* Specify initial condition */ | ||
226 | -- arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
227 | -- if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
228 | |||
229 | -- /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
230 | -- /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
231 | -- /* the initial dependent variable vector y. Note: since this */ | ||
232 | -- /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | ||
233 | -- flag = ARKodeInit(arkode_mem, NULL, f, T0, y); | ||
234 | -- if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
235 | |||
236 | -- /* Set routines */ | ||
237 | -- flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */ | ||
238 | -- if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1; | ||
239 | -- flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ | ||
240 | -- if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; | ||
241 | |||
242 | -- /* Initialize dense matrix data structure and solver */ | ||
243 | -- A = SUNDenseMatrix(NEQ, NEQ); | ||
244 | -- if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
245 | -- LS = SUNDenseLinearSolver(y, A); | ||
246 | -- if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
247 | |||
248 | -- /* Linear solver interface */ | ||
249 | -- flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | ||
250 | -- if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1; | ||
251 | -- flag = ARKDlsSetJacFn(arkode_mem, Jac); /* Set Jacobian routine */ | ||
252 | -- if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; | ||
253 | |||
254 | -- /* Specify linearly implicit RHS, with non-time-dependent Jacobian */ | ||
255 | -- flag = ARKodeSetLinear(arkode_mem, 0); | ||
256 | -- if (check_flag(&flag, "ARKodeSetLinear", 1)) return 1; | ||
257 | |||
258 | -- /* Open output stream for results, output comment line */ | ||
259 | -- UFID = fopen("solution.txt","w"); | ||
260 | -- fprintf(UFID,"# t u\n"); | ||
261 | |||
262 | -- /* output initial condition to disk */ | ||
263 | -- fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); | ||
264 | |||
265 | -- /* Main time-stepping loop: calls ARKode to perform the integration, then | ||
266 | -- prints results. Stops when the final time has been reached */ | ||
267 | -- t = T0; | ||
268 | -- tout = T0+dTout; | ||
269 | -- printf(" t u\n"); | ||
270 | -- printf(" ---------------------\n"); | ||
271 | -- while (Tf - t > 1.0e-15) { | ||
272 | |||
273 | -- flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ | ||
274 | -- if (check_flag(&flag, "ARKode", 1)) break; | ||
275 | -- printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ | ||
276 | -- fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); | ||
277 | -- if (flag >= 0) { /* successful solve: update time */ | ||
278 | -- tout += dTout; | ||
279 | -- tout = (tout > Tf) ? Tf : tout; | ||
280 | -- } else { /* unsuccessful solve: break */ | ||
281 | -- fprintf(stderr,"Solver failure, stopping integration\n"); | ||
282 | -- break; | ||
283 | -- } | ||
284 | -- } | ||
285 | -- printf(" ---------------------\n"); | ||
286 | -- fclose(UFID); | ||
287 | |||
288 | -- /* Get/print some final statistics on how the solve progressed */ | ||
289 | -- flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
290 | -- check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
291 | -- flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
292 | -- check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
293 | -- flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
294 | -- check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
295 | -- flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
296 | -- check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
297 | -- flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
298 | -- check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
299 | -- flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
300 | -- check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
301 | -- flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
302 | -- check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
303 | -- flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
304 | -- check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
305 | -- flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
306 | -- check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
307 | |||
308 | -- printf("\nFinal Solver Statistics:\n"); | ||
309 | -- printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); | ||
310 | -- printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); | ||
311 | -- printf(" Total linear solver setups = %li\n", nsetups); | ||
312 | -- printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); | ||
313 | -- printf(" Total number of Jacobian evaluations = %li\n", nje); | ||
314 | -- printf(" Total number of Newton iterations = %li\n", nni); | ||
315 | -- printf(" Total number of linear solver convergence failures = %li\n", ncfn); | ||
316 | -- printf(" Total number of error test failures = %li\n\n", netf); | ||
317 | |||
318 | -- /* check the solution error */ | ||
319 | -- flag = check_ans(y, t, reltol, abstol); | ||
320 | |||
321 | -- /* Clean up and return */ | ||
322 | -- N_VDestroy(y); /* Free y vector */ | ||
323 | -- ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
324 | -- SUNLinSolFree(LS); /* Free linear solver */ | ||
325 | -- SUNMatDestroy(A); /* Free A matrix */ | ||
326 | |||
327 | -- return flag; | ||
328 | -- } |] | ||
285 | putStrLn $ show res | 329 | putStrLn $ show res |