From 0d52842881192a627d6f52e47c2fe26592f20adb Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 21 Mar 2018 13:57:16 +0000 Subject: Also return diagnostics --- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 90 ++++++++++++++-------- 1 file changed, 60 insertions(+), 30 deletions(-) (limited to 'packages/sundials/src/Numeric/Sundials/ARKode') diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index c5d085e..630827c 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -79,6 +79,19 @@ vectorToC vec len ptr = do ptr' <- newForeignPtr_ ptr V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec +data SundialsDiagnostics = SundialsDiagnostics { + aRKodeGetNumSteps :: Int + , aRKodeGetNumStepAttempts :: Int + , aRKodeGetNumRhsEvals_fe :: Int + , aRKodeGetNumRhsEvals_fi :: Int + , aRKodeGetNumLinSolvSetups :: Int + , aRKodeGetNumErrTestFails :: Int + , aRKodeGetNumNonlinSolvIters :: Int + , aRKodeGetNumNonlinSolvConvFails :: Int + , aRKDlsGetNumJacEvals :: Int + , aRKDlsGetNumRhsEvals :: Int + } deriving Show + odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> [Double] -- ^ initial conditions -> Vector Double -- ^ desired solution times @@ -89,28 +102,31 @@ solveOde :: (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times - -> Either Int (V.Vector Double) -- ^ Error code or solution + -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c - Right v -> Right $ coerce v + Right (v, d) -> Right (coerce v, d) solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) - -> V.Vector CDouble -- ^ Initial conditions - -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt (V.Vector CDouble) -- ^ Error code or solution + -> V.Vector CDouble -- ^ Initial conditions + -> V.Vector CDouble -- ^ Desired solution times + -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution solveOdeC fun f0 ts = unsafePerformIO $ do let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim nTs :: CInt nTs = fromIntegral $ V.length ts + -- FIXME: fMut is not actually mutatated fMut <- V.thaw f0 tMut <- V.thaw ts -- FIXME: I believe this gets taken from the ghc heap and so should -- be subject to garbage collection. quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) qMatMut <- V.thaw quasiMatrixRes + diagnostics :: V.Vector CLong <- createVector 10 -- FIXME + diagMut <- V.thaw diagnostics -- We need the types that sundials expects. These are tied together -- in 'Types'. FIXME: The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt @@ -178,20 +194,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do /* Linear solver interface */ flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ - /* Output initial conditions */ + /* Store initial conditions */ for (j = 0; j < NEQ; j++) { ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); } - /* Main time-stepping loop: calls ARKode to perform the integration, then - prints results. Stops when the final time has been reached */ - printf(" t u\n"); - printf(" ---------------------\n"); + /* Main time-stepping loop: calls ARKode to perform the integration */ + /* Stops when the final time has been reached */ for (i = 1; i < $(int nTs); i++) { - flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ + flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ if (check_flag(&flag, "ARKode", 1)) break; - printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ + + /* Store the results for Haskell */ for (j = 0; j < NEQ; j++) { ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); } @@ -201,42 +216,45 @@ solveOdeC fun f0 ts = unsafePerformIO $ do break; } } - printf(" ---------------------\n"); - - for (i = 0; i < NEQ; i++) { - ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); - }; - /* Get/print some final statistics on how the solve progressed */ + /* Get some final statistics on how the solve progressed */ flag = ARKodeGetNumSteps(arkode_mem, &nst); check_flag(&flag, "ARKodeGetNumSteps", 1); + ($vec-ptr:(long int *diagMut))[0] = nst; + flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); check_flag(&flag, "ARKodeGetNumStepAttempts", 1); + ($vec-ptr:(long int *diagMut))[1] = nst_a; + flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); check_flag(&flag, "ARKodeGetNumRhsEvals", 1); + ($vec-ptr:(long int *diagMut))[2] = nfe; + ($vec-ptr:(long int *diagMut))[3] = nfi; + flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); + ($vec-ptr:(long int *diagMut))[4] = nsetups; + flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); check_flag(&flag, "ARKodeGetNumErrTestFails", 1); + ($vec-ptr:(long int *diagMut))[5] = netf; + flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); + ($vec-ptr:(long int *diagMut))[6] = nni; + flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); + ($vec-ptr:(long int *diagMut))[7] = ncfn; + flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); check_flag(&flag, "ARKDlsGetNumJacEvals", 1); + ($vec-ptr:(long int *diagMut))[8] = ncfn; + flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); - - printf("\nFinal Solver Statistics:\n"); - printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); - printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); - printf(" Total linear solver setups = %li\n", nsetups); - printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); - printf(" Total number of Jacobian evaluations = %li\n", nje); - printf(" Total number of Newton iterations = %li\n", nni); - printf(" Total number of linear solver convergence failures = %li\n", ncfn); - printf(" Total number of error test failures = %li\n\n", netf); - + ($vec-ptr:(long int *diagMut))[9] = ncfn; + /* Clean up and return */ N_VDestroy(y); /* Free y vector */ ARKodeFree(&arkode_mem); /* Free integrator memory */ @@ -247,7 +265,19 @@ solveOdeC fun f0 ts = unsafePerformIO $ do } |] if res == 0 then do + preD <- V.freeze diagMut + let d = SundialsDiagnostics (fromIntegral $ preD V.!0) + (fromIntegral $ preD V.!1) + (fromIntegral $ preD V.!2) + (fromIntegral $ preD V.!3) + (fromIntegral $ preD V.!4) + (fromIntegral $ preD V.!5) + (fromIntegral $ preD V.!6) + (fromIntegral $ preD V.!7) + (fromIntegral $ preD V.!8) + (fromIntegral $ preD V.!9) m <- V.freeze qMatMut - return $ Right m + return $ Right (m, d) else do return $ Left res + -- cgit v1.2.3