diff options
-rw-r--r-- | packages/sundials/src/Main.hs | 245 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 45 |
2 files changed, 18 insertions, 272 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index d1f35bd..978088b 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -1,84 +1,7 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | 1 | {-# OPTIONS_GHC -Wall #-} |
2 | 2 | ||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | |||
9 | import qualified Language.C.Inline as C | ||
10 | import qualified Language.C.Inline.Unsafe as CU | ||
11 | import Data.Monoid ((<>)) | ||
12 | import Foreign.C.Types | ||
13 | import Foreign.Ptr (Ptr) | ||
14 | import qualified Data.Vector.Storable as V | 3 | import qualified Data.Vector.Storable as V |
15 | 4 | import Numeric.Sundials.Arkode.ODE | |
16 | import Data.Coerce (coerce) | ||
17 | import qualified Data.Vector.Storable.Mutable as VM | ||
18 | import Foreign.ForeignPtr (newForeignPtr_) | ||
19 | import Foreign.Storable (Storable) | ||
20 | import System.IO.Unsafe (unsafePerformIO) | ||
21 | |||
22 | import Foreign.Storable (peekByteOff) | ||
23 | |||
24 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
25 | |||
26 | import qualified Types as T | ||
27 | |||
28 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
29 | |||
30 | -- C includes | ||
31 | C.include "<stdio.h>" | ||
32 | C.include "<math.h>" | ||
33 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
34 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
35 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
36 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
37 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
38 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
39 | C.include "<sundials/sundials_math.h>" | ||
40 | C.include "helpers.h" | ||
41 | |||
42 | |||
43 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
44 | -- template. They are probably very fragile and could easily break on | ||
45 | -- different architectures and / or changes in the sundials package. | ||
46 | |||
47 | getContentPtr :: Storable a => Ptr b -> IO a | ||
48 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
49 | |||
50 | getData :: Storable a => Ptr b -> IO a | ||
51 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
52 | |||
53 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) | ||
54 | getDataFromContents len ptr = do | ||
55 | qtr <- getContentPtr ptr | ||
56 | rtr <- getData qtr | ||
57 | vectorFromC len rtr | ||
58 | |||
59 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
60 | putDataInContents vec len ptr = do | ||
61 | qtr <- getContentPtr ptr | ||
62 | rtr <- getData qtr | ||
63 | vectorToC vec len rtr | ||
64 | |||
65 | -- Utils | ||
66 | |||
67 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
68 | vectorFromC len ptr = do | ||
69 | ptr' <- newForeignPtr_ ptr | ||
70 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
71 | |||
72 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
73 | vectorToC vec len ptr = do | ||
74 | ptr' <- newForeignPtr_ ptr | ||
75 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
76 | |||
77 | stiffish :: Double -> V.Vector Double -> V.Vector Double | ||
78 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
79 | where | ||
80 | u = v V.! 0 | ||
81 | lamda = -100.0 | ||
82 | 5 | ||
83 | brusselator :: Double -> V.Vector Double -> V.Vector Double | 6 | brusselator :: Double -> V.Vector Double -> V.Vector Double |
84 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | 7 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 |
@@ -93,171 +16,7 @@ brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | |||
93 | v = x V.! 1 | 16 | v = x V.! 1 |
94 | w = x V.! 2 | 17 | w = x V.! 2 |
95 | 18 | ||
96 | |||
97 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
98 | -> [Double] -- ^ initial conditions | ||
99 | -> Vector Double -- ^ desired solution times | ||
100 | -> Matrix Double -- ^ solution | ||
101 | odeSolve = undefined | ||
102 | |||
103 | solveOdeC :: | ||
104 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
105 | -> V.Vector CDouble -- ^ Initial conditions | ||
106 | -> V.Vector CDouble -- ^ Desired solution times | ||
107 | -> Either CInt (V.Vector CDouble) -- ^ Error code or solution | ||
108 | solveOdeC fun f0 ts = unsafePerformIO $ do | ||
109 | let dim = V.length f0 | ||
110 | nEq :: CLong | ||
111 | nEq = fromIntegral dim | ||
112 | fMut <- V.thaw f0 | ||
113 | -- We need the types that sundials expects. These are tied together | ||
114 | -- in 'Types'. The Haskell type is currently empty! | ||
115 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
116 | funIO x y f _ptr = do | ||
117 | -- Convert the pointer we get from C (y) to a vector, and then | ||
118 | -- apply the user-supplied function. | ||
119 | fImm <- fun x <$> getDataFromContents dim y | ||
120 | -- Fill in the provided pointer with the resulting vector. | ||
121 | putDataInContents fImm dim f | ||
122 | -- I don't understand what this comment means | ||
123 | -- Unsafe since the function will be called many times. | ||
124 | [CU.exp| int{ 0 } |] | ||
125 | res <- [C.block| int { | ||
126 | /* general problem variables */ | ||
127 | int flag; /* reusable error-checking flag */ | ||
128 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
129 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
130 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
131 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
132 | FILE *UFID; | ||
133 | realtype t, tout; | ||
134 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
135 | |||
136 | /* general problem parameters */ | ||
137 | realtype T0 = RCONST(0.0); /* initial time */ | ||
138 | realtype Tf = RCONST(10.0); /* final time */ | ||
139 | realtype dTout = RCONST(1.0); /* time between outputs */ | ||
140 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
141 | realtype reltol = 1.0e-6; /* tolerances */ | ||
142 | realtype abstol = 1.0e-10; | ||
143 | |||
144 | /* Initial diagnostics output */ | ||
145 | printf("\nAnalytical ODE test problem:\n"); | ||
146 | printf(" reltol = %.1"ESYM"\n", reltol); | ||
147 | printf(" abstol = %.1"ESYM"\n\n",abstol); | ||
148 | |||
149 | /* Initialize data structures */ | ||
150 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
151 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
152 | int i; | ||
153 | for (i = 0; i < NEQ; i++) { | ||
154 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | ||
155 | }; /* Specify initial condition */ | ||
156 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
157 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
158 | |||
159 | /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
160 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
161 | /* the initial dependent variable vector y. Note: since this */ | ||
162 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | ||
163 | |||
164 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
165 | /* the Haskell types defined in Types */ | ||
166 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
167 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
168 | |||
169 | /* Set routines */ | ||
170 | flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ | ||
171 | if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; | ||
172 | |||
173 | /* Initialize dense matrix data structure and solver */ | ||
174 | A = SUNDenseMatrix(NEQ, NEQ); | ||
175 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
176 | LS = SUNDenseLinearSolver(y, A); | ||
177 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
178 | |||
179 | /* Linear solver interface */ | ||
180 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | ||
181 | /* Open output stream for results, output comment line */ | ||
182 | UFID = fopen("solution.txt","w"); | ||
183 | fprintf(UFID,"# t u\n"); | ||
184 | |||
185 | /* output initial condition to disk */ | ||
186 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); | ||
187 | |||
188 | /* Main time-stepping loop: calls ARKode to perform the integration, then | ||
189 | prints results. Stops when the final time has been reached */ | ||
190 | t = T0; | ||
191 | tout = T0+dTout; | ||
192 | printf(" t u\n"); | ||
193 | printf(" ---------------------\n"); | ||
194 | while (Tf - t > 1.0e-15) { | ||
195 | |||
196 | flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ | ||
197 | if (check_flag(&flag, "ARKode", 1)) break; | ||
198 | printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ | ||
199 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); | ||
200 | if (flag >= 0) { /* successful solve: update time */ | ||
201 | tout += dTout; | ||
202 | tout = (tout > Tf) ? Tf : tout; | ||
203 | } else { /* unsuccessful solve: break */ | ||
204 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
205 | break; | ||
206 | } | ||
207 | } | ||
208 | printf(" ---------------------\n"); | ||
209 | fclose(UFID); | ||
210 | |||
211 | for (i = 0; i < NEQ; i++) { | ||
212 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); | ||
213 | }; | ||
214 | |||
215 | /* Get/print some final statistics on how the solve progressed */ | ||
216 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
217 | check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
218 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
219 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
220 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
221 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
222 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
223 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
224 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
225 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
226 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
227 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
228 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
229 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
230 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
231 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
232 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
233 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
234 | |||
235 | printf("\nFinal Solver Statistics:\n"); | ||
236 | printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); | ||
237 | printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); | ||
238 | printf(" Total linear solver setups = %li\n", nsetups); | ||
239 | printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); | ||
240 | printf(" Total number of Jacobian evaluations = %li\n", nje); | ||
241 | printf(" Total number of Newton iterations = %li\n", nni); | ||
242 | printf(" Total number of linear solver convergence failures = %li\n", ncfn); | ||
243 | printf(" Total number of error test failures = %li\n\n", netf); | ||
244 | |||
245 | /* Clean up and return */ | ||
246 | N_VDestroy(y); /* Free y vector */ | ||
247 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
248 | SUNLinSolFree(LS); /* Free linear solver */ | ||
249 | SUNMatDestroy(A); /* Free A matrix */ | ||
250 | |||
251 | return flag; | ||
252 | } |] | ||
253 | if res ==0 | ||
254 | then do | ||
255 | v <- V.freeze fMut | ||
256 | return $ Right v | ||
257 | else do | ||
258 | return $ Left res | ||
259 | |||
260 | main :: IO () | 19 | main :: IO () |
261 | main = do | 20 | main = do |
262 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined | 21 | let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0]) |
263 | putStrLn $ show res | 22 | putStrLn $ show res |
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 58acef3..9de20b6 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -6,7 +6,7 @@ | |||
6 | {-# LANGUAGE OverloadedStrings #-} | 6 | {-# LANGUAGE OverloadedStrings #-} |
7 | {-# LANGUAGE ScopedTypeVariables #-} | 7 | {-# LANGUAGE ScopedTypeVariables #-} |
8 | 8 | ||
9 | module Numeric.Sundials.Arkode.ODE ( solveOdeC ) where | 9 | module Numeric.Sundials.Arkode.ODE ( solveOde ) where |
10 | 10 | ||
11 | import qualified Language.C.Inline as C | 11 | import qualified Language.C.Inline as C |
12 | import qualified Language.C.Inline.Unsafe as CU | 12 | import qualified Language.C.Inline.Unsafe as CU |
@@ -76,32 +76,21 @@ vectorToC vec len ptr = do | |||
76 | ptr' <- newForeignPtr_ ptr | 76 | ptr' <- newForeignPtr_ ptr |
77 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 77 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
78 | 78 | ||
79 | stiffish :: Double -> V.Vector Double -> V.Vector Double | ||
80 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
81 | where | ||
82 | u = v V.! 0 | ||
83 | lamda = -100.0 | ||
84 | |||
85 | brusselator :: Double -> V.Vector Double -> V.Vector Double | ||
86 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | ||
87 | , w * u - v * u^2 | ||
88 | , (b - w) / eps - w * u | ||
89 | ] | ||
90 | where | ||
91 | a = 1.0 | ||
92 | b = 3.5 | ||
93 | eps = 5.0e-6 | ||
94 | u = x V.! 0 | ||
95 | v = x V.! 1 | ||
96 | w = x V.! 2 | ||
97 | |||
98 | |||
99 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 79 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
100 | -> [Double] -- ^ initial conditions | 80 | -> [Double] -- ^ initial conditions |
101 | -> Vector Double -- ^ desired solution times | 81 | -> Vector Double -- ^ desired solution times |
102 | -> Matrix Double -- ^ solution | 82 | -> Matrix Double -- ^ solution |
103 | odeSolve = undefined | 83 | odeSolve = undefined |
104 | 84 | ||
85 | solveOde :: | ||
86 | (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
87 | -> V.Vector Double -- ^ Initial conditions | ||
88 | -> V.Vector Double -- ^ Desired solution times | ||
89 | -> Either Int (V.Vector Double) -- ^ Error code or solution | ||
90 | solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of | ||
91 | Left c -> Left $ fromIntegral c | ||
92 | Right v -> Right $ coerce v | ||
93 | |||
105 | solveOdeC :: | 94 | solveOdeC :: |
106 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 95 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
107 | -> V.Vector CDouble -- ^ Initial conditions | 96 | -> V.Vector CDouble -- ^ Initial conditions |
@@ -111,7 +100,10 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
111 | let dim = V.length f0 | 100 | let dim = V.length f0 |
112 | nEq :: CLong | 101 | nEq :: CLong |
113 | nEq = fromIntegral dim | 102 | nEq = fromIntegral dim |
103 | nTs :: CInt | ||
104 | nTs = fromIntegral $ V.length ts | ||
114 | fMut <- V.thaw f0 | 105 | fMut <- V.thaw f0 |
106 | tMut <- V.thaw ts | ||
115 | -- We need the types that sundials expects. These are tied together | 107 | -- We need the types that sundials expects. These are tied together |
116 | -- in 'Types'. The Haskell type is currently empty! | 108 | -- in 'Types'. The Haskell type is currently empty! |
117 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 109 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt |
@@ -136,8 +128,8 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
136 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | 128 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; |
137 | 129 | ||
138 | /* general problem parameters */ | 130 | /* general problem parameters */ |
139 | realtype T0 = RCONST(0.0); /* initial time */ | 131 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ |
140 | realtype Tf = RCONST(10.0); /* final time */ | 132 | realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ |
141 | realtype dTout = RCONST(1.0); /* time between outputs */ | 133 | realtype dTout = RCONST(1.0); /* time between outputs */ |
142 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | 134 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ |
143 | realtype reltol = 1.0e-6; /* tolerances */ | 135 | realtype reltol = 1.0e-6; /* tolerances */ |
@@ -193,7 +185,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
193 | tout = T0+dTout; | 185 | tout = T0+dTout; |
194 | printf(" t u\n"); | 186 | printf(" t u\n"); |
195 | printf(" ---------------------\n"); | 187 | printf(" ---------------------\n"); |
196 | while (Tf - t > 1.0e-15) { | 188 | for (i = 0; i < $(int nTs); i++) { |
197 | 189 | ||
198 | flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ | 190 | flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ |
199 | if (check_flag(&flag, "ARKode", 1)) break; | 191 | if (check_flag(&flag, "ARKode", 1)) break; |
@@ -258,8 +250,3 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
258 | return $ Right v | 250 | return $ Right v |
259 | else do | 251 | else do |
260 | return $ Left res | 252 | return $ Left res |
261 | |||
262 | main :: IO () | ||
263 | main = do | ||
264 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined | ||
265 | putStrLn $ show res | ||