diff options
-rw-r--r-- | packages/sundials/hmatrix-sundials.cabal | 16 | ||||
-rw-r--r-- | packages/sundials/src/Main.hs | 21 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 265 | ||||
-rw-r--r-- | packages/sundials/src/Types.hs | 1 |
4 files changed, 297 insertions, 6 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal index d928ab1..331e6c4 100644 --- a/packages/sundials/hmatrix-sundials.cabal +++ b/packages/sundials/hmatrix-sundials.cabal | |||
@@ -18,6 +18,19 @@ cabal-version: >=1.10 | |||
18 | 18 | ||
19 | extra-source-files: src/helpers.c, src/helpers.h | 19 | extra-source-files: src/helpers.c, src/helpers.h |
20 | 20 | ||
21 | |||
22 | library | ||
23 | build-depends: base >=4.10 && <4.11, | ||
24 | inline-c >=0.6 && <0.7, | ||
25 | vector >=0.12 && <0.13, | ||
26 | template-haskell >=2.12 && <2.13, | ||
27 | containers >=0.5 && <0.6, | ||
28 | hmatrix>=0.18 | ||
29 | other-extensions: QuasiQuotes, TemplateHaskell, MultiWayIf, OverloadedStrings | ||
30 | hs-source-dirs: src | ||
31 | exposed-modules: Numeric.Sundials.Arkode.ODE | ||
32 | other-modules: Types | ||
33 | |||
21 | executable sundials | 34 | executable sundials |
22 | main-is: Main.hs | 35 | main-is: Main.hs |
23 | other-modules: Types | 36 | other-modules: Types |
@@ -26,7 +39,8 @@ executable sundials | |||
26 | inline-c >=0.6 && <0.7, | 39 | inline-c >=0.6 && <0.7, |
27 | vector >=0.12 && <0.13, | 40 | vector >=0.12 && <0.13, |
28 | template-haskell >=2.12 && <2.13, | 41 | template-haskell >=2.12 && <2.13, |
29 | containers >=0.5 && <0.6 | 42 | containers >=0.5 && <0.6, |
43 | hmatrix>=0.18 | ||
30 | hs-source-dirs: src | 44 | hs-source-dirs: src |
31 | default-language: Haskell2010 | 45 | default-language: Haskell2010 |
32 | extra-libraries: sundials_arkode | 46 | extra-libraries: sundials_arkode |
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index fc48710..d1f35bd 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -21,6 +21,8 @@ import System.IO.Unsafe (unsafePerformIO) | |||
21 | 21 | ||
22 | import Foreign.Storable (peekByteOff) | 22 | import Foreign.Storable (peekByteOff) |
23 | 23 | ||
24 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
25 | |||
24 | import qualified Types as T | 26 | import qualified Types as T |
25 | 27 | ||
26 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | 28 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) |
@@ -91,10 +93,19 @@ brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | |||
91 | v = x V.! 1 | 93 | v = x V.! 1 |
92 | w = x V.! 2 | 94 | w = x V.! 2 |
93 | 95 | ||
94 | solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | 96 | |
95 | V.Vector CDouble -> | 97 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
96 | Either CInt (V.Vector CDouble) | 98 | -> [Double] -- ^ initial conditions |
97 | solveOdeC fun f0 = unsafePerformIO $ do | 99 | -> Vector Double -- ^ desired solution times |
100 | -> Matrix Double -- ^ solution | ||
101 | odeSolve = undefined | ||
102 | |||
103 | solveOdeC :: | ||
104 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
105 | -> V.Vector CDouble -- ^ Initial conditions | ||
106 | -> V.Vector CDouble -- ^ Desired solution times | ||
107 | -> Either CInt (V.Vector CDouble) -- ^ Error code or solution | ||
108 | solveOdeC fun f0 ts = unsafePerformIO $ do | ||
98 | let dim = V.length f0 | 109 | let dim = V.length f0 |
99 | nEq :: CLong | 110 | nEq :: CLong |
100 | nEq = fromIntegral dim | 111 | nEq = fromIntegral dim |
@@ -248,5 +259,5 @@ solveOdeC fun f0 = unsafePerformIO $ do | |||
248 | 259 | ||
249 | main :: IO () | 260 | main :: IO () |
250 | main = do | 261 | main = do |
251 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) | 262 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined |
252 | putStrLn $ show res | 263 | putStrLn $ show res |
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs new file mode 100644 index 0000000..58acef3 --- /dev/null +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -0,0 +1,265 @@ | |||
1 | {-# OPTIONS_GHC -Wall #-} | ||
2 | |||
3 | {-# LANGUAGE QuasiQuotes #-} | ||
4 | {-# LANGUAGE TemplateHaskell #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | {-# LANGUAGE OverloadedStrings #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | |||
9 | module Numeric.Sundials.Arkode.ODE ( solveOdeC ) where | ||
10 | |||
11 | import qualified Language.C.Inline as C | ||
12 | import qualified Language.C.Inline.Unsafe as CU | ||
13 | import Data.Monoid ((<>)) | ||
14 | import Foreign.C.Types | ||
15 | import Foreign.Ptr (Ptr) | ||
16 | import qualified Data.Vector.Storable as V | ||
17 | |||
18 | import Data.Coerce (coerce) | ||
19 | import qualified Data.Vector.Storable.Mutable as VM | ||
20 | import Foreign.ForeignPtr (newForeignPtr_) | ||
21 | import Foreign.Storable (Storable) | ||
22 | import System.IO.Unsafe (unsafePerformIO) | ||
23 | |||
24 | import Foreign.Storable (peekByteOff) | ||
25 | |||
26 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
27 | |||
28 | import qualified Types as T | ||
29 | |||
30 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) | ||
31 | |||
32 | -- C includes | ||
33 | C.include "<stdio.h>" | ||
34 | C.include "<math.h>" | ||
35 | C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts. | ||
36 | C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros | ||
37 | C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix | ||
38 | C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver | ||
39 | C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface | ||
40 | C.include "<sundials/sundials_types.h>" -- definition of type realtype | ||
41 | C.include "<sundials/sundials_math.h>" | ||
42 | C.include "../../../helpers.h" | ||
43 | |||
44 | |||
45 | -- These were semi-generated using hsc2hs with Bar.hsc as the | ||
46 | -- template. They are probably very fragile and could easily break on | ||
47 | -- different architectures and / or changes in the sundials package. | ||
48 | |||
49 | getContentPtr :: Storable a => Ptr b -> IO a | ||
50 | getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr | ||
51 | |||
52 | getData :: Storable a => Ptr b -> IO a | ||
53 | getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr | ||
54 | |||
55 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) | ||
56 | getDataFromContents len ptr = do | ||
57 | qtr <- getContentPtr ptr | ||
58 | rtr <- getData qtr | ||
59 | vectorFromC len rtr | ||
60 | |||
61 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | ||
62 | putDataInContents vec len ptr = do | ||
63 | qtr <- getContentPtr ptr | ||
64 | rtr <- getData qtr | ||
65 | vectorToC vec len rtr | ||
66 | |||
67 | -- Utils | ||
68 | |||
69 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
70 | vectorFromC len ptr = do | ||
71 | ptr' <- newForeignPtr_ ptr | ||
72 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
73 | |||
74 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
75 | vectorToC vec len ptr = do | ||
76 | ptr' <- newForeignPtr_ ptr | ||
77 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||
78 | |||
79 | stiffish :: Double -> V.Vector Double -> V.Vector Double | ||
80 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
81 | where | ||
82 | u = v V.! 0 | ||
83 | lamda = -100.0 | ||
84 | |||
85 | brusselator :: Double -> V.Vector Double -> V.Vector Double | ||
86 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | ||
87 | , w * u - v * u^2 | ||
88 | , (b - w) / eps - w * u | ||
89 | ] | ||
90 | where | ||
91 | a = 1.0 | ||
92 | b = 3.5 | ||
93 | eps = 5.0e-6 | ||
94 | u = x V.! 0 | ||
95 | v = x V.! 1 | ||
96 | w = x V.! 2 | ||
97 | |||
98 | |||
99 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
100 | -> [Double] -- ^ initial conditions | ||
101 | -> Vector Double -- ^ desired solution times | ||
102 | -> Matrix Double -- ^ solution | ||
103 | odeSolve = undefined | ||
104 | |||
105 | solveOdeC :: | ||
106 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | ||
107 | -> V.Vector CDouble -- ^ Initial conditions | ||
108 | -> V.Vector CDouble -- ^ Desired solution times | ||
109 | -> Either CInt (V.Vector CDouble) -- ^ Error code or solution | ||
110 | solveOdeC fun f0 ts = unsafePerformIO $ do | ||
111 | let dim = V.length f0 | ||
112 | nEq :: CLong | ||
113 | nEq = fromIntegral dim | ||
114 | fMut <- V.thaw f0 | ||
115 | -- We need the types that sundials expects. These are tied together | ||
116 | -- in 'Types'. The Haskell type is currently empty! | ||
117 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | ||
118 | funIO x y f _ptr = do | ||
119 | -- Convert the pointer we get from C (y) to a vector, and then | ||
120 | -- apply the user-supplied function. | ||
121 | fImm <- fun x <$> getDataFromContents dim y | ||
122 | -- Fill in the provided pointer with the resulting vector. | ||
123 | putDataInContents fImm dim f | ||
124 | -- I don't understand what this comment means | ||
125 | -- Unsafe since the function will be called many times. | ||
126 | [CU.exp| int{ 0 } |] | ||
127 | res <- [C.block| int { | ||
128 | /* general problem variables */ | ||
129 | int flag; /* reusable error-checking flag */ | ||
130 | N_Vector y = NULL; /* empty vector for storing solution */ | ||
131 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | ||
132 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | ||
133 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | ||
134 | FILE *UFID; | ||
135 | realtype t, tout; | ||
136 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | ||
137 | |||
138 | /* general problem parameters */ | ||
139 | realtype T0 = RCONST(0.0); /* initial time */ | ||
140 | realtype Tf = RCONST(10.0); /* final time */ | ||
141 | realtype dTout = RCONST(1.0); /* time between outputs */ | ||
142 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | ||
143 | realtype reltol = 1.0e-6; /* tolerances */ | ||
144 | realtype abstol = 1.0e-10; | ||
145 | |||
146 | /* Initial diagnostics output */ | ||
147 | printf("\nAnalytical ODE test problem:\n"); | ||
148 | printf(" reltol = %.1"ESYM"\n", reltol); | ||
149 | printf(" abstol = %.1"ESYM"\n\n",abstol); | ||
150 | |||
151 | /* Initialize data structures */ | ||
152 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | ||
153 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | ||
154 | int i; | ||
155 | for (i = 0; i < NEQ; i++) { | ||
156 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | ||
157 | }; /* Specify initial condition */ | ||
158 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | ||
159 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | ||
160 | |||
161 | /* Call ARKodeInit to initialize the integrator memory and specify the */ | ||
162 | /* right-hand side function in y'=f(t,y), the inital time T0, and */ | ||
163 | /* the initial dependent variable vector y. Note: since this */ | ||
164 | /* problem is fully implicit, we set f_E to NULL and f_I to f. */ | ||
165 | |||
166 | /* Here we use the C types defined in helpers.h which tie up with */ | ||
167 | /* the Haskell types defined in Types */ | ||
168 | flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); | ||
169 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | ||
170 | |||
171 | /* Set routines */ | ||
172 | flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ | ||
173 | if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; | ||
174 | |||
175 | /* Initialize dense matrix data structure and solver */ | ||
176 | A = SUNDenseMatrix(NEQ, NEQ); | ||
177 | if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; | ||
178 | LS = SUNDenseLinearSolver(y, A); | ||
179 | if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; | ||
180 | |||
181 | /* Linear solver interface */ | ||
182 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | ||
183 | /* Open output stream for results, output comment line */ | ||
184 | UFID = fopen("solution.txt","w"); | ||
185 | fprintf(UFID,"# t u\n"); | ||
186 | |||
187 | /* output initial condition to disk */ | ||
188 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); | ||
189 | |||
190 | /* Main time-stepping loop: calls ARKode to perform the integration, then | ||
191 | prints results. Stops when the final time has been reached */ | ||
192 | t = T0; | ||
193 | tout = T0+dTout; | ||
194 | printf(" t u\n"); | ||
195 | printf(" ---------------------\n"); | ||
196 | while (Tf - t > 1.0e-15) { | ||
197 | |||
198 | flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ | ||
199 | if (check_flag(&flag, "ARKode", 1)) break; | ||
200 | printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ | ||
201 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); | ||
202 | if (flag >= 0) { /* successful solve: update time */ | ||
203 | tout += dTout; | ||
204 | tout = (tout > Tf) ? Tf : tout; | ||
205 | } else { /* unsuccessful solve: break */ | ||
206 | fprintf(stderr,"Solver failure, stopping integration\n"); | ||
207 | break; | ||
208 | } | ||
209 | } | ||
210 | printf(" ---------------------\n"); | ||
211 | fclose(UFID); | ||
212 | |||
213 | for (i = 0; i < NEQ; i++) { | ||
214 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); | ||
215 | }; | ||
216 | |||
217 | /* Get/print some final statistics on how the solve progressed */ | ||
218 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | ||
219 | check_flag(&flag, "ARKodeGetNumSteps", 1); | ||
220 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | ||
221 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | ||
222 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | ||
223 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | ||
224 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | ||
225 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | ||
226 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | ||
227 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | ||
228 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | ||
229 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | ||
230 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | ||
231 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | ||
232 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | ||
233 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | ||
234 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | ||
235 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | ||
236 | |||
237 | printf("\nFinal Solver Statistics:\n"); | ||
238 | printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); | ||
239 | printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); | ||
240 | printf(" Total linear solver setups = %li\n", nsetups); | ||
241 | printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); | ||
242 | printf(" Total number of Jacobian evaluations = %li\n", nje); | ||
243 | printf(" Total number of Newton iterations = %li\n", nni); | ||
244 | printf(" Total number of linear solver convergence failures = %li\n", ncfn); | ||
245 | printf(" Total number of error test failures = %li\n\n", netf); | ||
246 | |||
247 | /* Clean up and return */ | ||
248 | N_VDestroy(y); /* Free y vector */ | ||
249 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | ||
250 | SUNLinSolFree(LS); /* Free linear solver */ | ||
251 | SUNMatDestroy(A); /* Free A matrix */ | ||
252 | |||
253 | return flag; | ||
254 | } |] | ||
255 | if res ==0 | ||
256 | then do | ||
257 | v <- V.freeze fMut | ||
258 | return $ Right v | ||
259 | else do | ||
260 | return $ Left res | ||
261 | |||
262 | main :: IO () | ||
263 | main = do | ||
264 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined | ||
265 | putStrLn $ show res | ||
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs index c42c34e..9654527 100644 --- a/packages/sundials/src/Types.hs +++ b/packages/sundials/src/Types.hs | |||
@@ -4,6 +4,7 @@ | |||
4 | {-# LANGUAGE TemplateHaskell #-} | 4 | {-# LANGUAGE TemplateHaskell #-} |
5 | {-# LANGUAGE MultiWayIf #-} | 5 | {-# LANGUAGE MultiWayIf #-} |
6 | {-# LANGUAGE OverloadedStrings #-} | 6 | {-# LANGUAGE OverloadedStrings #-} |
7 | {-# LANGUAGE EmptyDataDecls #-} | ||
7 | 8 | ||
8 | module Types where | 9 | module Types where |
9 | 10 | ||