From d23f3abc8038e9669ef1aa6b7ab9fe5346f95410 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 14 Mar 2018 07:25:10 +0000 Subject: Now as a function --- packages/sundials/src/Main.hs | 258 ++++++++++++++++++++++++------------------ 1 file changed, 151 insertions(+), 107 deletions(-) (limited to 'packages/sundials/src') diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 28b813a..b6855cb 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -42,114 +42,21 @@ C.include "" -- definition of type realtype C.include "" C.include "helpers.h" --- | Solves a system of ODEs. Every 'V.Vector' involved must be of the --- same size. --- {-# NOINLINE solveOdeC #-} --- solveOdeC --- :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) --- -- ^ ODE to Solve --- -> CDouble --- -- ^ Start --- -> V.Vector CDouble --- -- ^ Solution at start point --- -> CDouble --- -- ^ End --- -> Either String (V.Vector CDouble) --- -- ^ Solution at end point, or error. --- solveOdeC fun x0 f0 xend = unsafePerformIO $ do --- let dim = V.length f0 --- let dim_c = fromIntegral dim -- This is in CInt --- -- Convert the function to something of the right type to C. --- let funIO x y f _ptr = do --- -- Convert the pointer we get from C (y) to a vector, and then --- -- apply the user-supplied function. --- fImm <- fun x <$> vectorFromC dim y --- -- Fill in the provided pointer with the resulting vector. --- vectorToC fImm dim f --- -- Unsafe since the function will be called many times. --- [CU.exp| int{ GSL_SUCCESS } |] --- -- Create a mutable vector from the initial solution. This will be --- -- passed to the ODE solving function provided by GSL, and will --- -- contain the final solution. --- fMut <- V.thaw f0 --- res <- [C.block| int { --- gsl_odeiv2_system sys = { --- $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)), --- // The ODE to solve, converted to function pointer using the `fun` --- // anti-quoter --- NULL, // We don't provide a Jacobian --- $(int dim_c), // The dimension --- NULL // We don't need the parameter pointer --- }; --- // Create the driver, using some sensible values for the stepping --- // function and the tolerances --- gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new ( --- &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0); --- // Finally, apply the driver. --- int status = gsl_odeiv2_driver_apply( --- d, &$(double x0), $(double xend), $vec-ptr:(double *fMut)); --- // Free the driver --- gsl_odeiv2_driver_free(d); --- return status; --- } |] --- -- Check the error code --- maxSteps <- [C.exp| int{ GSL_EMAXITER } |] --- smallStep <- [C.exp| int{ GSL_ENOPROG } |] --- good <- [C.exp| int{ GSL_SUCCESS } |] --- if | res == good -> Right <$> V.freeze fMut --- | res == maxSteps -> return $ Left "Too many steps" --- | res == smallStep -> return $ Left "Step size dropped below minimum allowed size" --- | otherwise -> return $ Left $ "Unknown error code " ++ show res - --- -- Utils - --- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) --- vectorFromC len ptr = do --- ptr' <- newForeignPtr_ ptr --- V.freeze $ VM.unsafeFromForeignPtr0 ptr' len - --- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () --- vectorToC vec len ptr = do --- ptr' <- newForeignPtr_ ptr --- V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec - - --- /* Check function return value... --- opt == 0 means SUNDIALS function allocates memory so check if --- returned NULL pointer --- opt == 1 means SUNDIALS function returns a flag so check if --- flag >= 0 --- opt == 2 means function allocates memory so check if returned --- NULL pointer --- */ --- static int check_flag(void *flagvalue, const char *funcname, int opt) --- { --- int *errflag; - --- /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ --- if (opt == 0 && flagvalue == NULL) { --- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", --- funcname); --- return 1; } - --- /* Check if flag < 0 */ --- else if (opt == 1) { --- errflag = (int *) flagvalue; --- if (*errflag < 0) { --- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n", --- funcname, *errflag); --- return 1; }} - --- /* Check if function returned NULL pointer - no memory allocated */ --- else if (opt == 2 && flagvalue == NULL) { --- fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n", --- funcname); --- return 1; } - --- return 0; --- } -main = do +-- Utils + +vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) +vectorFromC len ptr = do + ptr' <- newForeignPtr_ ptr + V.freeze $ VM.unsafeFromForeignPtr0 ptr' len + +vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () +vectorToC vec len ptr = do + ptr' <- newForeignPtr_ ptr + V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec + +solve :: CDouble -> CInt +solve lambda = unsafePerformIO $ do res <- [C.block| int { /* general problem variables */ int flag; /* reusable error-checking flag */ N_Vector y = NULL; /* empty vector for storing solution */ @@ -172,6 +79,7 @@ main = do /* Initial diagnostics output */ printf("\nAnalytical ODE test problem:\n"); printf(" lamda = %"GSYM"\n", lamda); + printf(" lambda = %"GSYM"\n", $(double lambda)); printf(" reltol = %.1"ESYM"\n", reltol); printf(" abstol = %.1"ESYM"\n\n",abstol); @@ -282,4 +190,140 @@ main = do return flag; } |] + return res + +main = do + let res = solve (coerce (100.0 :: Double)) + -- res <- [C.block| int { /* general problem variables */ + -- int flag; /* reusable error-checking flag */ + -- N_Vector y = NULL; /* empty vector for storing solution */ + -- SUNMatrix A = NULL; /* empty matrix for linear solver */ + -- SUNLinearSolver LS = NULL; /* empty linear solver object */ + -- void *arkode_mem = NULL; /* empty ARKode memory structure */ + -- FILE *UFID; + -- realtype t, tout; + -- long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; + + -- /* general problem parameters */ + -- realtype T0 = RCONST(0.0); /* initial time */ + -- realtype Tf = RCONST(10.0); /* final time */ + -- realtype dTout = RCONST(1.0); /* time between outputs */ + -- sunindextype NEQ = 1; /* number of dependent vars. */ + -- realtype reltol = 1.0e-6; /* tolerances */ + -- realtype abstol = 1.0e-10; + -- realtype lamda = -100.0; /* stiffness parameter */ + + -- /* Initial diagnostics output */ + -- printf("\nAnalytical ODE test problem:\n"); + -- printf(" lamda = %"GSYM"\n", lamda); + -- printf(" reltol = %.1"ESYM"\n", reltol); + -- printf(" abstol = %.1"ESYM"\n\n",abstol); + + -- /* Initialize data structures */ + -- y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ + -- if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; + -- N_VConst(0.0, y); /* Specify initial condition */ + -- arkode_mem = ARKodeCreate(); /* Create the solver memory */ + -- if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; + + -- /* Call ARKodeInit to initialize the integrator memory and specify the */ + -- /* right-hand side function in y'=f(t,y), the inital time T0, and */ + -- /* the initial dependent variable vector y. Note: since this */ + -- /* problem is fully implicit, we set f_E to NULL and f_I to f. */ + -- flag = ARKodeInit(arkode_mem, NULL, f, T0, y); + -- if (check_flag(&flag, "ARKodeInit", 1)) return 1; + + -- /* Set routines */ + -- flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */ + -- if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1; + -- flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ + -- if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; + + -- /* Initialize dense matrix data structure and solver */ + -- A = SUNDenseMatrix(NEQ, NEQ); + -- if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; + -- LS = SUNDenseLinearSolver(y, A); + -- if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; + + -- /* Linear solver interface */ + -- flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ + -- if (check_flag(&flag, "ARKDlsSetLinearSolver", 1)) return 1; + -- flag = ARKDlsSetJacFn(arkode_mem, Jac); /* Set Jacobian routine */ + -- if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; + + -- /* Specify linearly implicit RHS, with non-time-dependent Jacobian */ + -- flag = ARKodeSetLinear(arkode_mem, 0); + -- if (check_flag(&flag, "ARKodeSetLinear", 1)) return 1; + + -- /* Open output stream for results, output comment line */ + -- UFID = fopen("solution.txt","w"); + -- fprintf(UFID,"# t u\n"); + + -- /* output initial condition to disk */ + -- fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); + + -- /* Main time-stepping loop: calls ARKode to perform the integration, then + -- prints results. Stops when the final time has been reached */ + -- t = T0; + -- tout = T0+dTout; + -- printf(" t u\n"); + -- printf(" ---------------------\n"); + -- while (Tf - t > 1.0e-15) { + + -- flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ + -- if (check_flag(&flag, "ARKode", 1)) break; + -- printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ + -- fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); + -- if (flag >= 0) { /* successful solve: update time */ + -- tout += dTout; + -- tout = (tout > Tf) ? Tf : tout; + -- } else { /* unsuccessful solve: break */ + -- fprintf(stderr,"Solver failure, stopping integration\n"); + -- break; + -- } + -- } + -- printf(" ---------------------\n"); + -- fclose(UFID); + + -- /* Get/print some final statistics on how the solve progressed */ + -- flag = ARKodeGetNumSteps(arkode_mem, &nst); + -- check_flag(&flag, "ARKodeGetNumSteps", 1); + -- flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); + -- check_flag(&flag, "ARKodeGetNumStepAttempts", 1); + -- flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); + -- check_flag(&flag, "ARKodeGetNumRhsEvals", 1); + -- flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); + -- check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); + -- flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); + -- check_flag(&flag, "ARKodeGetNumErrTestFails", 1); + -- flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); + -- check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); + -- flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); + -- check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); + -- flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); + -- check_flag(&flag, "ARKDlsGetNumJacEvals", 1); + -- flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); + -- check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); + + -- printf("\nFinal Solver Statistics:\n"); + -- printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); + -- printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); + -- printf(" Total linear solver setups = %li\n", nsetups); + -- printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); + -- printf(" Total number of Jacobian evaluations = %li\n", nje); + -- printf(" Total number of Newton iterations = %li\n", nni); + -- printf(" Total number of linear solver convergence failures = %li\n", ncfn); + -- printf(" Total number of error test failures = %li\n\n", netf); + + -- /* check the solution error */ + -- flag = check_ans(y, t, reltol, abstol); + + -- /* Clean up and return */ + -- N_VDestroy(y); /* Free y vector */ + -- ARKodeFree(&arkode_mem); /* Free integrator memory */ + -- SUNLinSolFree(LS); /* Free linear solver */ + -- SUNMatDestroy(A); /* Free A matrix */ + + -- return flag; + -- } |] putStrLn $ show res -- cgit v1.2.3