diff options
Diffstat (limited to 'packages/sundials/src')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 90 |
1 files changed, 60 insertions, 30 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index c5d085e..630827c 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -79,6 +79,19 @@ vectorToC vec len ptr = do | |||
79 | ptr' <- newForeignPtr_ ptr | 79 | ptr' <- newForeignPtr_ ptr |
80 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 80 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
81 | 81 | ||
82 | data SundialsDiagnostics = SundialsDiagnostics { | ||
83 | aRKodeGetNumSteps :: Int | ||
84 | , aRKodeGetNumStepAttempts :: Int | ||
85 | , aRKodeGetNumRhsEvals_fe :: Int | ||
86 | , aRKodeGetNumRhsEvals_fi :: Int | ||
87 | , aRKodeGetNumLinSolvSetups :: Int | ||
88 | , aRKodeGetNumErrTestFails :: Int | ||
89 | , aRKodeGetNumNonlinSolvIters :: Int | ||
90 | , aRKodeGetNumNonlinSolvConvFails :: Int | ||
91 | , aRKDlsGetNumJacEvals :: Int | ||
92 | , aRKDlsGetNumRhsEvals :: Int | ||
93 | } deriving Show | ||
94 | |||
82 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 95 | odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
83 | -> [Double] -- ^ initial conditions | 96 | -> [Double] -- ^ initial conditions |
84 | -> Vector Double -- ^ desired solution times | 97 | -> Vector Double -- ^ desired solution times |
@@ -89,28 +102,31 @@ solveOde :: | |||
89 | (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 102 | (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
90 | -> V.Vector Double -- ^ Initial conditions | 103 | -> V.Vector Double -- ^ Initial conditions |
91 | -> V.Vector Double -- ^ Desired solution times | 104 | -> V.Vector Double -- ^ Desired solution times |
92 | -> Either Int (V.Vector Double) -- ^ Error code or solution | 105 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution |
93 | solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of | 106 | solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of |
94 | Left c -> Left $ fromIntegral c | 107 | Left c -> Left $ fromIntegral c |
95 | Right v -> Right $ coerce v | 108 | Right (v, d) -> Right (coerce v, d) |
96 | 109 | ||
97 | solveOdeC :: | 110 | solveOdeC :: |
98 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 111 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
99 | -> V.Vector CDouble -- ^ Initial conditions | 112 | -> V.Vector CDouble -- ^ Initial conditions |
100 | -> V.Vector CDouble -- ^ Desired solution times | 113 | -> V.Vector CDouble -- ^ Desired solution times |
101 | -> Either CInt (V.Vector CDouble) -- ^ Error code or solution | 114 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution |
102 | solveOdeC fun f0 ts = unsafePerformIO $ do | 115 | solveOdeC fun f0 ts = unsafePerformIO $ do |
103 | let dim = V.length f0 | 116 | let dim = V.length f0 |
104 | nEq :: CLong | 117 | nEq :: CLong |
105 | nEq = fromIntegral dim | 118 | nEq = fromIntegral dim |
106 | nTs :: CInt | 119 | nTs :: CInt |
107 | nTs = fromIntegral $ V.length ts | 120 | nTs = fromIntegral $ V.length ts |
121 | -- FIXME: fMut is not actually mutatated | ||
108 | fMut <- V.thaw f0 | 122 | fMut <- V.thaw f0 |
109 | tMut <- V.thaw ts | 123 | tMut <- V.thaw ts |
110 | -- FIXME: I believe this gets taken from the ghc heap and so should | 124 | -- FIXME: I believe this gets taken from the ghc heap and so should |
111 | -- be subject to garbage collection. | 125 | -- be subject to garbage collection. |
112 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | 126 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) |
113 | qMatMut <- V.thaw quasiMatrixRes | 127 | qMatMut <- V.thaw quasiMatrixRes |
128 | diagnostics :: V.Vector CLong <- createVector 10 -- FIXME | ||
129 | diagMut <- V.thaw diagnostics | ||
114 | -- We need the types that sundials expects. These are tied together | 130 | -- We need the types that sundials expects. These are tied together |
115 | -- in 'Types'. FIXME: The Haskell type is currently empty! | 131 | -- in 'Types'. FIXME: The Haskell type is currently empty! |
116 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 132 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt |
@@ -178,20 +194,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
178 | 194 | ||
179 | /* Linear solver interface */ | 195 | /* Linear solver interface */ |
180 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | 196 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ |
181 | /* Output initial conditions */ | 197 | /* Store initial conditions */ |
182 | for (j = 0; j < NEQ; j++) { | 198 | for (j = 0; j < NEQ; j++) { |
183 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | 199 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); |
184 | } | 200 | } |
185 | 201 | ||
186 | /* Main time-stepping loop: calls ARKode to perform the integration, then | 202 | /* Main time-stepping loop: calls ARKode to perform the integration */ |
187 | prints results. Stops when the final time has been reached */ | 203 | /* Stops when the final time has been reached */ |
188 | printf(" t u\n"); | ||
189 | printf(" ---------------------\n"); | ||
190 | for (i = 1; i < $(int nTs); i++) { | 204 | for (i = 1; i < $(int nTs); i++) { |
191 | 205 | ||
192 | flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ | 206 | flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ |
193 | if (check_flag(&flag, "ARKode", 1)) break; | 207 | if (check_flag(&flag, "ARKode", 1)) break; |
194 | printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ | 208 | |
209 | /* Store the results for Haskell */ | ||
195 | for (j = 0; j < NEQ; j++) { | 210 | for (j = 0; j < NEQ; j++) { |
196 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); | 211 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); |
197 | } | 212 | } |
@@ -201,42 +216,45 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
201 | break; | 216 | break; |
202 | } | 217 | } |
203 | } | 218 | } |
204 | printf(" ---------------------\n"); | ||
205 | |||
206 | for (i = 0; i < NEQ; i++) { | ||
207 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); | ||
208 | }; | ||
209 | 219 | ||
210 | /* Get/print some final statistics on how the solve progressed */ | 220 | /* Get some final statistics on how the solve progressed */ |
211 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | 221 | flag = ARKodeGetNumSteps(arkode_mem, &nst); |
212 | check_flag(&flag, "ARKodeGetNumSteps", 1); | 222 | check_flag(&flag, "ARKodeGetNumSteps", 1); |
223 | ($vec-ptr:(long int *diagMut))[0] = nst; | ||
224 | |||
213 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); | 225 | flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); |
214 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); | 226 | check_flag(&flag, "ARKodeGetNumStepAttempts", 1); |
227 | ($vec-ptr:(long int *diagMut))[1] = nst_a; | ||
228 | |||
215 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); | 229 | flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); |
216 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); | 230 | check_flag(&flag, "ARKodeGetNumRhsEvals", 1); |
231 | ($vec-ptr:(long int *diagMut))[2] = nfe; | ||
232 | ($vec-ptr:(long int *diagMut))[3] = nfi; | ||
233 | |||
217 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); | 234 | flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); |
218 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); | 235 | check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); |
236 | ($vec-ptr:(long int *diagMut))[4] = nsetups; | ||
237 | |||
219 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); | 238 | flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); |
220 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); | 239 | check_flag(&flag, "ARKodeGetNumErrTestFails", 1); |
240 | ($vec-ptr:(long int *diagMut))[5] = netf; | ||
241 | |||
221 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); | 242 | flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); |
222 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); | 243 | check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); |
244 | ($vec-ptr:(long int *diagMut))[6] = nni; | ||
245 | |||
223 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); | 246 | flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); |
224 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); | 247 | check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); |
248 | ($vec-ptr:(long int *diagMut))[7] = ncfn; | ||
249 | |||
225 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); | 250 | flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); |
226 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); | 251 | check_flag(&flag, "ARKDlsGetNumJacEvals", 1); |
252 | ($vec-ptr:(long int *diagMut))[8] = ncfn; | ||
253 | |||
227 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); | 254 | flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); |
228 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); | 255 | check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); |
229 | 256 | ($vec-ptr:(long int *diagMut))[9] = ncfn; | |
230 | printf("\nFinal Solver Statistics:\n"); | 257 | |
231 | printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); | ||
232 | printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); | ||
233 | printf(" Total linear solver setups = %li\n", nsetups); | ||
234 | printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); | ||
235 | printf(" Total number of Jacobian evaluations = %li\n", nje); | ||
236 | printf(" Total number of Newton iterations = %li\n", nni); | ||
237 | printf(" Total number of linear solver convergence failures = %li\n", ncfn); | ||
238 | printf(" Total number of error test failures = %li\n\n", netf); | ||
239 | |||
240 | /* Clean up and return */ | 258 | /* Clean up and return */ |
241 | N_VDestroy(y); /* Free y vector */ | 259 | N_VDestroy(y); /* Free y vector */ |
242 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | 260 | ARKodeFree(&arkode_mem); /* Free integrator memory */ |
@@ -247,7 +265,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
247 | } |] | 265 | } |] |
248 | if res == 0 | 266 | if res == 0 |
249 | then do | 267 | then do |
268 | preD <- V.freeze diagMut | ||
269 | let d = SundialsDiagnostics (fromIntegral $ preD V.!0) | ||
270 | (fromIntegral $ preD V.!1) | ||
271 | (fromIntegral $ preD V.!2) | ||
272 | (fromIntegral $ preD V.!3) | ||
273 | (fromIntegral $ preD V.!4) | ||
274 | (fromIntegral $ preD V.!5) | ||
275 | (fromIntegral $ preD V.!6) | ||
276 | (fromIntegral $ preD V.!7) | ||
277 | (fromIntegral $ preD V.!8) | ||
278 | (fromIntegral $ preD V.!9) | ||
250 | m <- V.freeze qMatMut | 279 | m <- V.freeze qMatMut |
251 | return $ Right m | 280 | return $ Right (m, d) |
252 | else do | 281 | else do |
253 | return $ Left res | 282 | return $ Left res |
283 | |||