summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-21 13:57:16 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-21 13:57:16 +0000
commit0d52842881192a627d6f52e47c2fe26592f20adb (patch)
tree1a3a740f812ec962f783e005edf848acf0ba855a /packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
parent10080499b9b1c1c01ff6d6bb4194608d2eff9eca (diff)
Also return diagnostics
Diffstat (limited to 'packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs90
1 files changed, 60 insertions, 30 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index c5d085e..630827c 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -79,6 +79,19 @@ vectorToC vec len ptr = do
79 ptr' <- newForeignPtr_ ptr 79 ptr' <- newForeignPtr_ ptr
80 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 80 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
81 81
82data SundialsDiagnostics = SundialsDiagnostics {
83 aRKodeGetNumSteps :: Int
84 , aRKodeGetNumStepAttempts :: Int
85 , aRKodeGetNumRhsEvals_fe :: Int
86 , aRKodeGetNumRhsEvals_fi :: Int
87 , aRKodeGetNumLinSolvSetups :: Int
88 , aRKodeGetNumErrTestFails :: Int
89 , aRKodeGetNumNonlinSolvIters :: Int
90 , aRKodeGetNumNonlinSolvConvFails :: Int
91 , aRKDlsGetNumJacEvals :: Int
92 , aRKDlsGetNumRhsEvals :: Int
93 } deriving Show
94
82odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 95odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
83 -> [Double] -- ^ initial conditions 96 -> [Double] -- ^ initial conditions
84 -> Vector Double -- ^ desired solution times 97 -> Vector Double -- ^ desired solution times
@@ -89,28 +102,31 @@ solveOde ::
89 (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 102 (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
90 -> V.Vector Double -- ^ Initial conditions 103 -> V.Vector Double -- ^ Initial conditions
91 -> V.Vector Double -- ^ Desired solution times 104 -> V.Vector Double -- ^ Desired solution times
92 -> Either Int (V.Vector Double) -- ^ Error code or solution 105 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
93solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of 106solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of
94 Left c -> Left $ fromIntegral c 107 Left c -> Left $ fromIntegral c
95 Right v -> Right $ coerce v 108 Right (v, d) -> Right (coerce v, d)
96 109
97solveOdeC :: 110solveOdeC ::
98 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 111 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
99 -> V.Vector CDouble -- ^ Initial conditions 112 -> V.Vector CDouble -- ^ Initial conditions
100 -> V.Vector CDouble -- ^ Desired solution times 113 -> V.Vector CDouble -- ^ Desired solution times
101 -> Either CInt (V.Vector CDouble) -- ^ Error code or solution 114 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
102solveOdeC fun f0 ts = unsafePerformIO $ do 115solveOdeC fun f0 ts = unsafePerformIO $ do
103 let dim = V.length f0 116 let dim = V.length f0
104 nEq :: CLong 117 nEq :: CLong
105 nEq = fromIntegral dim 118 nEq = fromIntegral dim
106 nTs :: CInt 119 nTs :: CInt
107 nTs = fromIntegral $ V.length ts 120 nTs = fromIntegral $ V.length ts
121 -- FIXME: fMut is not actually mutatated
108 fMut <- V.thaw f0 122 fMut <- V.thaw f0
109 tMut <- V.thaw ts 123 tMut <- V.thaw ts
110 -- FIXME: I believe this gets taken from the ghc heap and so should 124 -- FIXME: I believe this gets taken from the ghc heap and so should
111 -- be subject to garbage collection. 125 -- be subject to garbage collection.
112 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) 126 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
113 qMatMut <- V.thaw quasiMatrixRes 127 qMatMut <- V.thaw quasiMatrixRes
128 diagnostics :: V.Vector CLong <- createVector 10 -- FIXME
129 diagMut <- V.thaw diagnostics
114 -- We need the types that sundials expects. These are tied together 130 -- We need the types that sundials expects. These are tied together
115 -- in 'Types'. FIXME: The Haskell type is currently empty! 131 -- in 'Types'. FIXME: The Haskell type is currently empty!
116 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt 132 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
@@ -178,20 +194,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
178 194
179 /* Linear solver interface */ 195 /* Linear solver interface */
180 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ 196 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
181 /* Output initial conditions */ 197 /* Store initial conditions */
182 for (j = 0; j < NEQ; j++) { 198 for (j = 0; j < NEQ; j++) {
183 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); 199 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
184 } 200 }
185 201
186 /* Main time-stepping loop: calls ARKode to perform the integration, then 202 /* Main time-stepping loop: calls ARKode to perform the integration */
187 prints results. Stops when the final time has been reached */ 203 /* Stops when the final time has been reached */
188 printf(" t u\n");
189 printf(" ---------------------\n");
190 for (i = 1; i < $(int nTs); i++) { 204 for (i = 1; i < $(int nTs); i++) {
191 205
192 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ 206 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */
193 if (check_flag(&flag, "ARKode", 1)) break; 207 if (check_flag(&flag, "ARKode", 1)) break;
194 printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ 208
209 /* Store the results for Haskell */
195 for (j = 0; j < NEQ; j++) { 210 for (j = 0; j < NEQ; j++) {
196 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); 211 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
197 } 212 }
@@ -201,42 +216,45 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
201 break; 216 break;
202 } 217 }
203 } 218 }
204 printf(" ---------------------\n");
205
206 for (i = 0; i < NEQ; i++) {
207 ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i);
208 };
209 219
210 /* Get/print some final statistics on how the solve progressed */ 220 /* Get some final statistics on how the solve progressed */
211 flag = ARKodeGetNumSteps(arkode_mem, &nst); 221 flag = ARKodeGetNumSteps(arkode_mem, &nst);
212 check_flag(&flag, "ARKodeGetNumSteps", 1); 222 check_flag(&flag, "ARKodeGetNumSteps", 1);
223 ($vec-ptr:(long int *diagMut))[0] = nst;
224
213 flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); 225 flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a);
214 check_flag(&flag, "ARKodeGetNumStepAttempts", 1); 226 check_flag(&flag, "ARKodeGetNumStepAttempts", 1);
227 ($vec-ptr:(long int *diagMut))[1] = nst_a;
228
215 flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); 229 flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi);
216 check_flag(&flag, "ARKodeGetNumRhsEvals", 1); 230 check_flag(&flag, "ARKodeGetNumRhsEvals", 1);
231 ($vec-ptr:(long int *diagMut))[2] = nfe;
232 ($vec-ptr:(long int *diagMut))[3] = nfi;
233
217 flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); 234 flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups);
218 check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); 235 check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1);
236 ($vec-ptr:(long int *diagMut))[4] = nsetups;
237
219 flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); 238 flag = ARKodeGetNumErrTestFails(arkode_mem, &netf);
220 check_flag(&flag, "ARKodeGetNumErrTestFails", 1); 239 check_flag(&flag, "ARKodeGetNumErrTestFails", 1);
240 ($vec-ptr:(long int *diagMut))[5] = netf;
241
221 flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); 242 flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni);
222 check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); 243 check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1);
244 ($vec-ptr:(long int *diagMut))[6] = nni;
245
223 flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); 246 flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn);
224 check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); 247 check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1);
248 ($vec-ptr:(long int *diagMut))[7] = ncfn;
249
225 flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); 250 flag = ARKDlsGetNumJacEvals(arkode_mem, &nje);
226 check_flag(&flag, "ARKDlsGetNumJacEvals", 1); 251 check_flag(&flag, "ARKDlsGetNumJacEvals", 1);
252 ($vec-ptr:(long int *diagMut))[8] = ncfn;
253
227 flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); 254 flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS);
228 check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); 255 check_flag(&flag, "ARKDlsGetNumRhsEvals", 1);
229 256 ($vec-ptr:(long int *diagMut))[9] = ncfn;
230 printf("\nFinal Solver Statistics:\n"); 257
231 printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a);
232 printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi);
233 printf(" Total linear solver setups = %li\n", nsetups);
234 printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS);
235 printf(" Total number of Jacobian evaluations = %li\n", nje);
236 printf(" Total number of Newton iterations = %li\n", nni);
237 printf(" Total number of linear solver convergence failures = %li\n", ncfn);
238 printf(" Total number of error test failures = %li\n\n", netf);
239
240 /* Clean up and return */ 258 /* Clean up and return */
241 N_VDestroy(y); /* Free y vector */ 259 N_VDestroy(y); /* Free y vector */
242 ARKodeFree(&arkode_mem); /* Free integrator memory */ 260 ARKodeFree(&arkode_mem); /* Free integrator memory */
@@ -247,7 +265,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
247 } |] 265 } |]
248 if res == 0 266 if res == 0
249 then do 267 then do
268 preD <- V.freeze diagMut
269 let d = SundialsDiagnostics (fromIntegral $ preD V.!0)
270 (fromIntegral $ preD V.!1)
271 (fromIntegral $ preD V.!2)
272 (fromIntegral $ preD V.!3)
273 (fromIntegral $ preD V.!4)
274 (fromIntegral $ preD V.!5)
275 (fromIntegral $ preD V.!6)
276 (fromIntegral $ preD V.!7)
277 (fromIntegral $ preD V.!8)
278 (fromIntegral $ preD V.!9)
250 m <- V.freeze qMatMut 279 m <- V.freeze qMatMut
251 return $ Right m 280 return $ Right (m, d)
252 else do 281 else do
253 return $ Left res 282 return $ Left res
283