diff options
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/Arkode/ODE.hs')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/Arkode/ODE.hs | 313 |
1 files changed, 313 insertions, 0 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/Arkode/ODE.hs b/packages/sundials/src/Numeric/Sundials/Arkode/ODE.hs new file mode 100644 index 0000000..6d9a1b2 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/Arkode/ODE.hs | |||
@@ -0,0 +1,313 @@ | |||
1 | {-# LANGUAGE MultiWayIf #-} | ||
2 | {-# LANGUAGE OverloadedStrings #-} | ||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE ScopedTypeVariables #-} | ||
5 | {-# LANGUAGE TemplateHaskell #-} | ||
6 | |||
7 | module Numeric.Sundials.Arkode.ODE | ||
8 | ( SundialsDiagnostics(..) | ||
9 | , solveOde | ||
10 | , odeSolve | ||
11 | ) where | ||
12 | |||
13 | import qualified Language.C.Inline as C | ||
14 | import qualified Language.C.Inline.Unsafe as CU | ||
15 | |||
16 | import Data.Monoid ((<>)) | ||
17 | |||
18 | import Foreign.C.Types | ||
19 | import Foreign.Ptr (Ptr) | ||
20 | import Foreign.ForeignPtr (newForeignPtr_) | ||
21 | import Foreign.Storable (Storable, peekByteOff) | ||
22 | |||
23 | import qualified Data.Vector.Storable as V | ||
24 | import qualified Data.Vector.Storable.Mutable as VM | ||
25 | |||
26 | import Data.Coerce (coerce) | ||
27 | import System.IO.Unsafe (unsafePerformIO) | ||
28 | |||
29 | import Numeric.LinearAlgebra.Devel (createVector) | ||
30 | |||
31 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><)) | ||
32 | |||
33 | import qualified Types as T | ||
34 | |||
35 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
36 | |||
37 | -- C includes | ||
38 | C.include "<stdlib.h>" | ||
39 | C.include "<stdio.h>" | ||
40 | C.include "<math.h>" | ||
41 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
42 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
43 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
44 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
45 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
46 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
47 | C.include "<sundials/sundials_math.h>" | ||
48 | C.include "../../../helpers.h" | ||
49 | |||
50 | |||
51 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
52 | -- template. They are probably very fragile and could easily break on | ||
53 | -- different architectures and / or changes in the sundials package. | ||
54 | |||
55 | getContentPtr :: Storable a => Ptr b -> IO a | ||
56 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
57 | |||
58 | getData :: Storable a => Ptr b -> IO a | ||
59 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
60 | |||
61 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) | ||
62 | getDataFromContents len ptr = do | ||
63 | qtr <- getContentPtr ptr | ||
64 | rtr <- getData qtr | ||
65 | vectorFromC len rtr | ||
66 | |||
67 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
68 | putDataInContents vec len ptr = do | ||
69 | qtr <- getContentPtr ptr | ||
70 | rtr <- getData qtr | ||
71 | vectorToC vec len rtr | ||
72 | |||
73 | -- Utils | ||
74 | |||
75 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
76 | vectorFromC len ptr = do | ||
77 | ptr' <- newForeignPtr_ ptr | ||
78 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
79 | |||
80 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
81 | vectorToC vec len ptr = do | ||
82 | ptr' <- newForeignPtr_ ptr | ||
83 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
84 | |||
85 | data SundialsDiagnostics = SundialsDiagnostics { | ||
86 | aRKodeGetNumSteps :: Int | ||
87 | , aRKodeGetNumStepAttempts :: Int | ||
88 | , aRKodeGetNumRhsEvals_fe :: Int | ||
89 | , aRKodeGetNumRhsEvals_fi :: Int | ||
90 | , aRKodeGetNumLinSolvSetups :: Int | ||
91 | , aRKodeGetNumErrTestFails :: Int | ||
92 | , aRKodeGetNumNonlinSolvIters :: Int | ||
93 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
94 | , aRKDlsGetNumJacEvals :: Int | ||
95 | , aRKDlsGetNumRhsEvals :: Int | ||
96 | } deriving Show | ||
97 | |||
98 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
99 | -> [Double] -- ^ initial conditions | ||
100 | -> Vector Double -- ^ desired solution times | ||
101 | -> Matrix Double -- ^ solution | ||
102 | odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of | ||
103 | Left c -> error $ show c -- FIXME | ||
104 | Right (v, _) -> (nR >< nC) (V.toList v) | ||
105 | where | ||
106 | us = toList ts | ||
107 | nR = length us | ||
108 | nC = length y0 | ||
109 | g t x0 = V.fromList $ f t (V.toList x0) | ||
110 | |||
111 | solveOde :: | ||
112 | (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
113 | -> V.Vector Double -- ^ Initial conditions | ||
114 | -> V.Vector Double -- ^ Desired solution times | ||
115 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | ||
116 | solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of | ||
117 | Left c -> Left $ fromIntegral c | ||
118 | Right (v, d) -> Right (coerce v, d) | ||
119 | |||
120 | solveOdeC :: | ||
121 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
122 | -> V.Vector CDouble -- ^ Initial conditions | ||
123 | -> V.Vector CDouble -- ^ Desired solution times | ||
124 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | ||
125 | solveOdeC fun f0 ts = unsafePerformIO $ do | ||
126 | let dim = V.length f0 | ||
127 | nEq :: CLong | ||
128 | nEq = fromIntegral dim | ||
129 | nTs :: CInt | ||
130 | nTs = fromIntegral $ V.length ts | ||
131 | -- FIXME: fMut is not actually mutatated | ||
132 | fMut <- V.thaw f0 | ||
133 | tMut <- V.thaw ts | ||
134 | -- FIXME: I believe this gets taken from the ghc heap and so should | ||
135 | -- be subject to garbage collection. | ||
136 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | ||
137 | qMatMut <- V.thaw quasiMatrixRes | ||
138 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | ||
139 | diagMut <- V.thaw diagnostics | ||
140 | -- We need the types that sundials expects. These are tied together | ||
141 | -- in 'Types'. FIXME: The Haskell type is currently empty! | ||
142 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
143 | funIO x y f _ptr = do | ||
144 | -- Convert the pointer we get from C (y) to a vector, and then | ||
145 | -- apply the user-supplied function. | ||
146 | fImm <- fun x <$> getDataFromContents dim y | ||
147 | -- Fill in the provided pointer with the resulting vector. | ||
148 | putDataInContents fImm dim f | ||
149 | -- I don't understand what this comment means | ||
150 | -- Unsafe since the function will be called many times. | ||
151 | [CU.exp| int{ 0 } |] | ||
152 | res <- [C.block| int { | ||
153 | /* general problem variables */ | ||
154 | int flag; /* reusable error-checking flag */ | ||
155 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
156 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
157 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
158 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
159 | realtype t; | ||
160 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
161 | |||
162 | /* general problem parameters */ | ||
163 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ | ||
164 | |||
165 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
166 | realtype reltol = 1.0e-6; /* tolerances */ | ||
167 | realtype abstol = 1.0e-10; | ||
168 | |||
169 | /* Initial diagnostics output */ | ||
170 | printf("\nAnalytical ODE test problem:\n"); | ||
171 | printf(" reltol = %.1"ESYM"\n", reltol); | ||
172 | printf(" abstol = %.1"ESYM"\n\n",abstol); | ||
173 | |||
174 | /* Initialize data structures */ | ||
175 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
176 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
177 | int i, j; | ||
178 | for (i = 0; i < NEQ; i++) { | ||
179 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | ||
180 | }; /* Specify initial condition */ | ||
181 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
182 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
183 | |||
184 | /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
185 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
186 | /* the initial dependent variable vector y. Note: since this */ | ||
187 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | ||
188 | |||
189 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
190 | /* the Haskell types defined in Types */ | ||
191 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
192 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
193 | |||
194 | /* Set routines */ | ||
195 | flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ | ||
196 | if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; | ||
197 | |||
198 | /* Initialize dense matrix data structure and solver */ | ||
199 | A = SUNDenseMatrix(NEQ, NEQ); | ||
200 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
201 | LS = SUNDenseLinearSolver(y, A); | ||
202 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
203 | |||
204 | /* Linear solver interface */ | ||
205 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | ||
206 | /* Store initial conditions */ | ||
207 | for (j = 0; j < NEQ; j++) { | ||
208 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | ||
209 | } | ||
210 | |||
211 | flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3); | ||
212 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
213 | |||
214 | int s, q, p; | ||
215 | realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
216 | realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype)); | ||
217 | realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
218 | realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
219 | realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
220 | realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
221 | realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
222 | realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype)); | ||
223 | flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e); | ||
224 | if (check_flag(&flag, "ARKode", 1)) return 1; | ||
225 | fprintf(stderr, "s = %d, q = %d, p = %d\n", s, q, p); | ||
226 | for (i = 0; i < s; i++) { | ||
227 | for (j = 0; j < s; j++) { | ||
228 | fprintf(stderr, "ai[%d,%d] = %f", i, j, ai[i * ARK_S_MAX + j]); | ||
229 | } | ||
230 | fprintf(stderr, "\n"); | ||
231 | } | ||
232 | |||
233 | /* Main time-stepping loop: calls ARKode to perform the integration */ | ||
234 | /* Stops when the final time has been reached */ | ||
235 | for (i = 1; i < $(int nTs); i++) { | ||
236 | |||
237 | flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ | ||
238 | if (check_flag(&flag, "ARKode", 1)) break; | ||
239 | |||
240 | /* Store the results for Haskell */ | ||
241 | for (j = 0; j < NEQ; j++) { | ||
242 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); | ||
243 | } | ||
244 | |||
245 | if (flag < 0) { /* unsuccessful solve: break */ | ||
246 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
247 | break; | ||
248 | } | ||
249 | } | ||
250 | |||
251 | /* Get some final statistics on how the solve progressed */ | ||
252 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
253 | check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
254 | ($vec-ptr:(long int *diagMut))[0] = nst; | ||
255 | |||
256 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
257 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
258 | ($vec-ptr:(long int *diagMut))[1] = nst_a; | ||
259 | |||
260 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
261 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
262 | ($vec-ptr:(long int *diagMut))[2] = nfe; | ||
263 | ($vec-ptr:(long int *diagMut))[3] = nfi; | ||
264 | |||
265 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
266 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
267 | ($vec-ptr:(long int *diagMut))[4] = nsetups; | ||
268 | |||
269 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
270 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
271 | ($vec-ptr:(long int *diagMut))[5] = netf; | ||
272 | |||
273 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
274 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
275 | ($vec-ptr:(long int *diagMut))[6] = nni; | ||
276 | |||
277 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
278 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
279 | ($vec-ptr:(long int *diagMut))[7] = ncfn; | ||
280 | |||
281 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
282 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
283 | ($vec-ptr:(long int *diagMut))[8] = ncfn; | ||
284 | |||
285 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
286 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
287 | ($vec-ptr:(long int *diagMut))[9] = ncfn; | ||
288 | |||
289 | /* Clean up and return */ | ||
290 | N_VDestroy(y); /* Free y vector */ | ||
291 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
292 | SUNLinSolFree(LS); /* Free linear solver */ | ||
293 | SUNMatDestroy(A); /* Free A matrix */ | ||
294 | |||
295 | return flag; | ||
296 | } |] | ||
297 | if res == 0 | ||
298 | then do | ||
299 | preD <- V.freeze diagMut | ||
300 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | ||
301 | (fromIntegral $ preD V.!1) | ||
302 | (fromIntegral $ preD V.!2) | ||
303 | (fromIntegral $ preD V.!3) | ||
304 | (fromIntegral $ preD V.!4) | ||
305 | (fromIntegral $ preD V.!5) | ||
306 | (fromIntegral $ preD V.!6) | ||
307 | (fromIntegral $ preD V.!7) | ||
308 | (fromIntegral $ preD V.!8) | ||
309 | (fromIntegral $ preD V.!9) | ||
310 | m <- V.freeze qMatMut | ||
311 | return $ Right (m, d) | ||
312 | else do | ||
313 | return $ Left res | ||