summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Test.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Test.hs')
-rw-r--r--packages/sundials/src/Test.hs164
1 files changed, 164 insertions, 0 deletions
diff --git a/packages/sundials/src/Test.hs b/packages/sundials/src/Test.hs
new file mode 100644
index 0000000..a99582a
--- /dev/null
+++ b/packages/sundials/src/Test.hs
@@ -0,0 +1,164 @@
1{-# LANGUAGE QuasiQuotes #-}
2{-# LANGUAGE TemplateHaskell #-}
3{-# LANGUAGE MultiWayIf #-}
4{-# LANGUAGE OverloadedStrings #-}
5
6import qualified Language.C.Inline as C
7import qualified Language.C.Inline.Unsafe as CU
8import Data.Monoid ((<>))
9import Foreign.C.Types
10import Foreign.Ptr (Ptr)
11import Foreign.Marshal.Array
12import qualified Data.Vector.Storable as V
13
14import Data.Coerce (coerce)
15import Data.Monoid ((<>))
16import qualified Data.Vector.Storable as V
17import qualified Data.Vector.Storable.Mutable as VM
18import Foreign.C.Types
19import Foreign.ForeignPtr (newForeignPtr_)
20import Foreign.Ptr (Ptr)
21import Foreign.Storable (Storable)
22import qualified Language.C.Inline as C
23import qualified Language.C.Inline.Unsafe as CU
24import System.IO.Unsafe (unsafePerformIO)
25
26import qualified Language.Haskell.TH as TH
27import qualified Language.C.Types as CT
28import qualified Data.Map as Map
29import Language.C.Inline.Context
30
31C.context (C.baseCtx <> C.vecCtx <> C.funCtx)
32
33-- C includes
34C.include "<stdio.h>"
35C.include "<math.h>"
36C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts.
37C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
38C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
39C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
40C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
41C.include "<sundials/sundials_types.h>" -- definition of type realtype
42C.include "<sundials/sundials_math.h>"
43C.include "helpers.h"
44
45-- | Solves a system of ODEs. Every 'V.Vector' involved must be of the
46-- same size.
47-- {-# NOINLINE solveOdeC #-}
48-- solveOdeC
49-- :: (CDouble -> V.Vector CDouble -> V.Vector CDouble)
50-- -- ^ ODE to Solve
51-- -> CDouble
52-- -- ^ Start
53-- -> V.Vector CDouble
54-- -- ^ Solution at start point
55-- -> CDouble
56-- -- ^ End
57-- -> Either String (V.Vector CDouble)
58-- -- ^ Solution at end point, or error.
59-- solveOdeC fun x0 f0 xend = unsafePerformIO $ do
60-- let dim = V.length f0
61-- let dim_c = fromIntegral dim -- This is in CInt
62-- -- Convert the function to something of the right type to C.
63-- let funIO x y f _ptr = do
64-- -- Convert the pointer we get from C (y) to a vector, and then
65-- -- apply the user-supplied function.
66-- fImm <- fun x <$> vectorFromC dim y
67-- -- Fill in the provided pointer with the resulting vector.
68-- vectorToC fImm dim f
69-- -- Unsafe since the function will be called many times.
70-- [CU.exp| int{ GSL_SUCCESS } |]
71-- -- Create a mutable vector from the initial solution. This will be
72-- -- passed to the ODE solving function provided by GSL, and will
73-- -- contain the final solution.
74-- fMut <- V.thaw f0
75-- res <- [C.block| int {
76-- gsl_odeiv2_system sys = {
77-- $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)),
78-- // The ODE to solve, converted to function pointer using the `fun`
79-- // anti-quoter
80-- NULL, // We don't provide a Jacobian
81-- $(int dim_c), // The dimension
82-- NULL // We don't need the parameter pointer
83-- };
84-- // Create the driver, using some sensible values for the stepping
85-- // function and the tolerances
86-- gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new (
87-- &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0);
88-- // Finally, apply the driver.
89-- int status = gsl_odeiv2_driver_apply(
90-- d, &$(double x0), $(double xend), $vec-ptr:(double *fMut));
91-- // Free the driver
92-- gsl_odeiv2_driver_free(d);
93-- return status;
94-- } |]
95-- -- Check the error code
96-- maxSteps <- [C.exp| int{ GSL_EMAXITER } |]
97-- smallStep <- [C.exp| int{ GSL_ENOPROG } |]
98-- good <- [C.exp| int{ GSL_SUCCESS } |]
99-- if | res == good -> Right <$> V.freeze fMut
100-- | res == maxSteps -> return $ Left "Too many steps"
101-- | res == smallStep -> return $ Left "Step size dropped below minimum allowed size"
102-- | otherwise -> return $ Left $ "Unknown error code " ++ show res
103
104-- -- Utils
105
106-- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
107-- vectorFromC len ptr = do
108-- ptr' <- newForeignPtr_ ptr
109-- V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
110
111-- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
112-- vectorToC vec len ptr = do
113-- ptr' <- newForeignPtr_ ptr
114-- V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
115
116
117-- /* Check function return value...
118-- opt == 0 means SUNDIALS function allocates memory so check if
119-- returned NULL pointer
120-- opt == 1 means SUNDIALS function returns a flag so check if
121-- flag >= 0
122-- opt == 2 means function allocates memory so check if returned
123-- NULL pointer
124-- */
125-- static int check_flag(void *flagvalue, const char *funcname, int opt)
126-- {
127-- int *errflag;
128
129-- /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
130-- if (opt == 0 && flagvalue == NULL) {
131-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
132-- funcname);
133-- return 1; }
134
135-- /* Check if flag < 0 */
136-- else if (opt == 1) {
137-- errflag = (int *) flagvalue;
138-- if (*errflag < 0) {
139-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n",
140-- funcname, *errflag);
141-- return 1; }}
142
143-- /* Check if function returned NULL pointer - no memory allocated */
144-- else if (opt == 2 && flagvalue == NULL) {
145-- fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
146-- funcname);
147-- return 1; }
148
149-- return 0;
150-- }
151
152main = do
153 res <- [C.block| int { sunindextype NEQ = 1; /* number of dependent vars. */
154 N_Vector y = NULL; /* empty vector for storing solution */
155 void *arkode_mem = NULL; /* empty ARKode memory structure */
156 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
157 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
158
159 N_VConst(0.0, y); /* Specify initial condition */
160 arkode_mem = ARKodeCreate(); /* Create the solver memory */
161 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
162 return 0;
163 } |]
164 putStrLn $ show res