diff options
Diffstat (limited to 'packages/sundials')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | 75 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ODEOpts.hs | 4 |
2 files changed, 38 insertions, 41 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs index d7a2b53..0871f9b 100644 --- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs | |||
@@ -87,11 +87,12 @@ import System.IO.Unsafe (unsafePerformIO) | |||
87 | 87 | ||
88 | import Numeric.LinearAlgebra.Devel (createVector) | 88 | import Numeric.LinearAlgebra.Devel (createVector) |
89 | 89 | ||
90 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), | 90 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows, |
91 | rows, cols, toLists, size) | 91 | cols, toLists, size, reshape) |
92 | 92 | ||
93 | import qualified Types as T | 93 | import qualified Types as T |
94 | import Arkode (cV_ADAMS, cV_BDF) | 94 | import Arkode (cV_ADAMS, cV_BDF) |
95 | import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian) | ||
95 | import qualified Numeric.Sundials.ODEOpts as SO | 96 | import qualified Numeric.Sundials.ODEOpts as SO |
96 | 97 | ||
97 | 98 | ||
@@ -111,8 +112,6 @@ C.include "../../../helpers.h" | |||
111 | C.include "Arkode_hsc.h" | 112 | C.include "Arkode_hsc.h" |
112 | 113 | ||
113 | 114 | ||
114 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
115 | |||
116 | -- | Stepping functions | 115 | -- | Stepping functions |
117 | data ODEMethod = ADAMS | 116 | data ODEMethod = ADAMS |
118 | | BDF | 117 | | BDF |
@@ -140,14 +139,8 @@ odeSolveV | |||
140 | -> Vector Double -- ^ desired solution times | 139 | -> Vector Double -- ^ desired solution times |
141 | -> Matrix Double -- ^ solution | 140 | -> Matrix Double -- ^ solution |
142 | odeSolveV meth hi epsAbs epsRel f y0 ts = | 141 | odeSolveV meth hi epsAbs epsRel f y0 ts = |
143 | case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of | 142 | odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts |
144 | Left c -> error $ show c -- FIXME | ||
145 | -- FIXME: Can we do better than using lists? | ||
146 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
147 | where | 143 | where |
148 | us = toList ts | ||
149 | nR = length us | ||
150 | nC = size y0 | ||
151 | g t x0 = coerce $ f t x0 | 144 | g t x0 = coerce $ f t x0 |
152 | 145 | ||
153 | -- | A version of 'odeSolveV' with reasonable default parameters and | 146 | -- | A version of 'odeSolveV' with reasonable default parameters and |
@@ -160,13 +153,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y | |||
160 | -> Matrix Double -- ^ solution | 153 | -> Matrix Double -- ^ solution |
161 | odeSolve f y0 ts = | 154 | odeSolve f y0 ts = |
162 | -- FIXME: These tolerances are different from the ones in GSL | 155 | -- FIXME: These tolerances are different from the ones in GSL |
163 | case odeSolveVWith' BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of | 156 | odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) |
164 | Left c -> error $ show c -- FIXME | ||
165 | Right (v, _d) -> (nR >< nC) (V.toList v) | ||
166 | where | 157 | where |
167 | us = toList ts | ||
168 | nR = length us | ||
169 | nC = length y0 | ||
170 | g t x0 = V.fromList $ f t (V.toList x0) | 158 | g t x0 = V.fromList $ f t (V.toList x0) |
171 | 159 | ||
172 | odeSolveVWith :: | 160 | odeSolveVWith :: |
@@ -183,15 +171,20 @@ odeSolveVWith :: | |||
183 | -> V.Vector Double -- ^ Desired solution times | 171 | -> V.Vector Double -- ^ Desired solution times |
184 | -> Matrix Double -- ^ Error code or solution | 172 | -> Matrix Double -- ^ Error code or solution |
185 | odeSolveVWith method control initStepSize f y0 tt = | 173 | odeSolveVWith method control initStepSize f y0 tt = |
186 | case odeSolveVWith' method control initStepSize f y0 tt of | 174 | case odeSolveVWith' opts method control initStepSize f y0 tt of |
187 | Left c -> error $ show c -- FIXME | 175 | Left c -> error $ show c -- FIXME |
188 | Right (v, _d) -> (nR >< nC) (V.toList v) | 176 | Right (v, _d) -> v |
189 | where | 177 | where |
190 | nR = V.length tt | 178 | opts = ODEOpts { maxNumSteps = 10000 |
191 | nC = V.length y0 | 179 | , minStep = 1.0e-12 |
180 | , relTol = error "relTol" | ||
181 | , absTols = error "absTol" | ||
182 | , initStep = error "initStep" | ||
183 | } | ||
192 | 184 | ||
193 | odeSolveVWith' :: | 185 | odeSolveVWith' :: |
194 | ODEMethod | 186 | ODEOpts |
187 | -> ODEMethod | ||
195 | -> StepControl | 188 | -> StepControl |
196 | -> Maybe Double -- ^ initial step size - by default, ARKode | 189 | -> Maybe Double -- ^ initial step size - by default, ARKode |
197 | -- estimates the initial step size to be the | 190 | -- estimates the initial step size to be the |
@@ -202,19 +195,21 @@ odeSolveVWith' :: | |||
202 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 195 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
203 | -> V.Vector Double -- ^ Initial conditions | 196 | -> V.Vector Double -- ^ Initial conditions |
204 | -> V.Vector Double -- ^ Desired solution times | 197 | -> V.Vector Double -- ^ Desired solution times |
205 | -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution | 198 | -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution |
206 | odeSolveVWith' method control initStepSize f y0 tt = | 199 | odeSolveVWith' opts method control initStepSize f y0 tt = |
207 | case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | 200 | case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts) |
201 | (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) | ||
208 | (coerce f) (coerce y0) (coerce tt) of | 202 | (coerce f) (coerce y0) (coerce tt) of |
209 | Left c -> Left $ fromIntegral c | 203 | Left c -> Left $ fromIntegral c |
210 | Right (v, d) -> Right (coerce v, d) | 204 | Right (v, d) -> Right (reshape nC (coerce v), d) |
211 | where | 205 | where |
206 | nC = V.length y0 | ||
212 | l = size y0 | 207 | l = size y0 |
213 | scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) | 208 | scise (X aTol rTol) = coerce (V.replicate l aTol, rTol) |
214 | scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) | 209 | scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol) |
215 | scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) | 210 | scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol) |
216 | -- FIXME; Should we check that the length of ss is correct? | 211 | -- FIXME; Should we check that the length of ss is correct? |
217 | scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) | 212 | scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol) |
218 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ | 213 | jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ |
219 | getJacobian method | 214 | getJacobian method |
220 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | 215 | matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } |
@@ -225,6 +220,8 @@ odeSolveVWith' method control initStepSize f y0 tt = | |||
225 | vs = V.fromList $ map coerce $ concat $ toLists m | 220 | vs = V.fromList $ map coerce $ concat $ toLists m |
226 | 221 | ||
227 | solveOdeC :: | 222 | solveOdeC :: |
223 | CLong -> | ||
224 | CDouble -> | ||
228 | CInt -> | 225 | CInt -> |
229 | Maybe CDouble -> | 226 | Maybe CDouble -> |
230 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> | 227 | (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> |
@@ -233,7 +230,8 @@ solveOdeC :: | |||
233 | -> V.Vector CDouble -- ^ Initial conditions | 230 | -> V.Vector CDouble -- ^ Initial conditions |
234 | -> V.Vector CDouble -- ^ Desired solution times | 231 | -> V.Vector CDouble -- ^ Desired solution times |
235 | -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution | 232 | -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution |
236 | solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do | 233 | solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts = |
234 | unsafePerformIO $ do | ||
237 | 235 | ||
238 | let isInitStepSize :: CInt | 236 | let isInitStepSize :: CInt |
239 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize | 237 | isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize |
@@ -249,10 +247,6 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
249 | nEq = fromIntegral dim | 247 | nEq = fromIntegral dim |
250 | nTs :: CInt | 248 | nTs :: CInt |
251 | nTs = fromIntegral $ V.length ts | 249 | nTs = fromIntegral $ V.length ts |
252 | -- FIXME: tMut is not actually mutatated? | ||
253 | tMut <- V.thaw ts | ||
254 | -- FIXME: I believe this gets taken from the ghc heap and so should | ||
255 | -- be subject to garbage collection. | ||
256 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | 250 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) |
257 | qMatMut <- V.thaw quasiMatrixRes | 251 | qMatMut <- V.thaw quasiMatrixRes |
258 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | 252 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME |
@@ -324,18 +318,17 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
324 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; | 318 | if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; |
325 | /* Specify tolerances */ | 319 | /* Specify tolerances */ |
326 | for (i = 0; i < NEQ; i++) { | 320 | for (i = 0; i < NEQ; i++) { |
327 | NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; | 321 | NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i]; |
328 | }; | 322 | }; |
329 | 323 | ||
330 | /* FIXME: A hack for initial testing */ | 324 | flag = CVodeSetMinStep(cvode_mem, $(double minStep_)); |
331 | flag = CVodeSetMinStep(cvode_mem, 1.0e-12); | ||
332 | if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; | 325 | if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; |
333 | flag = CVodeSetMaxNumSteps(cvode_mem, 10000); | 326 | flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_)); |
334 | if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; | 327 | if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; |
335 | 328 | ||
336 | /* Call CVodeSVtolerances to specify the scalar relative tolerance | 329 | /* Call CVodeSVtolerances to specify the scalar relative tolerance |
337 | * and vector absolute tolerances */ | 330 | * and vector absolute tolerances */ |
338 | flag = CVodeSVtolerances(cvode_mem, $(double relTol), tv); | 331 | flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv); |
339 | if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); | 332 | if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); |
340 | 333 | ||
341 | /* Initialize dense matrix data structure and solver */ | 334 | /* Initialize dense matrix data structure and solver */ |
@@ -371,7 +364,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO | |||
371 | /* Stops when the final time has been reached */ | 364 | /* Stops when the final time has been reached */ |
372 | for (i = 1; i < $(int nTs); i++) { | 365 | for (i = 1; i < $(int nTs); i++) { |
373 | 366 | ||
374 | flag = CVode(cvode_mem, ($vec-ptr:(double *tMut))[i], y, &t, CV_NORMAL); /* call integrator */ | 367 | flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */ |
375 | if (check_flag(&flag, "CVode", 1)) break; | 368 | if (check_flag(&flag, "CVode", 1)) break; |
376 | 369 | ||
377 | /* Store the results for Haskell */ | 370 | /* Store the results for Haskell */ |
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs index e924292..538b474 100644 --- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs +++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs | |||
@@ -8,9 +8,13 @@ import Foreign.C.Types | |||
8 | import qualified Data.Vector.Storable as VS | 8 | import qualified Data.Vector.Storable as VS |
9 | import qualified Data.Vector.Storable.Mutable as VM | 9 | import qualified Data.Vector.Storable.Mutable as VM |
10 | 10 | ||
11 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | ||
12 | |||
11 | import qualified Types as T | 13 | import qualified Types as T |
12 | import qualified Arkode as B | 14 | import qualified Arkode as B |
13 | 15 | ||
16 | type Jacobian = Double -> Vector Double -> Matrix Double | ||
17 | |||
14 | data ODEOpts = ODEOpts { | 18 | data ODEOpts = ODEOpts { |
15 | maxNumSteps :: Int32 | 19 | maxNumSteps :: Int32 |
16 | , minStep :: Double | 20 | , minStep :: Double |