diff options
Diffstat (limited to 'packages/sundials/src/Test.hs')
-rw-r--r-- | packages/sundials/src/Test.hs | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/packages/sundials/src/Test.hs b/packages/sundials/src/Test.hs new file mode 100644 index 0000000..a99582a --- /dev/null +++ b/packages/sundials/src/Test.hs | |||
@@ -0,0 +1,164 @@ | |||
1 | {-# LANGUAGE QuasiQuotes #-} | ||
2 | {-# LANGUAGE TemplateHaskell #-} | ||
3 | {-# LANGUAGE MultiWayIf #-} | ||
4 | {-# LANGUAGE OverloadedStrings #-} | ||
5 | |||
6 | import qualified Language.C.Inline as C | ||
7 | import qualified Language.C.Inline.Unsafe as CU | ||
8 | import Data.Monoid ((<>)) | ||
9 | import Foreign.C.Types | ||
10 | import Foreign.Ptr (Ptr) | ||
11 | import Foreign.Marshal.Array | ||
12 | import qualified Data.Vector.Storable as V | ||
13 | |||
14 | import Data.Coerce (coerce) | ||
15 | import Data.Monoid ((<>)) | ||
16 | import qualified Data.Vector.Storable as V | ||
17 | import qualified Data.Vector.Storable.Mutable as VM | ||
18 | import Foreign.C.Types | ||
19 | import Foreign.ForeignPtr (newForeignPtr_) | ||
20 | import Foreign.Ptr (Ptr) | ||
21 | import Foreign.Storable (Storable) | ||
22 | import qualified Language.C.Inline as C | ||
23 | import qualified Language.C.Inline.Unsafe as CU | ||
24 | import System.IO.Unsafe (unsafePerformIO) | ||
25 | |||
26 | import qualified Language.Haskell.TH as TH | ||
27 | import qualified Language.C.Types as CT | ||
28 | import qualified Data.Map as Map | ||
29 | import Language.C.Inline.Context | ||
30 | |||
31 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx) | ||
32 | |||
33 | -- C includes | ||
34 | C.include "<stdio.h>" | ||
35 | C.include "<math.h>" | ||
36 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
37 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
38 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
39 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
40 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
41 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
42 | C.include "<sundials/sundials_math.h>" | ||
43 | C.include "helpers.h" | ||
44 | |||
45 | -- | Solves a system of ODEs. Every 'V.Vector' involved must be of the | ||
46 | -- same size. | ||
47 | -- {-# NOINLINE solveOdeC #-} | ||
48 | -- solveOdeC | ||
49 | -- :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) | ||
50 | -- -- ^ ODE to Solve | ||
51 | -- -> CDouble | ||
52 | -- -- ^ Start | ||
53 | -- -> V.Vector CDouble | ||
54 | -- -- ^ Solution at start point | ||
55 | -- -> CDouble | ||
56 | -- -- ^ End | ||
57 | -- -> Either String (V.Vector CDouble) | ||
58 | -- -- ^ Solution at end point, or error. | ||
59 | -- solveOdeC fun x0 f0 xend = unsafePerformIO $ do | ||
60 | -- let dim = V.length f0 | ||
61 | -- let dim_c = fromIntegral dim -- This is in CInt | ||
62 | -- -- Convert the function to something of the right type to C. | ||
63 | -- let funIO x y f _ptr = do | ||
64 | -- -- Convert the pointer we get from C (y) to a vector, and then | ||
65 | -- -- apply the user-supplied function. | ||
66 | -- fImm <- fun x <$> vectorFromC dim y | ||
67 | -- -- Fill in the provided pointer with the resulting vector. | ||
68 | -- vectorToC fImm dim f | ||
69 | -- -- Unsafe since the function will be called many times. | ||
70 | -- [CU.exp| int{ GSL_SUCCESS } |] | ||
71 | -- -- Create a mutable vector from the initial solution. This will be | ||
72 | -- -- passed to the ODE solving function provided by GSL, and will | ||
73 | -- -- contain the final solution. | ||
74 | -- fMut <- V.thaw f0 | ||
75 | -- res <- [C.block| int { | ||
76 | -- gsl_odeiv2_system sys = { | ||
77 | -- $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)), | ||
78 | -- // The ODE to solve, converted to function pointer using the `fun` | ||
79 | -- // anti-quoter | ||
80 | -- NULL, // We don't provide a Jacobian | ||
81 | -- $(int dim_c), // The dimension | ||
82 | -- NULL // We don't need the parameter pointer | ||
83 | -- }; | ||
84 | -- // Create the driver, using some sensible values for the stepping | ||
85 | -- // function and the tolerances | ||
86 | -- gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new ( | ||
87 | -- &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0); | ||
88 | -- // Finally, apply the driver. | ||
89 | -- int status = gsl_odeiv2_driver_apply( | ||
90 | -- d, &$(double x0), $(double xend), $vec-ptr:(double *fMut)); | ||
91 | -- // Free the driver | ||
92 | -- gsl_odeiv2_driver_free(d); | ||
93 | -- return status; | ||
94 | -- } |] | ||
95 | -- -- Check the error code | ||
96 | -- maxSteps <- [C.exp| int{ GSL_EMAXITER } |] | ||
97 | -- smallStep <- [C.exp| int{ GSL_ENOPROG } |] | ||
98 | -- good <- [C.exp| int{ GSL_SUCCESS } |] | ||
99 | -- if | res == good -> Right <$> V.freeze fMut | ||
100 | -- | res == maxSteps -> return $ Left "Too many steps" | ||
101 | -- | res == smallStep -> return $ Left "Step size dropped below minimum allowed size" | ||
102 | -- | otherwise -> return $ Left $ "Unknown error code " ++ show res | ||
103 | |||
104 | -- -- Utils | ||
105 | |||
106 | -- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
107 | -- vectorFromC len ptr = do | ||
108 | -- ptr' <- newForeignPtr_ ptr | ||
109 | -- V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
110 | |||
111 | -- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
112 | -- vectorToC vec len ptr = do | ||
113 | -- ptr' <- newForeignPtr_ ptr | ||
114 | -- V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
115 | |||
116 | |||
117 | -- /* Check function return value... | ||
118 | -- opt == 0 means SUNDIALS function allocates memory so check if | ||
119 | -- returned NULL pointer | ||
120 | -- opt == 1 means SUNDIALS function returns a flag so check if | ||
121 | -- flag >= 0 | ||
122 | -- opt == 2 means function allocates memory so check if returned | ||
123 | -- NULL pointer | ||
124 | -- */ | ||
125 | -- static int check_flag(void *flagvalue, const char *funcname, int opt) | ||
126 | -- { | ||
127 | -- int *errflag; | ||
128 | |||
129 | -- /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ | ||
130 | -- if (opt == 0 && flagvalue == NULL) { | ||
131 | -- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", | ||
132 | -- funcname); | ||
133 | -- return 1; } | ||
134 | |||
135 | -- /* Check if flag < 0 */ | ||
136 | -- else if (opt == 1) { | ||
137 | -- errflag = (int *) flagvalue; | ||
138 | -- if (*errflag < 0) { | ||
139 | -- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n", | ||
140 | -- funcname, *errflag); | ||
141 | -- return 1; }} | ||
142 | |||
143 | -- /* Check if function returned NULL pointer - no memory allocated */ | ||
144 | -- else if (opt == 2 && flagvalue == NULL) { | ||
145 | -- fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n", | ||
146 | -- funcname); | ||
147 | -- return 1; } | ||
148 | |||
149 | -- return 0; | ||
150 | -- } | ||
151 | |||
152 | main = do | ||
153 | res <- [C.block| int { sunindextype NEQ = 1; /* number of dependent vars. */ | ||
154 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
155 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
156 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
157 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
158 | |||
159 | N_VConst(0.0, y); /* Specify initial condition */ | ||
160 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
161 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
162 | return 0; | ||
163 | } |] | ||
164 | putStrLn $ show res | ||