From e022e6b2d96f89376113241d89e31e2affd4faaf Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Mon, 9 Apr 2018 15:05:28 +0100 Subject: Improve haddock --- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 146 +++++++++++++++------ 1 file changed, 104 insertions(+), 42 deletions(-) (limited to 'packages/sundials/src/Numeric') diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index b6a59e2..67378cc 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -21,30 +21,57 @@ -- -- A simple example: -- +-- <> +-- -- @ --- import Numeric.Sundials.ARKode --- import Numeric.LinearAlgebra --- import Graphics.Plot(mplot) +-- import Numeric.Sundials.ARKode.ODE +-- import Numeric.LinearAlgebra -- --- xdot t [x,v] = [v, -0.95*x - 0.1*v] +-- import Plots as P +-- import qualified Diagrams.Prelude as D +-- import Diagrams.Backend.Rasterific -- --- ts = linspace 100 (0,20 :: Double) +-- brusselator :: Double -> [Double] -> [Double] +-- brusselator _t x = [ a - (w + 1) * u + v * u * u +-- , w * u - v * u * u +-- , (b - w) / eps - w * u +-- ] +-- where +-- a = 1.0 +-- b = 3.5 +-- eps = 5.0e-6 +-- u = x !! 0 +-- v = x !! 1 +-- w = x !! 2 -- --- sol = odeSolve xdot [10,0] ts +-- lSaxis :: [[Double]] -> P.Axis B D.V2 Double +-- lSaxis xs = P.r2Axis &~ do +-- let ts = xs!!0 +-- us = xs!!1 +-- vs = xs!!2 +-- ws = xs!!3 +-- P.linePlot' $ zip ts us +-- P.linePlot' $ zip ts vs +-- P.linePlot' $ zip ts ws -- --- main = mplot (ts : toColumns sol) +-- main = do +-- let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) +-- renderRasterific "diagrams/brusselator.png" +-- (D.dims2D 500.0 500.0) +-- (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) -- @ -- --- <> --- -- KVAERNO_4_2_3 -- -- \[ -- \begin{array}{c|cccc} --- c_1 & 0.0 & 0.0 & 0.0 & 0.0 \\ --- c_2 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ --- c_3 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ --- c_4 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ +-- 0.0 & 0.0 & 0.0 & 0.0 & 0.0 \\ +-- 0.871733043 & 0.4358665215 & 0.4358665215 & 0.0 & 0.0 \\ +-- 1.0 & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ +-- 1.0 & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ +-- \hline +-- & 0.308809969973036 & 1.490563388254106 & -1.235239879727145 & 0.4358665215 \\ +-- & 0.490563388419108 & 7.3570090080892e-2 & 0.4358665215 & 0.0 \\ -- \end{array} -- \] -- @@ -52,8 +79,11 @@ -- -- \[ -- \begin{array}{c|cc} --- c_1 & 1.0 & 0.0 \\ --- c_2 & -1.0 & 1.0 \\ +-- 1.0 & 1.0 & 0.0 \\ +-- 0.0 & -1.0 & 1.0 \\ +-- \hline +-- & 0.5 & 0.5 \\ +-- & 1.0 & 0.0 \\ -- \end{array} -- \] -- @@ -61,20 +91,22 @@ -- -- \[ -- \begin{array}{c|ccccc} --- c_1 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ --- c_2 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ --- c_3 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ --- c_4 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ --- c_5 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ +-- 0.25 & 0.25 & 0.0 & 0.0 & 0.0 & 0.0 \\ +-- 0.75 & 0.5 & 0.25 & 0.0 & 0.0 & 0.0 \\ +-- 0.55 & 0.34 & -4.0e-2 & 0.25 & 0.0 & 0.0 \\ +-- 0.5 & 0.2727941176470588 & -5.036764705882353e-2 & 2.7573529411764705e-2 & 0.25 & 0.0 \\ +-- 1.0 & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ +-- \hline +-- & 1.0416666666666667 & -1.0208333333333333 & 7.8125 & -7.083333333333333 & 0.25 \\ +-- & 1.2291666666666667 & -0.17708333333333334 & 7.03125 & -7.083333333333333 & 0.0 \\ -- \end{array} -- \] ----------------------------------------------------------------------------- module Numeric.Sundials.ARKode.ODE ( odeSolve , odeSolveV , odeSolveVWith - , getButcherTable - , getBT - , btGet + , ButcherTable(..) + , butcherTable , ODEMethod(..) , StepControl(..) , SundialsDiagnostics(..) @@ -457,10 +489,11 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do ($vec-ptr:(long int *diagMut))[9] = ncfn; /* Clean up and return */ - N_VDestroy(y); /* Free y vector */ + N_VDestroy(y); /* Free y vector */ + N_VDestroy(tv); /* Free tv vector */ ARKodeFree(&arkode_mem); /* Free integrator memory */ - SUNLinSolFree(LS); /* Free linear solver */ - SUNMatDestroy(A); /* Free A matrix */ + SUNLinSolFree(LS); /* Free linear solver */ + SUNMatDestroy(A); /* Free A matrix */ return flag; } |] @@ -482,22 +515,43 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do else do return $ Left res -btGet :: ODEMethod -> (Matrix Double, Vector Double) -btGet method = case getBT method of - Left c -> error $ show c -- FIXME - Right ((v, w), sqp) -> ( subMatrix (0, 0) (s, s) $ - (B.arkSMax >< B.arkSMax) (V.toList v) - , subVector 0 s w) - where - s = fromIntegral $ sqp V.! 0 +data ButcherTable = ButcherTable { am :: Matrix Double + , cv :: Vector Double + , bv :: Vector Double + , b2v :: Vector Double + } + deriving Show + +data ButcherTable' a = ButcherTable' { am' :: V.Vector a + , cv' :: V.Vector a + , bv' :: V.Vector a + , b2v' :: V.Vector a + } + deriving Show + +butcherTable :: ODEMethod -> ButcherTable +butcherTable method = + case getBT method of + Left c -> error $ show c -- FIXME + Right (ButcherTable' v w x y, sqp) -> + ButcherTable { am = subMatrix (0, 0) (s, s) $ (B.arkSMax >< B.arkSMax) (V.toList v) + , cv = subVector 0 s w + , bv = subVector 0 s x + , b2v = subVector 0 s y + } + where + s = fromIntegral $ sqp V.! 0 -getBT :: ODEMethod -> Either Int ((V.Vector Double, V.Vector Double), V.Vector Int) +getBT :: ODEMethod -> Either Int (ButcherTable' Double, V.Vector Int) getBT method = case getButcherTable method of - Left c -> Left $ fromIntegral c - Right ((v, w), sqp) -> Right $ ((coerce v, coerce w), V.map fromIntegral sqp) + Left c -> + Left $ fromIntegral c + Right (ButcherTable' a b c d, sqp) -> + Right $ ( ButcherTable' (coerce a) (coerce b) (coerce c) (coerce d) + , V.map fromIntegral sqp ) getButcherTable :: ODEMethod - -> Either CInt ((V.Vector CDouble, V.Vector CDouble), V.Vector CInt) + -> Either CInt (ButcherTable' CDouble, V.Vector CInt) getButcherTable method = unsafePerformIO $ do -- ARKode seems to want an ODE in order to set and then get the -- Butcher tableau so here's one to keep it happy @@ -515,8 +569,12 @@ getButcherTable method = unsafePerformIO $ do btSQPMut <- V.thaw btSQP btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) btAsMut <- V.thaw btAs - btCs :: V.Vector CDouble <- createVector B.arkSMax - btCsMut <- V.thaw btCs + btCs :: V.Vector CDouble <- createVector B.arkSMax + btBs :: V.Vector CDouble <- createVector B.arkSMax + btB2s :: V.Vector CDouble <- createVector B.arkSMax + btCsMut <- V.thaw btCs + btBsMut <- V.thaw btBs + btB2sMut <- V.thaw btB2s let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIO x y f _ptr = do fImm <- fun x <$> getDataFromContents dim y @@ -576,7 +634,9 @@ getButcherTable method = unsafePerformIO $ do } for (i = 0; i < s; i++) { - ($vec-ptr:(double *btCsMut))[i] = ci[i]; + ($vec-ptr:(double *btCsMut))[i] = ci[i]; + ($vec-ptr:(double *btBsMut))[i] = bi[i]; + ($vec-ptr:(double *btB2sMut))[i] = b2i[i]; } /* Clean up and return */ @@ -590,7 +650,9 @@ getButcherTable method = unsafePerformIO $ do x <- V.freeze btAsMut y <- V.freeze btSQPMut z <- V.freeze btCsMut - return $ Right ((x, z), y) + u <- V.freeze btBsMut + v <- V.freeze btB2sMut + return $ Right (ButcherTable' { am' = x, cv' = z, bv' = u, b2v' = v }, y) else do return $ Left res -- cgit v1.2.3