summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Main.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Main.hs')
-rw-r--r--packages/sundials/src/Main.hs194
1 files changed, 194 insertions, 0 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
new file mode 100644
index 0000000..9e8bc63
--- /dev/null
+++ b/packages/sundials/src/Main.hs
@@ -0,0 +1,194 @@
1{-# LANGUAGE QuasiQuotes #-}
2{-# LANGUAGE TemplateHaskell #-}
3{-# LANGUAGE MultiWayIf #-}
4{-# LANGUAGE OverloadedStrings #-}
5
6import qualified Language.C.Inline as C
7import qualified Language.C.Inline.Unsafe as CU
8import Data.Monoid ((<>))
9import Foreign.C.Types
10import Foreign.Ptr (Ptr)
11import Foreign.Marshal.Array
12import qualified Data.Vector.Storable as V
13
14import Data.Coerce (coerce)
15import Data.Monoid ((<>))
16import qualified Data.Vector.Storable as V
17import qualified Data.Vector.Storable.Mutable as VM
18import Foreign.C.Types
19import Foreign.ForeignPtr (newForeignPtr_)
20import Foreign.Ptr (Ptr)
21import Foreign.Storable (Storable)
22import qualified Language.C.Inline as C
23import qualified Language.C.Inline.Unsafe as CU
24import System.IO.Unsafe (unsafePerformIO)
25
26import qualified Language.Haskell.TH as TH
27import qualified Language.C.Types as CT
28import qualified Data.Map as Map
29import Language.C.Inline.Context
30
31C.context (C.baseCtx <> C.vecCtx <> C.funCtx)
32
33-- C includes
34C.include "<stdio.h>"
35C.include "<math.h>"
36C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts.
37C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
38C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
39C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
40C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
41C.include "<sundials/sundials_types.h>" -- definition of type realtype
42C.include "<sundials/sundials_math.h>"
43C.include "helpers.h"
44
45-- | Solves a system of ODEs. Every 'V.Vector' involved must be of the
46-- same size.
47-- {-# NOINLINE solveOdeC #-}
48-- solveOdeC
49-- :: (CDouble -> V.Vector CDouble -> V.Vector CDouble)
50-- -- ^ ODE to Solve
51-- -> CDouble
52-- -- ^ Start
53-- -> V.Vector CDouble
54-- -- ^ Solution at start point
55-- -> CDouble
56-- -- ^ End
57-- -> Either String (V.Vector CDouble)
58-- -- ^ Solution at end point, or error.
59-- solveOdeC fun x0 f0 xend = unsafePerformIO $ do
60-- let dim = V.length f0
61-- let dim_c = fromIntegral dim -- This is in CInt
62-- -- Convert the function to something of the right type to C.
63-- let funIO x y f _ptr = do
64-- -- Convert the pointer we get from C (y) to a vector, and then
65-- -- apply the user-supplied function.
66-- fImm <- fun x <$> vectorFromC dim y
67-- -- Fill in the provided pointer with the resulting vector.
68-- vectorToC fImm dim f
69-- -- Unsafe since the function will be called many times.
70-- [CU.exp| int{ GSL_SUCCESS } |]
71-- -- Create a mutable vector from the initial solution. This will be
72-- -- passed to the ODE solving function provided by GSL, and will
73-- -- contain the final solution.
74-- fMut <- V.thaw f0
75-- res <- [C.block| int {
76-- gsl_odeiv2_system sys = {
77-- $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)),
78-- // The ODE to solve, converted to function pointer using the `fun`
79-- // anti-quoter
80-- NULL, // We don't provide a Jacobian
81-- $(int dim_c), // The dimension
82-- NULL // We don't need the parameter pointer
83-- };
84-- // Create the driver, using some sensible values for the stepping
85-- // function and the tolerances
86-- gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new (
87-- &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0);
88-- // Finally, apply the driver.
89-- int status = gsl_odeiv2_driver_apply(
90-- d, &$(double x0), $(double xend), $vec-ptr:(double *fMut));
91-- // Free the driver
92-- gsl_odeiv2_driver_free(d);
93-- return status;
94-- } |]
95-- -- Check the error code
96-- maxSteps <- [C.exp| int{ GSL_EMAXITER } |]
97-- smallStep <- [C.exp| int{ GSL_ENOPROG } |]
98-- good <- [C.exp| int{ GSL_SUCCESS } |]
99-- if | res == good -> Right <$> V.freeze fMut
100-- | res == maxSteps -> return $ Left "Too many steps"
101-- | res == smallStep -> return $ Left "Step size dropped below minimum allowed size"
102-- | otherwise -> return $ Left $ "Unknown error code " ++ show res
103
104-- -- Utils
105
106-- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
107-- vectorFromC len ptr = do
108-- ptr' <- newForeignPtr_ ptr
109-- V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
110
111-- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
112-- vectorToC vec len ptr = do
113-- ptr' <- newForeignPtr_ ptr
114-- V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
115
116
117-- /* Check function return value...
118-- opt == 0 means SUNDIALS function allocates memory so check if
119-- returned NULL pointer
120-- opt == 1 means SUNDIALS function returns a flag so check if
121-- flag >= 0
122-- opt == 2 means function allocates memory so check if returned
123-- NULL pointer
124-- */
125-- static int check_flag(void *flagvalue, const char *funcname, int opt)
126-- {
127-- int *errflag;
128
129-- /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
130-- if (opt == 0 && flagvalue == NULL) {
131-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
132-- funcname);
133-- return 1; }
134
135-- /* Check if flag < 0 */
136-- else if (opt == 1) {
137-- errflag = (int *) flagvalue;
138-- if (*errflag < 0) {
139-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n",
140-- funcname, *errflag);
141-- return 1; }}
142
143-- /* Check if function returned NULL pointer - no memory allocated */
144-- else if (opt == 2 && flagvalue == NULL) {
145-- fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
146-- funcname);
147-- return 1; }
148
149-- return 0;
150-- }
151
152main = do
153 res <- [C.block| int { /* general problem variables */
154 int flag; /* reusable error-checking flag */
155 N_Vector y = NULL; /* empty vector for storing solution */
156 SUNMatrix A = NULL; /* empty matrix for linear solver */
157 SUNLinearSolver LS = NULL; /* empty linear solver object */
158 void *arkode_mem = NULL; /* empty ARKode memory structure */
159 FILE *UFID;
160 realtype t, tout;
161 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
162
163 /* general problem parameters */
164 realtype T0 = RCONST(0.0); /* initial time */
165 realtype Tf = RCONST(10.0); /* final time */
166 realtype dTout = RCONST(1.0); /* time between outputs */
167 sunindextype NEQ = 1; /* number of dependent vars. */
168 realtype reltol = 1.0e-6; /* tolerances */
169 realtype abstol = 1.0e-10;
170 realtype lamda = -100.0; /* stiffness parameter */
171
172 /* Initial diagnostics output */
173 printf("\nAnalytical ODE test problem:\n");
174 printf(" lamda = %"GSYM"\n", lamda);
175 printf(" reltol = %.1"ESYM"\n", reltol);
176 printf(" abstol = %.1"ESYM"\n\n",abstol);
177
178 /* Initialize data structures */
179 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
180 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
181 N_VConst(0.0, y); /* Specify initial condition */
182 arkode_mem = ARKodeCreate(); /* Create the solver memory */
183 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
184
185 /* Call ARKodeInit to initialize the integrator memory and specify the */
186 /* right-hand side function in y'=f(t,y), the inital time T0, and */
187 /* the initial dependent variable vector y. Note: since this */
188 /* problem is fully implicit, we set f_E to NULL and f_I to f. */
189 flag = ARKodeInit(arkode_mem, NULL, f, T0, y);
190 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
191
192 return 0;
193 } |]
194 putStrLn $ show res