diff options
Diffstat (limited to 'packages/sundials/src/Numeric')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 265 |
1 files changed, 265 insertions, 0 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs new file mode 100644 index 0000000..58acef3 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -0,0 +1,265 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | |||
9 | module Numeric.Sundials.Arkode.ODE ( solveOdeC ) where | ||
10 | |||
11 | import qualified Language.C.Inline as C | ||
12 | import qualified Language.C.Inline.Unsafe as CU | ||
13 | import Data.Monoid ((<>)) | ||
14 | import Foreign.C.Types | ||
15 | import Foreign.Ptr (Ptr) | ||
16 | import qualified Data.Vector.Storable as V | ||
17 | |||
18 | import Data.Coerce (coerce) | ||
19 | import qualified Data.Vector.Storable.Mutable as VM | ||
20 | import Foreign.ForeignPtr (newForeignPtr_) | ||
21 | import Foreign.Storable (Storable) | ||
22 | import System.IO.Unsafe (unsafePerformIO) | ||
23 | |||
24 | import Foreign.Storable (peekByteOff) | ||
25 | |||
26 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
27 | |||
28 | import qualified Types as T | ||
29 | |||
30 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
31 | |||
32 | -- C includes | ||
33 | C.include "<stdio.h>" | ||
34 | C.include "<math.h>" | ||
35 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
36 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
37 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
38 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
39 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
40 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
41 | C.include "<sundials/sundials_math.h>" | ||
42 | C.include "../../../helpers.h" | ||
43 | |||
44 | |||
45 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
46 | -- template. They are probably very fragile and could easily break on | ||
47 | -- different architectures and / or changes in the sundials package. | ||
48 | |||
49 | getContentPtr :: Storable a => Ptr b -> IO a | ||
50 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
51 | |||
52 | getData :: Storable a => Ptr b -> IO a | ||
53 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
54 | |||
55 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) | ||
56 | getDataFromContents len ptr = do | ||
57 | qtr <- getContentPtr ptr | ||
58 | rtr <- getData qtr | ||
59 | vectorFromC len rtr | ||
60 | |||
61 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
62 | putDataInContents vec len ptr = do | ||
63 | qtr <- getContentPtr ptr | ||
64 | rtr <- getData qtr | ||
65 | vectorToC vec len rtr | ||
66 | |||
67 | -- Utils | ||
68 | |||
69 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
70 | vectorFromC len ptr = do | ||
71 | ptr' <- newForeignPtr_ ptr | ||
72 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
73 | |||
74 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
75 | vectorToC vec len ptr = do | ||
76 | ptr' <- newForeignPtr_ ptr | ||
77 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
78 | |||
79 | stiffish :: Double -> V.Vector Double -> V.Vector Double | ||
80 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
81 | where | ||
82 | u = v V.! 0 | ||
83 | lamda = -100.0 | ||
84 | |||
85 | brusselator :: Double -> V.Vector Double -> V.Vector Double | ||
86 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | ||
87 | , w * u - v * u^2 | ||
88 | , (b - w) / eps - w * u | ||
89 | ] | ||
90 | where | ||
91 | a = 1.0 | ||
92 | b = 3.5 | ||
93 | eps = 5.0e-6 | ||
94 | u = x V.! 0 | ||
95 | v = x V.! 1 | ||
96 | w = x V.! 2 | ||
97 | |||
98 | |||
99 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
100 | -> [Double] -- ^ initial conditions | ||
101 | -> Vector Double -- ^ desired solution times | ||
102 | -> Matrix Double -- ^ solution | ||
103 | odeSolve = undefined | ||
104 | |||
105 | solveOdeC :: | ||
106 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
107 | -> V.Vector CDouble -- ^ Initial conditions | ||
108 | -> V.Vector CDouble -- ^ Desired solution times | ||
109 | -> Either CInt (V.Vector CDouble) -- ^ Error code or solution | ||
110 | solveOdeC fun f0 ts = unsafePerformIO $ do | ||
111 | let dim = V.length f0 | ||
112 | nEq :: CLong | ||
113 | nEq = fromIntegral dim | ||
114 | fMut <- V.thaw f0 | ||
115 | -- We need the types that sundials expects. These are tied together | ||
116 | -- in 'Types'. The Haskell type is currently empty! | ||
117 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
118 | funIO x y f _ptr = do | ||
119 | -- Convert the pointer we get from C (y) to a vector, and then | ||
120 | -- apply the user-supplied function. | ||
121 | fImm <- fun x <$> getDataFromContents dim y | ||
122 | -- Fill in the provided pointer with the resulting vector. | ||
123 | putDataInContents fImm dim f | ||
124 | -- I don't understand what this comment means | ||
125 | -- Unsafe since the function will be called many times. | ||
126 | [CU.exp| int{ 0 } |] | ||
127 | res <- [C.block| int { | ||
128 | /* general problem variables */ | ||
129 | int flag; /* reusable error-checking flag */ | ||
130 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
131 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
132 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
133 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
134 | FILE *UFID; | ||
135 | realtype t, tout; | ||
136 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
137 | |||
138 | /* general problem parameters */ | ||
139 | realtype T0 = RCONST(0.0); /* initial time */ | ||
140 | realtype Tf = RCONST(10.0); /* final time */ | ||
141 | realtype dTout = RCONST(1.0); /* time between outputs */ | ||
142 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
143 | realtype reltol = 1.0e-6; /* tolerances */ | ||
144 | realtype abstol = 1.0e-10; | ||
145 | |||
146 | /* Initial diagnostics output */ | ||
147 | printf("\nAnalytical ODE test problem:\n"); | ||
148 | printf(" reltol = %.1"ESYM"\n", reltol); | ||
149 | printf(" abstol = %.1"ESYM"\n\n",abstol); | ||
150 | |||
151 | /* Initialize data structures */ | ||
152 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
153 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
154 | int i; | ||
155 | for (i = 0; i < NEQ; i++) { | ||
156 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | ||
157 | }; /* Specify initial condition */ | ||
158 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
159 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
160 | |||
161 | /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
162 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
163 | /* the initial dependent variable vector y. Note: since this */ | ||
164 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | ||
165 | |||
166 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
167 | /* the Haskell types defined in Types */ | ||
168 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
169 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
170 | |||
171 | /* Set routines */ | ||
172 | flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ | ||
173 | if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; | ||
174 | |||
175 | /* Initialize dense matrix data structure and solver */ | ||
176 | A = SUNDenseMatrix(NEQ, NEQ); | ||
177 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
178 | LS = SUNDenseLinearSolver(y, A); | ||
179 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
180 | |||
181 | /* Linear solver interface */ | ||
182 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | ||
183 | /* Open output stream for results, output comment line */ | ||
184 | UFID = fopen("solution.txt","w"); | ||
185 | fprintf(UFID,"# t u\n"); | ||
186 | |||
187 | /* output initial condition to disk */ | ||
188 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); | ||
189 | |||
190 | /* Main time-stepping loop: calls ARKode to perform the integration, then | ||
191 | prints results. Stops when the final time has been reached */ | ||
192 | t = T0; | ||
193 | tout = T0+dTout; | ||
194 | printf(" t u\n"); | ||
195 | printf(" ---------------------\n"); | ||
196 | while (Tf - t > 1.0e-15) { | ||
197 | |||
198 | flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ | ||
199 | if (check_flag(&flag, "ARKode", 1)) break; | ||
200 | printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ | ||
201 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); | ||
202 | if (flag >= 0) { /* successful solve: update time */ | ||
203 | tout += dTout; | ||
204 | tout = (tout > Tf) ? Tf : tout; | ||
205 | } else { /* unsuccessful solve: break */ | ||
206 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
207 | break; | ||
208 | } | ||
209 | } | ||
210 | printf(" ---------------------\n"); | ||
211 | fclose(UFID); | ||
212 | |||
213 | for (i = 0; i < NEQ; i++) { | ||
214 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); | ||
215 | }; | ||
216 | |||
217 | /* Get/print some final statistics on how the solve progressed */ | ||
218 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
219 | check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
220 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
221 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
222 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
223 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
224 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
225 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
226 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
227 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
228 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
229 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
230 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
231 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
232 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
233 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
234 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
235 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
236 | |||
237 | printf("\nFinal Solver Statistics:\n"); | ||
238 | printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); | ||
239 | printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); | ||
240 | printf(" Total linear solver setups = %li\n", nsetups); | ||
241 | printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); | ||
242 | printf(" Total number of Jacobian evaluations = %li\n", nje); | ||
243 | printf(" Total number of Newton iterations = %li\n", nni); | ||
244 | printf(" Total number of linear solver convergence failures = %li\n", ncfn); | ||
245 | printf(" Total number of error test failures = %li\n\n", netf); | ||
246 | |||
247 | /* Clean up and return */ | ||
248 | N_VDestroy(y); /* Free y vector */ | ||
249 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
250 | SUNLinSolFree(LS); /* Free linear solver */ | ||
251 | SUNMatDestroy(A); /* Free A matrix */ | ||
252 | |||
253 | return flag; | ||
254 | } |] | ||
255 | if res ==0 | ||
256 | then do | ||
257 | v <- V.freeze fMut | ||
258 | return $ Right v | ||
259 | else do | ||
260 | return $ Left res | ||
261 | |||
262 | main :: IO () | ||
263 | main = do | ||
264 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined | ||
265 | putStrLn $ show res | ||