summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
authoridontgetoutmuch <dominic@steinitz.org>2018-03-26 10:32:51 +0100
committerGitHub <noreply@github.com>2018-03-26 10:32:51 +0100
commit9fd7adf7dda75077b85f0337a548be9138fc1ed5 (patch)
tree03e23b7027dd1e7983a98328b47aa795256005de /packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
parent560f38ab27bcc44c80ce7d9c2e4972342170fe28 (diff)
Revert "Cleanups to Sundials PR"
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs314
1 files changed, 314 insertions, 0 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
new file mode 100644
index 0000000..f432951
--- /dev/null
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -0,0 +1,314 @@
1{-# OPTIONS_GHC -Wall #-}
2
3{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-}
6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8
9module Numeric.Sundials.Arkode.ODE ( solveOde
10 , odeSolve
11 ) where
12
13import qualified Language.C.Inline as C
14import qualified Language.C.Inline.Unsafe as CU
15
16import Data.Monoid ((<>))
17
18import Foreign.C.Types
19import Foreign.Ptr (Ptr)
20import Foreign.ForeignPtr (newForeignPtr_)
21import Foreign.Storable (Storable, peekByteOff)
22
23import qualified Data.Vector.Storable as V
24import qualified Data.Vector.Storable.Mutable as VM
25
26import Data.Coerce (coerce)
27import System.IO.Unsafe (unsafePerformIO)
28
29import Numeric.LinearAlgebra.Devel (createVector)
30
31import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><))
32
33import qualified Types as T
34
35C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
36
37-- C includes
38C.include "<stdlib.h>"
39C.include "<stdio.h>"
40C.include "<math.h>"
41C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts.
42C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
43C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
44C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
45C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
46C.include "<sundials/sundials_types.h>" -- definition of type realtype
47C.include "<sundials/sundials_math.h>"
48C.include "../../../helpers.h"
49
50
51-- These were semi-generated using hsc2hs with Bar.hsc as the
52-- template. They are probably very fragile and could easily break on
53-- different architectures and / or changes in the sundials package.
54
55getContentPtr :: Storable a => Ptr b -> IO a
56getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr
57
58getData :: Storable a => Ptr b -> IO a
59getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr
60
61getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b)
62getDataFromContents len ptr = do
63 qtr <- getContentPtr ptr
64 rtr <- getData qtr
65 vectorFromC len rtr
66
67putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
68putDataInContents vec len ptr = do
69 qtr <- getContentPtr ptr
70 rtr <- getData qtr
71 vectorToC vec len rtr
72
73-- Utils
74
75vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
76vectorFromC len ptr = do
77 ptr' <- newForeignPtr_ ptr
78 V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
79
80vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
81vectorToC vec len ptr = do
82 ptr' <- newForeignPtr_ ptr
83 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
84
85data SundialsDiagnostics = SundialsDiagnostics {
86 aRKodeGetNumSteps :: Int
87 , aRKodeGetNumStepAttempts :: Int
88 , aRKodeGetNumRhsEvals_fe :: Int
89 , aRKodeGetNumRhsEvals_fi :: Int
90 , aRKodeGetNumLinSolvSetups :: Int
91 , aRKodeGetNumErrTestFails :: Int
92 , aRKodeGetNumNonlinSolvIters :: Int
93 , aRKodeGetNumNonlinSolvConvFails :: Int
94 , aRKDlsGetNumJacEvals :: Int
95 , aRKDlsGetNumRhsEvals :: Int
96 } deriving Show
97
98odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
99 -> [Double] -- ^ initial conditions
100 -> Vector Double -- ^ desired solution times
101 -> Matrix Double -- ^ solution
102odeSolve f y0 ts = case solveOde g (V.fromList y0) (V.fromList $ toList ts) of
103 Left c -> error $ show c -- FIXME
104 Right (v, _) -> (nR >< nC) (V.toList v)
105 where
106 us = toList ts
107 nR = length us
108 nC = length y0
109 g t x0 = V.fromList $ f t (V.toList x0)
110
111solveOde ::
112 (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
113 -> V.Vector Double -- ^ Initial conditions
114 -> V.Vector Double -- ^ Desired solution times
115 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
116solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of
117 Left c -> Left $ fromIntegral c
118 Right (v, d) -> Right (coerce v, d)
119
120solveOdeC ::
121 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
122 -> V.Vector CDouble -- ^ Initial conditions
123 -> V.Vector CDouble -- ^ Desired solution times
124 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
125solveOdeC fun f0 ts = unsafePerformIO $ do
126 let dim = V.length f0
127 nEq :: CLong
128 nEq = fromIntegral dim
129 nTs :: CInt
130 nTs = fromIntegral $ V.length ts
131 -- FIXME: fMut is not actually mutatated
132 fMut <- V.thaw f0
133 tMut <- V.thaw ts
134 -- FIXME: I believe this gets taken from the ghc heap and so should
135 -- be subject to garbage collection.
136 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
137 qMatMut <- V.thaw quasiMatrixRes
138 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
139 diagMut <- V.thaw diagnostics
140 -- We need the types that sundials expects. These are tied together
141 -- in 'Types'. FIXME: The Haskell type is currently empty!
142 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
143 funIO x y f _ptr = do
144 -- Convert the pointer we get from C (y) to a vector, and then
145 -- apply the user-supplied function.
146 fImm <- fun x <$> getDataFromContents dim y
147 -- Fill in the provided pointer with the resulting vector.
148 putDataInContents fImm dim f
149 -- I don't understand what this comment means
150 -- Unsafe since the function will be called many times.
151 [CU.exp| int{ 0 } |]
152 res <- [C.block| int {
153 /* general problem variables */
154 int flag; /* reusable error-checking flag */
155 N_Vector y = NULL; /* empty vector for storing solution */
156 SUNMatrix A = NULL; /* empty matrix for linear solver */
157 SUNLinearSolver LS = NULL; /* empty linear solver object */
158 void *arkode_mem = NULL; /* empty ARKode memory structure */
159 realtype t;
160 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
161
162 /* general problem parameters */
163 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */
164
165 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
166 realtype reltol = 1.0e-6; /* tolerances */
167 realtype abstol = 1.0e-10;
168
169 /* Initial diagnostics output */
170 printf("\nAnalytical ODE test problem:\n");
171 printf(" reltol = %.1"ESYM"\n", reltol);
172 printf(" abstol = %.1"ESYM"\n\n",abstol);
173
174 /* Initialize data structures */
175 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
176 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
177 int i, j;
178 for (i = 0; i < NEQ; i++) {
179 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
180 }; /* Specify initial condition */
181 arkode_mem = ARKodeCreate(); /* Create the solver memory */
182 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
183
184 /* Call ARKodeInit to initialize the integrator memory and specify the */
185 /* right-hand side function in y'=f(t,y), the inital time T0, and */
186 /* the initial dependent variable vector y. Note: since this */
187 /* problem is fully implicit, we set f_E to NULL and f_I to f. */
188
189 /* Here we use the C types defined in helpers.h which tie up with */
190 /* the Haskell types defined in Types */
191 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y);
192 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
193
194 /* Set routines */
195 flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */
196 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1;
197
198 /* Initialize dense matrix data structure and solver */
199 A = SUNDenseMatrix(NEQ, NEQ);
200 if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
201 LS = SUNDenseLinearSolver(y, A);
202 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
203
204 /* Linear solver interface */
205 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
206 /* Store initial conditions */
207 for (j = 0; j < NEQ; j++) {
208 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
209 }
210
211 flag = ARKodeSetIRKTableNum(arkode_mem, KVAERNO_4_2_3);
212 if (check_flag(&flag, "ARKode", 1)) return 1;
213
214 int s, q, p;
215 realtype *ai = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
216 realtype *ae = (realtype *)malloc(ARK_S_MAX * ARK_S_MAX * sizeof(realtype));
217 realtype *ci = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
218 realtype *ce = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
219 realtype *bi = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
220 realtype *be = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
221 realtype *b2i = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
222 realtype *b2e = (realtype *)malloc(ARK_S_MAX * sizeof(realtype));
223 flag = ARKodeGetCurrentButcherTables(arkode_mem, &s, &q, &p, ai, ae, ci, ce, bi, be, b2i, b2e);
224 if (check_flag(&flag, "ARKode", 1)) return 1;
225 fprintf(stderr, "s = %d, q = %d, p = %d\n", s, q, p);
226 for (i = 0; i < s; i++) {
227 for (j = 0; j < s; j++) {
228 fprintf(stderr, "ai[%d,%d] = %f", i, j, ai[i * ARK_S_MAX + j]);
229 }
230 fprintf(stderr, "\n");
231 }
232
233 /* Main time-stepping loop: calls ARKode to perform the integration */
234 /* Stops when the final time has been reached */
235 for (i = 1; i < $(int nTs); i++) {
236
237 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */
238 if (check_flag(&flag, "ARKode", 1)) break;
239
240 /* Store the results for Haskell */
241 for (j = 0; j < NEQ; j++) {
242 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
243 }
244
245 if (flag < 0) { /* unsuccessful solve: break */
246 fprintf(stderr,"Solver failure, stopping integration\n");
247 break;
248 }
249 }
250
251 /* Get some final statistics on how the solve progressed */
252 flag = ARKodeGetNumSteps(arkode_mem, &nst);
253 check_flag(&flag, "ARKodeGetNumSteps", 1);
254 ($vec-ptr:(long int *diagMut))[0] = nst;
255
256 flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a);
257 check_flag(&flag, "ARKodeGetNumStepAttempts", 1);
258 ($vec-ptr:(long int *diagMut))[1] = nst_a;
259
260 flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi);
261 check_flag(&flag, "ARKodeGetNumRhsEvals", 1);
262 ($vec-ptr:(long int *diagMut))[2] = nfe;
263 ($vec-ptr:(long int *diagMut))[3] = nfi;
264
265 flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups);
266 check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1);
267 ($vec-ptr:(long int *diagMut))[4] = nsetups;
268
269 flag = ARKodeGetNumErrTestFails(arkode_mem, &netf);
270 check_flag(&flag, "ARKodeGetNumErrTestFails", 1);
271 ($vec-ptr:(long int *diagMut))[5] = netf;
272
273 flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni);
274 check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1);
275 ($vec-ptr:(long int *diagMut))[6] = nni;
276
277 flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn);
278 check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1);
279 ($vec-ptr:(long int *diagMut))[7] = ncfn;
280
281 flag = ARKDlsGetNumJacEvals(arkode_mem, &nje);
282 check_flag(&flag, "ARKDlsGetNumJacEvals", 1);
283 ($vec-ptr:(long int *diagMut))[8] = ncfn;
284
285 flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS);
286 check_flag(&flag, "ARKDlsGetNumRhsEvals", 1);
287 ($vec-ptr:(long int *diagMut))[9] = ncfn;
288
289 /* Clean up and return */
290 N_VDestroy(y); /* Free y vector */
291 ARKodeFree(&arkode_mem); /* Free integrator memory */
292 SUNLinSolFree(LS); /* Free linear solver */
293 SUNMatDestroy(A); /* Free A matrix */
294
295 return flag;
296 } |]
297 if res == 0
298 then do
299 preD <- V.freeze diagMut
300 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
301 (fromIntegral $ preD V.!1)
302 (fromIntegral $ preD V.!2)
303 (fromIntegral $ preD V.!3)
304 (fromIntegral $ preD V.!4)
305 (fromIntegral $ preD V.!5)
306 (fromIntegral $ preD V.!6)
307 (fromIntegral $ preD V.!7)
308 (fromIntegral $ preD V.!8)
309 (fromIntegral $ preD V.!9)
310 m <- V.freeze qMatMut
311 return $ Right (m, d)
312 else do
313 return $ Left res
314