summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src/Numeric')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs265
1 files changed, 265 insertions, 0 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
new file mode 100644
index 0000000..58acef3
--- /dev/null
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -0,0 +1,265 @@
1{-# OPTIONS_GHC -Wall #-}
2
3{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-}
6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8
9module Numeric.Sundials.Arkode.ODE ( solveOdeC ) where
10
11import qualified Language.C.Inline as C
12import qualified Language.C.Inline.Unsafe as CU
13import Data.Monoid ((<>))
14import Foreign.C.Types
15import Foreign.Ptr (Ptr)
16import qualified Data.Vector.Storable as V
17
18import Data.Coerce (coerce)
19import qualified Data.Vector.Storable.Mutable as VM
20import Foreign.ForeignPtr (newForeignPtr_)
21import Foreign.Storable (Storable)
22import System.IO.Unsafe (unsafePerformIO)
23
24import Foreign.Storable (peekByteOff)
25
26import Numeric.LinearAlgebra.HMatrix (Vector, Matrix)
27
28import qualified Types as T
29
30C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
31
32-- C includes
33C.include "<stdio.h>"
34C.include "<math.h>"
35C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts.
36C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
37C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
38C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
39C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
40C.include "<sundials/sundials_types.h>" -- definition of type realtype
41C.include "<sundials/sundials_math.h>"
42C.include "../../../helpers.h"
43
44
45-- These were semi-generated using hsc2hs with Bar.hsc as the
46-- template. They are probably very fragile and could easily break on
47-- different architectures and / or changes in the sundials package.
48
49getContentPtr :: Storable a => Ptr b -> IO a
50getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr
51
52getData :: Storable a => Ptr b -> IO a
53getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr
54
55getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b)
56getDataFromContents len ptr = do
57 qtr <- getContentPtr ptr
58 rtr <- getData qtr
59 vectorFromC len rtr
60
61putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
62putDataInContents vec len ptr = do
63 qtr <- getContentPtr ptr
64 rtr <- getData qtr
65 vectorToC vec len rtr
66
67-- Utils
68
69vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
70vectorFromC len ptr = do
71 ptr' <- newForeignPtr_ ptr
72 V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
73
74vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
75vectorToC vec len ptr = do
76 ptr' <- newForeignPtr_ ptr
77 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
78
79stiffish :: Double -> V.Vector Double -> V.Vector Double
80stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
81 where
82 u = v V.! 0
83 lamda = -100.0
84
85brusselator :: Double -> V.Vector Double -> V.Vector Double
86brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2
87 , w * u - v * u^2
88 , (b - w) / eps - w * u
89 ]
90 where
91 a = 1.0
92 b = 3.5
93 eps = 5.0e-6
94 u = x V.! 0
95 v = x V.! 1
96 w = x V.! 2
97
98
99odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
100 -> [Double] -- ^ initial conditions
101 -> Vector Double -- ^ desired solution times
102 -> Matrix Double -- ^ solution
103odeSolve = undefined
104
105solveOdeC ::
106 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
107 -> V.Vector CDouble -- ^ Initial conditions
108 -> V.Vector CDouble -- ^ Desired solution times
109 -> Either CInt (V.Vector CDouble) -- ^ Error code or solution
110solveOdeC fun f0 ts = unsafePerformIO $ do
111 let dim = V.length f0
112 nEq :: CLong
113 nEq = fromIntegral dim
114 fMut <- V.thaw f0
115 -- We need the types that sundials expects. These are tied together
116 -- in 'Types'. The Haskell type is currently empty!
117 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
118 funIO x y f _ptr = do
119 -- Convert the pointer we get from C (y) to a vector, and then
120 -- apply the user-supplied function.
121 fImm <- fun x <$> getDataFromContents dim y
122 -- Fill in the provided pointer with the resulting vector.
123 putDataInContents fImm dim f
124 -- I don't understand what this comment means
125 -- Unsafe since the function will be called many times.
126 [CU.exp| int{ 0 } |]
127 res <- [C.block| int {
128 /* general problem variables */
129 int flag; /* reusable error-checking flag */
130 N_Vector y = NULL; /* empty vector for storing solution */
131 SUNMatrix A = NULL; /* empty matrix for linear solver */
132 SUNLinearSolver LS = NULL; /* empty linear solver object */
133 void *arkode_mem = NULL; /* empty ARKode memory structure */
134 FILE *UFID;
135 realtype t, tout;
136 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
137
138 /* general problem parameters */
139 realtype T0 = RCONST(0.0); /* initial time */
140 realtype Tf = RCONST(10.0); /* final time */
141 realtype dTout = RCONST(1.0); /* time between outputs */
142 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
143 realtype reltol = 1.0e-6; /* tolerances */
144 realtype abstol = 1.0e-10;
145
146 /* Initial diagnostics output */
147 printf("\nAnalytical ODE test problem:\n");
148 printf(" reltol = %.1"ESYM"\n", reltol);
149 printf(" abstol = %.1"ESYM"\n\n",abstol);
150
151 /* Initialize data structures */
152 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
153 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
154 int i;
155 for (i = 0; i < NEQ; i++) {
156 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
157 }; /* Specify initial condition */
158 arkode_mem = ARKodeCreate(); /* Create the solver memory */
159 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
160
161 /* Call ARKodeInit to initialize the integrator memory and specify the */
162 /* right-hand side function in y'=f(t,y), the inital time T0, and */
163 /* the initial dependent variable vector y. Note: since this */
164 /* problem is fully implicit, we set f_E to NULL and f_I to f. */
165
166 /* Here we use the C types defined in helpers.h which tie up with */
167 /* the Haskell types defined in Types */
168 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y);
169 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
170
171 /* Set routines */
172 flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */
173 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1;
174
175 /* Initialize dense matrix data structure and solver */
176 A = SUNDenseMatrix(NEQ, NEQ);
177 if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
178 LS = SUNDenseLinearSolver(y, A);
179 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
180
181 /* Linear solver interface */
182 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
183 /* Open output stream for results, output comment line */
184 UFID = fopen("solution.txt","w");
185 fprintf(UFID,"# t u\n");
186
187 /* output initial condition to disk */
188 fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0));
189
190 /* Main time-stepping loop: calls ARKode to perform the integration, then
191 prints results. Stops when the final time has been reached */
192 t = T0;
193 tout = T0+dTout;
194 printf(" t u\n");
195 printf(" ---------------------\n");
196 while (Tf - t > 1.0e-15) {
197
198 flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */
199 if (check_flag(&flag, "ARKode", 1)) break;
200 printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */
201 fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0));
202 if (flag >= 0) { /* successful solve: update time */
203 tout += dTout;
204 tout = (tout > Tf) ? Tf : tout;
205 } else { /* unsuccessful solve: break */
206 fprintf(stderr,"Solver failure, stopping integration\n");
207 break;
208 }
209 }
210 printf(" ---------------------\n");
211 fclose(UFID);
212
213 for (i = 0; i < NEQ; i++) {
214 ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i);
215 };
216
217 /* Get/print some final statistics on how the solve progressed */
218 flag = ARKodeGetNumSteps(arkode_mem, &nst);
219 check_flag(&flag, "ARKodeGetNumSteps", 1);
220 flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a);
221 check_flag(&flag, "ARKodeGetNumStepAttempts", 1);
222 flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi);
223 check_flag(&flag, "ARKodeGetNumRhsEvals", 1);
224 flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups);
225 check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1);
226 flag = ARKodeGetNumErrTestFails(arkode_mem, &netf);
227 check_flag(&flag, "ARKodeGetNumErrTestFails", 1);
228 flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni);
229 check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1);
230 flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn);
231 check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1);
232 flag = ARKDlsGetNumJacEvals(arkode_mem, &nje);
233 check_flag(&flag, "ARKDlsGetNumJacEvals", 1);
234 flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS);
235 check_flag(&flag, "ARKDlsGetNumRhsEvals", 1);
236
237 printf("\nFinal Solver Statistics:\n");
238 printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a);
239 printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi);
240 printf(" Total linear solver setups = %li\n", nsetups);
241 printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS);
242 printf(" Total number of Jacobian evaluations = %li\n", nje);
243 printf(" Total number of Newton iterations = %li\n", nni);
244 printf(" Total number of linear solver convergence failures = %li\n", ncfn);
245 printf(" Total number of error test failures = %li\n\n", netf);
246
247 /* Clean up and return */
248 N_VDestroy(y); /* Free y vector */
249 ARKodeFree(&arkode_mem); /* Free integrator memory */
250 SUNLinSolFree(LS); /* Free linear solver */
251 SUNMatDestroy(A); /* Free A matrix */
252
253 return flag;
254 } |]
255 if res ==0
256 then do
257 v <- V.freeze fMut
258 return $ Right v
259 else do
260 return $ Left res
261
262main :: IO ()
263main = do
264 let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined
265 putStrLn $ show res