summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-04-26 13:53:56 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-04-26 13:53:56 +0100
commite8e631c7dbc7ea34b51dcce5fef5e2ec620f9458 (patch)
treef401ae26b7d65cfdc0af87f9387d361fa8d921e3
parent729eb192cf77d4cddf33d2724b4409ab7d828921 (diff)
Refactor CVODE
-rw-r--r--packages/sundials/src/Numeric/Sundials/CVode/ODE.hs75
-rw-r--r--packages/sundials/src/Numeric/Sundials/ODEOpts.hs4
2 files changed, 38 insertions, 41 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
index d7a2b53..0871f9b 100644
--- a/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/CVode/ODE.hs
@@ -87,11 +87,12 @@ import System.IO.Unsafe (unsafePerformIO)
87 87
88import Numeric.LinearAlgebra.Devel (createVector) 88import Numeric.LinearAlgebra.Devel (createVector)
89 89
90import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), 90import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, rows,
91 rows, cols, toLists, size) 91 cols, toLists, size, reshape)
92 92
93import qualified Types as T 93import qualified Types as T
94import Arkode (cV_ADAMS, cV_BDF) 94import Arkode (cV_ADAMS, cV_BDF)
95import Numeric.Sundials.ODEOpts (ODEOpts(..), Jacobian)
95import qualified Numeric.Sundials.ODEOpts as SO 96import qualified Numeric.Sundials.ODEOpts as SO
96 97
97 98
@@ -111,8 +112,6 @@ C.include "../../../helpers.h"
111C.include "Arkode_hsc.h" 112C.include "Arkode_hsc.h"
112 113
113 114
114type Jacobian = Double -> Vector Double -> Matrix Double
115
116-- | Stepping functions 115-- | Stepping functions
117data ODEMethod = ADAMS 116data ODEMethod = ADAMS
118 | BDF 117 | BDF
@@ -140,14 +139,8 @@ odeSolveV
140 -> Vector Double -- ^ desired solution times 139 -> Vector Double -- ^ desired solution times
141 -> Matrix Double -- ^ solution 140 -> Matrix Double -- ^ solution
142odeSolveV meth hi epsAbs epsRel f y0 ts = 141odeSolveV meth hi epsAbs epsRel f y0 ts =
143 case odeSolveVWith' meth (X epsAbs epsRel) hi g y0 ts of 142 odeSolveVWith meth (X epsAbs epsRel) hi g y0 ts
144 Left c -> error $ show c -- FIXME
145 -- FIXME: Can we do better than using lists?
146 Right (v, _d) -> (nR >< nC) (V.toList v)
147 where 143 where
148 us = toList ts
149 nR = length us
150 nC = size y0
151 g t x0 = coerce $ f t x0 144 g t x0 = coerce $ f t x0
152 145
153-- | A version of 'odeSolveV' with reasonable default parameters and 146-- | A version of 'odeSolveV' with reasonable default parameters and
@@ -160,13 +153,8 @@ odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y
160 -> Matrix Double -- ^ solution 153 -> Matrix Double -- ^ solution
161odeSolve f y0 ts = 154odeSolve f y0 ts =
162 -- FIXME: These tolerances are different from the ones in GSL 155 -- FIXME: These tolerances are different from the ones in GSL
163 case odeSolveVWith' BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts) of 156 odeSolveVWith BDF (XX' 1.0e-6 1.0e-10 1 1) Nothing g (V.fromList y0) (V.fromList $ toList ts)
164 Left c -> error $ show c -- FIXME
165 Right (v, _d) -> (nR >< nC) (V.toList v)
166 where 157 where
167 us = toList ts
168 nR = length us
169 nC = length y0
170 g t x0 = V.fromList $ f t (V.toList x0) 158 g t x0 = V.fromList $ f t (V.toList x0)
171 159
172odeSolveVWith :: 160odeSolveVWith ::
@@ -183,15 +171,20 @@ odeSolveVWith ::
183 -> V.Vector Double -- ^ Desired solution times 171 -> V.Vector Double -- ^ Desired solution times
184 -> Matrix Double -- ^ Error code or solution 172 -> Matrix Double -- ^ Error code or solution
185odeSolveVWith method control initStepSize f y0 tt = 173odeSolveVWith method control initStepSize f y0 tt =
186 case odeSolveVWith' method control initStepSize f y0 tt of 174 case odeSolveVWith' opts method control initStepSize f y0 tt of
187 Left c -> error $ show c -- FIXME 175 Left c -> error $ show c -- FIXME
188 Right (v, _d) -> (nR >< nC) (V.toList v) 176 Right (v, _d) -> v
189 where 177 where
190 nR = V.length tt 178 opts = ODEOpts { maxNumSteps = 10000
191 nC = V.length y0 179 , minStep = 1.0e-12
180 , relTol = error "relTol"
181 , absTols = error "absTol"
182 , initStep = error "initStep"
183 }
192 184
193odeSolveVWith' :: 185odeSolveVWith' ::
194 ODEMethod 186 ODEOpts
187 -> ODEMethod
195 -> StepControl 188 -> StepControl
196 -> Maybe Double -- ^ initial step size - by default, ARKode 189 -> Maybe Double -- ^ initial step size - by default, ARKode
197 -- estimates the initial step size to be the 190 -- estimates the initial step size to be the
@@ -202,19 +195,21 @@ odeSolveVWith' ::
202 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 195 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
203 -> V.Vector Double -- ^ Initial conditions 196 -> V.Vector Double -- ^ Initial conditions
204 -> V.Vector Double -- ^ Desired solution times 197 -> V.Vector Double -- ^ Desired solution times
205 -> Either Int ((V.Vector Double), SO.SundialsDiagnostics) -- ^ Error code or solution 198 -> Either Int (Matrix Double, SO.SundialsDiagnostics) -- ^ Error code or solution
206odeSolveVWith' method control initStepSize f y0 tt = 199odeSolveVWith' opts method control initStepSize f y0 tt =
207 case solveOdeC (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control) 200 case solveOdeC (fromIntegral $ maxNumSteps opts) (coerce $ minStep opts)
201 (fromIntegral $ getMethod method) (coerce initStepSize) jacH (scise control)
208 (coerce f) (coerce y0) (coerce tt) of 202 (coerce f) (coerce y0) (coerce tt) of
209 Left c -> Left $ fromIntegral c 203 Left c -> Left $ fromIntegral c
210 Right (v, d) -> Right (coerce v, d) 204 Right (v, d) -> Right (reshape nC (coerce v), d)
211 where 205 where
206 nC = V.length y0
212 l = size y0 207 l = size y0
213 scise (X absTol relTol) = coerce (V.replicate l absTol, relTol) 208 scise (X aTol rTol) = coerce (V.replicate l aTol, rTol)
214 scise (X' absTol relTol) = coerce (V.replicate l absTol, relTol) 209 scise (X' aTol rTol) = coerce (V.replicate l aTol, rTol)
215 scise (XX' absTol relTol yScale _yDotScale) = coerce (V.replicate l absTol, yScale * relTol) 210 scise (XX' aTol rTol yScale _yDotScale) = coerce (V.replicate l aTol, yScale * rTol)
216 -- FIXME; Should we check that the length of ss is correct? 211 -- FIXME; Should we check that the length of ss is correct?
217 scise (ScXX' absTol relTol yScale _yDotScale ss) = coerce (V.map (* absTol) ss, yScale * relTol) 212 scise (ScXX' aTol rTol yScale _yDotScale ss) = coerce (V.map (* aTol) ss, yScale * rTol)
218 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $ 213 jacH = fmap (\g t v -> matrixToSunMatrix $ g (coerce t) (coerce v)) $
219 getJacobian method 214 getJacobian method
220 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } 215 matrixToSunMatrix m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
@@ -225,6 +220,8 @@ odeSolveVWith' method control initStepSize f y0 tt =
225 vs = V.fromList $ map coerce $ concat $ toLists m 220 vs = V.fromList $ map coerce $ concat $ toLists m
226 221
227solveOdeC :: 222solveOdeC ::
223 CLong ->
224 CDouble ->
228 CInt -> 225 CInt ->
229 Maybe CDouble -> 226 Maybe CDouble ->
230 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) -> 227 (Maybe (CDouble -> V.Vector CDouble -> T.SunMatrix)) ->
@@ -233,7 +230,8 @@ solveOdeC ::
233 -> V.Vector CDouble -- ^ Initial conditions 230 -> V.Vector CDouble -- ^ Initial conditions
234 -> V.Vector CDouble -- ^ Desired solution times 231 -> V.Vector CDouble -- ^ Desired solution times
235 -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution 232 -> Either CInt ((V.Vector CDouble), SO.SundialsDiagnostics) -- ^ Error code or solution
236solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do 233solveOdeC maxNumSteps_ minStep_ method initStepSize jacH (aTols, rTol) fun f0 ts =
234 unsafePerformIO $ do
237 235
238 let isInitStepSize :: CInt 236 let isInitStepSize :: CInt
239 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize 237 isInitStepSize = fromIntegral $ fromEnum $ isJust initStepSize
@@ -249,10 +247,6 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
249 nEq = fromIntegral dim 247 nEq = fromIntegral dim
250 nTs :: CInt 248 nTs :: CInt
251 nTs = fromIntegral $ V.length ts 249 nTs = fromIntegral $ V.length ts
252 -- FIXME: tMut is not actually mutatated?
253 tMut <- V.thaw ts
254 -- FIXME: I believe this gets taken from the ghc heap and so should
255 -- be subject to garbage collection.
256 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) 250 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
257 qMatMut <- V.thaw quasiMatrixRes 251 qMatMut <- V.thaw quasiMatrixRes
258 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME 252 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
@@ -324,18 +318,17 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
324 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1; 318 if (check_flag((void *)tv, "N_VNew_Serial", 0)) return 1;
325 /* Specify tolerances */ 319 /* Specify tolerances */
326 for (i = 0; i < NEQ; i++) { 320 for (i = 0; i < NEQ; i++) {
327 NV_Ith_S(tv,i) = ($vec-ptr:(double *absTols))[i]; 321 NV_Ith_S(tv,i) = ($vec-ptr:(double *aTols))[i];
328 }; 322 };
329 323
330 /* FIXME: A hack for initial testing */ 324 flag = CVodeSetMinStep(cvode_mem, $(double minStep_));
331 flag = CVodeSetMinStep(cvode_mem, 1.0e-12);
332 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1; 325 if (check_flag(&flag, "CVodeSetMinStep", 1)) return 1;
333 flag = CVodeSetMaxNumSteps(cvode_mem, 10000); 326 flag = CVodeSetMaxNumSteps(cvode_mem, $(long int maxNumSteps_));
334 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1; 327 if (check_flag(&flag, "CVodeSetMaxNumSteps", 1)) return 1;
335 328
336 /* Call CVodeSVtolerances to specify the scalar relative tolerance 329 /* Call CVodeSVtolerances to specify the scalar relative tolerance
337 * and vector absolute tolerances */ 330 * and vector absolute tolerances */
338 flag = CVodeSVtolerances(cvode_mem, $(double relTol), tv); 331 flag = CVodeSVtolerances(cvode_mem, $(double rTol), tv);
339 if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1); 332 if (check_flag(&flag, "CVodeSVtolerances", 1)) return(1);
340 333
341 /* Initialize dense matrix data structure and solver */ 334 /* Initialize dense matrix data structure and solver */
@@ -371,7 +364,7 @@ solveOdeC method initStepSize jacH (absTols, relTol) fun f0 ts = unsafePerformIO
371 /* Stops when the final time has been reached */ 364 /* Stops when the final time has been reached */
372 for (i = 1; i < $(int nTs); i++) { 365 for (i = 1; i < $(int nTs); i++) {
373 366
374 flag = CVode(cvode_mem, ($vec-ptr:(double *tMut))[i], y, &t, CV_NORMAL); /* call integrator */ 367 flag = CVode(cvode_mem, ($vec-ptr:(double *ts))[i], y, &t, CV_NORMAL); /* call integrator */
375 if (check_flag(&flag, "CVode", 1)) break; 368 if (check_flag(&flag, "CVode", 1)) break;
376 369
377 /* Store the results for Haskell */ 370 /* Store the results for Haskell */
diff --git a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
index e924292..538b474 100644
--- a/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
+++ b/packages/sundials/src/Numeric/Sundials/ODEOpts.hs
@@ -8,9 +8,13 @@ import Foreign.C.Types
8import qualified Data.Vector.Storable as VS 8import qualified Data.Vector.Storable as VS
9import qualified Data.Vector.Storable.Mutable as VM 9import qualified Data.Vector.Storable.Mutable as VM
10 10
11import Numeric.LinearAlgebra.HMatrix (Vector, Matrix)
12
11import qualified Types as T 13import qualified Types as T
12import qualified Arkode as B 14import qualified Arkode as B
13 15
16type Jacobian = Double -> Vector Double -> Matrix Double
17
14data ODEOpts = ODEOpts { 18data ODEOpts = ODEOpts {
15 maxNumSteps :: Int32 19 maxNumSteps :: Int32
16 , minStep :: Double 20 , minStep :: Double