diff options
Diffstat (limited to 'packages/sundials/src/Main.hs')
-rw-r--r-- | packages/sundials/src/Main.hs | 245 |
1 files changed, 2 insertions, 243 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index d1f35bd..978088b 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -1,84 +1,7 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | 1 | {-# OPTIONS_GHC -Wall #-} |
2 | 2 | ||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | |||
9 | import qualified Language.C.Inline as C | ||
10 | import qualified Language.C.Inline.Unsafe as CU | ||
11 | import Data.Monoid ((<>)) | ||
12 | import Foreign.C.Types | ||
13 | import Foreign.Ptr (Ptr) | ||
14 | import qualified Data.Vector.Storable as V | 3 | import qualified Data.Vector.Storable as V |
15 | 4 | import Numeric.Sundials.Arkode.ODE | |
16 | import Data.Coerce (coerce) | ||
17 | import qualified Data.Vector.Storable.Mutable as VM | ||
18 | import Foreign.ForeignPtr (newForeignPtr_) | ||
19 | import Foreign.Storable (Storable) | ||
20 | import System.IO.Unsafe (unsafePerformIO) | ||
21 | |||
22 | import Foreign.Storable (peekByteOff) | ||
23 | |||
24 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
25 | |||
26 | import qualified Types as T | ||
27 | |||
28 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
29 | |||
30 | -- C includes | ||
31 | C.include "<stdio.h>" | ||
32 | C.include "<math.h>" | ||
33 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
34 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
35 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
36 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
37 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
38 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
39 | C.include "<sundials/sundials_math.h>" | ||
40 | C.include "helpers.h" | ||
41 | |||
42 | |||
43 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
44 | -- template. They are probably very fragile and could easily break on | ||
45 | -- different architectures and / or changes in the sundials package. | ||
46 | |||
47 | getContentPtr :: Storable a => Ptr b -> IO a | ||
48 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
49 | |||
50 | getData :: Storable a => Ptr b -> IO a | ||
51 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
52 | |||
53 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) | ||
54 | getDataFromContents len ptr = do | ||
55 | qtr <- getContentPtr ptr | ||
56 | rtr <- getData qtr | ||
57 | vectorFromC len rtr | ||
58 | |||
59 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
60 | putDataInContents vec len ptr = do | ||
61 | qtr <- getContentPtr ptr | ||
62 | rtr <- getData qtr | ||
63 | vectorToC vec len rtr | ||
64 | |||
65 | -- Utils | ||
66 | |||
67 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
68 | vectorFromC len ptr = do | ||
69 | ptr' <- newForeignPtr_ ptr | ||
70 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
71 | |||
72 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
73 | vectorToC vec len ptr = do | ||
74 | ptr' <- newForeignPtr_ ptr | ||
75 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
76 | |||
77 | stiffish :: Double -> V.Vector Double -> V.Vector Double | ||
78 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
79 | where | ||
80 | u = v V.! 0 | ||
81 | lamda = -100.0 | ||
82 | 5 | ||
83 | brusselator :: Double -> V.Vector Double -> V.Vector Double | 6 | brusselator :: Double -> V.Vector Double -> V.Vector Double |
84 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | 7 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 |
@@ -93,171 +16,7 @@ brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | |||
93 | v = x V.! 1 | 16 | v = x V.! 1 |
94 | w = x V.! 2 | 17 | w = x V.! 2 |
95 | 18 | ||
96 | |||
97 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
98 | -> [Double] -- ^ initial conditions | ||
99 | -> Vector Double -- ^ desired solution times | ||
100 | -> Matrix Double -- ^ solution | ||
101 | odeSolve = undefined | ||
102 | |||
103 | solveOdeC :: | ||
104 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
105 | -> V.Vector CDouble -- ^ Initial conditions | ||
106 | -> V.Vector CDouble -- ^ Desired solution times | ||
107 | -> Either CInt (V.Vector CDouble) -- ^ Error code or solution | ||
108 | solveOdeC fun f0 ts = unsafePerformIO $ do | ||
109 | let dim = V.length f0 | ||
110 | nEq :: CLong | ||
111 | nEq = fromIntegral dim | ||
112 | fMut <- V.thaw f0 | ||
113 | -- We need the types that sundials expects. These are tied together | ||
114 | -- in 'Types'. The Haskell type is currently empty! | ||
115 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
116 | funIO x y f _ptr = do | ||
117 | -- Convert the pointer we get from C (y) to a vector, and then | ||
118 | -- apply the user-supplied function. | ||
119 | fImm <- fun x <$> getDataFromContents dim y | ||
120 | -- Fill in the provided pointer with the resulting vector. | ||
121 | putDataInContents fImm dim f | ||
122 | -- I don't understand what this comment means | ||
123 | -- Unsafe since the function will be called many times. | ||
124 | [CU.exp| int{ 0 } |] | ||
125 | res <- [C.block| int { | ||
126 | /* general problem variables */ | ||
127 | int flag; /* reusable error-checking flag */ | ||
128 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
129 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
130 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
131 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
132 | FILE *UFID; | ||
133 | realtype t, tout; | ||
134 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
135 | |||
136 | /* general problem parameters */ | ||
137 | realtype T0 = RCONST(0.0); /* initial time */ | ||
138 | realtype Tf = RCONST(10.0); /* final time */ | ||
139 | realtype dTout = RCONST(1.0); /* time between outputs */ | ||
140 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
141 | realtype reltol = 1.0e-6; /* tolerances */ | ||
142 | realtype abstol = 1.0e-10; | ||
143 | |||
144 | /* Initial diagnostics output */ | ||
145 | printf("\nAnalytical ODE test problem:\n"); | ||
146 | printf(" reltol = %.1"ESYM"\n", reltol); | ||
147 | printf(" abstol = %.1"ESYM"\n\n",abstol); | ||
148 | |||
149 | /* Initialize data structures */ | ||
150 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
151 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
152 | int i; | ||
153 | for (i = 0; i < NEQ; i++) { | ||
154 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | ||
155 | }; /* Specify initial condition */ | ||
156 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
157 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
158 | |||
159 | /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
160 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
161 | /* the initial dependent variable vector y. Note: since this */ | ||
162 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | ||
163 | |||
164 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
165 | /* the Haskell types defined in Types */ | ||
166 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
167 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
168 | |||
169 | /* Set routines */ | ||
170 | flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ | ||
171 | if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; | ||
172 | |||
173 | /* Initialize dense matrix data structure and solver */ | ||
174 | A = SUNDenseMatrix(NEQ, NEQ); | ||
175 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
176 | LS = SUNDenseLinearSolver(y, A); | ||
177 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
178 | |||
179 | /* Linear solver interface */ | ||
180 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | ||
181 | /* Open output stream for results, output comment line */ | ||
182 | UFID = fopen("solution.txt","w"); | ||
183 | fprintf(UFID,"# t u\n"); | ||
184 | |||
185 | /* output initial condition to disk */ | ||
186 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); | ||
187 | |||
188 | /* Main time-stepping loop: calls ARKode to perform the integration, then | ||
189 | prints results. Stops when the final time has been reached */ | ||
190 | t = T0; | ||
191 | tout = T0+dTout; | ||
192 | printf(" t u\n"); | ||
193 | printf(" ---------------------\n"); | ||
194 | while (Tf - t > 1.0e-15) { | ||
195 | |||
196 | flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ | ||
197 | if (check_flag(&flag, "ARKode", 1)) break; | ||
198 | printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ | ||
199 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); | ||
200 | if (flag >= 0) { /* successful solve: update time */ | ||
201 | tout += dTout; | ||
202 | tout = (tout > Tf) ? Tf : tout; | ||
203 | } else { /* unsuccessful solve: break */ | ||
204 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
205 | break; | ||
206 | } | ||
207 | } | ||
208 | printf(" ---------------------\n"); | ||
209 | fclose(UFID); | ||
210 | |||
211 | for (i = 0; i < NEQ; i++) { | ||
212 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); | ||
213 | }; | ||
214 | |||
215 | /* Get/print some final statistics on how the solve progressed */ | ||
216 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
217 | check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
218 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
219 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
220 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
221 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
222 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
223 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
224 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
225 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
226 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
227 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
228 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
229 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
230 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
231 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
232 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
233 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
234 | |||
235 | printf("\nFinal Solver Statistics:\n"); | ||
236 | printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); | ||
237 | printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); | ||
238 | printf(" Total linear solver setups = %li\n", nsetups); | ||
239 | printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); | ||
240 | printf(" Total number of Jacobian evaluations = %li\n", nje); | ||
241 | printf(" Total number of Newton iterations = %li\n", nni); | ||
242 | printf(" Total number of linear solver convergence failures = %li\n", ncfn); | ||
243 | printf(" Total number of error test failures = %li\n\n", netf); | ||
244 | |||
245 | /* Clean up and return */ | ||
246 | N_VDestroy(y); /* Free y vector */ | ||
247 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
248 | SUNLinSolFree(LS); /* Free linear solver */ | ||
249 | SUNMatDestroy(A); /* Free A matrix */ | ||
250 | |||
251 | return flag; | ||
252 | } |] | ||
253 | if res ==0 | ||
254 | then do | ||
255 | v <- V.freeze fMut | ||
256 | return $ Right v | ||
257 | else do | ||
258 | return $ Left res | ||
259 | |||
260 | main :: IO () | 19 | main :: IO () |
261 | main = do | 20 | main = do |
262 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined | 21 | let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0]) |
263 | putStrLn $ show res | 22 | putStrLn $ show res |