From 3c4411e48cbcfaf8035e893ac63aa250fcc56d3e Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Sat, 31 Mar 2018 13:06:35 +0100 Subject: Add in the Jacobian --- packages/sundials/src/Arkode.hsc | 19 +++++ packages/sundials/src/Main.hs | 27 ++++-- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 99 ++++++++++++++++------ packages/sundials/src/Types.hs | 7 ++ 4 files changed, 122 insertions(+), 30 deletions(-) (limited to 'packages/sundials') diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Arkode.hsc index 59e701e..ae2b40f 100644 --- a/packages/sundials/src/Arkode.hsc +++ b/packages/sundials/src/Arkode.hsc @@ -9,12 +9,31 @@ import Foreign.C.String #include #include +#include #include +#include #include #def typedef struct _generic_N_Vector SunVector; #def typedef struct _N_VectorContent_Serial SunContent; +#def typedef struct _generic_SUNMatrix SunMatrix; +#def typedef struct _SUNMatrixContent_Dense SunMatrixContent; + +getContentMatrixPtr ptr = (#peek SunMatrix, content) ptr + +getNRows :: Ptr b -> IO CInt +getNRows ptr = (#peek SunMatrixContent, M) ptr +putNRows :: CInt -> Ptr b -> IO () +putNRows nr ptr = (#poke SunMatrixContent, M) ptr nr + +getNCols :: Ptr b -> IO CInt +getNCols ptr = (#peek SunMatrixContent, N) ptr +putNCols :: CInt -> Ptr b -> IO () +putNCols nc ptr = (#poke SunMatrixContent, N) ptr nc + +getMatrixData ptr = (#peek SunMatrixContent, data) ptr + getContentPtr :: Storable a => Ptr b -> IO a getContentPtr ptr = (#peek SunVector, content) ptr diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 71bcbac..01d3595 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -27,11 +27,26 @@ brusselator _t x = [ a - (w + 1) * u + v * u^2 v = x !! 1 w = x !! 2 +brussJac _t x = (3><3) [ (-(w + 1.0)) + 2.0 * u * v, w - 2.0 * u * v, (-w) + , u * u , (-(u * u)) , 0.0 + , (-u) , u , (-1.0) / eps - u + ] + where + y = toList x + u = y !! 0 + v = y !! 1 + w = y !! 2 + eps = 5.0e-6 + stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] where lamda = -100.0 u = v !! 0 +stiffJac _t _v = (1><1) [ lamda ] + where + lamda = -100.0 + lSaxis :: [[Double]] -> P.Axis B D.V2 Double lSaxis xs = P.r2Axis &~ do let ts = xs!!0 @@ -77,14 +92,14 @@ main = do putStrLn $ show res putStrLn $ butcherTableauTex res - let res = odeSolve KVAERNO_4_2_3 brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) - putStrLn $ show res + let res1 = odeSolve KVAERNO_4_2_3 brussJac brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) + putStrLn $ show res1 renderRasterific "diagrams/brusselator.png" (D.dims2D 500.0 500.0) - (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res)) + (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1)) - let res = odeSolve KVAERNO_4_2_3 stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) - putStrLn $ show res + let res2 = odeSolve KVAERNO_4_2_3 stiffJac stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) + putStrLn $ show res2 renderRasterific "diagrams/stiffish.png" (D.dims2D 500.0 500.0) - (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res)) + (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2)) diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 30ff4c8..5af9e41 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -35,6 +35,7 @@ module Numeric.Sundials.Arkode.ODE ( solveOde , getBT , btGet , ODEMethod(..) + , odeSolveV ) where import qualified Language.C.Inline as C @@ -45,7 +46,7 @@ import Data.Monoid ((<>)) import Foreign.C.Types import Foreign.Ptr (Ptr) import Foreign.ForeignPtr (newForeignPtr_) -import Foreign.Storable (Storable, peekByteOff) +import Foreign.Storable (Storable) import qualified Data.Vector.Storable as V import qualified Data.Vector.Storable.Mutable as VM @@ -55,7 +56,8 @@ import System.IO.Unsafe (unsafePerformIO) import Numeric.LinearAlgebra.Devel (createVector) -import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) +import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), + subMatrix, rows, cols, toLists) import qualified Types as T import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) @@ -78,12 +80,34 @@ C.include "../../../helpers.h" C.include "Arkode_hsc.h" -getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) +getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) getDataFromContents len ptr = do qtr <- B.getContentPtr ptr rtr <- B.getData qtr vectorFromC len rtr +-- FIXME: Potentially an instance of Storable +getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix +getMatrixDataFromContents ptr = do + qtr <- B.getContentMatrixPtr ptr + rs <- B.getNRows qtr + cs <- B.getNCols qtr + rtr <- B.getMatrixData qtr + vs <- vectorFromC (fromIntegral $ rs * cs) rtr + return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } + +putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () +putMatrixDataFromContents mat ptr = do + let rs = T.rows mat + cs = T.cols mat + vs = T.vals mat + qtr <- B.getContentMatrixPtr ptr + B.putNRows rs qtr + B.putNCols cs qtr + rtr <- B.getMatrixData qtr + vectorToC vs (fromIntegral $ rs * cs) rtr +-- FIXME: END + putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () putDataInContents vec len ptr = do qtr <- B.getContentPtr ptr @@ -103,16 +127,16 @@ vectorToC vec len ptr = do V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec data SundialsDiagnostics = SundialsDiagnostics { - aRKodeGetNumSteps :: Int - , aRKodeGetNumStepAttempts :: Int - , aRKodeGetNumRhsEvals_fe :: Int - , aRKodeGetNumRhsEvals_fi :: Int - , aRKodeGetNumLinSolvSetups :: Int - , aRKodeGetNumErrTestFails :: Int - , aRKodeGetNumNonlinSolvIters :: Int - , aRKodeGetNumNonlinSolvConvFails :: Int - , aRKDlsGetNumJacEvals :: Int - , aRKDlsGetNumRhsEvals :: Int + _aRKodeGetNumSteps :: Int + , _aRKodeGetNumStepAttempts :: Int + , _aRKodeGetNumRhsEvals_fe :: Int + , _aRKodeGetNumRhsEvals_fi :: Int + , _aRKodeGetNumLinSolvSetups :: Int + , _aRKodeGetNumErrTestFails :: Int + , _aRKodeGetNumNonlinSolvIters :: Int + , _aRKodeGetNumNonlinSolvConvFails :: Int + , _aRKDlsGetNumJacEvals :: Int + , _aRKDlsGetNumRhsEvals :: Int } deriving Show -- | Stepping functions @@ -134,15 +158,16 @@ odeSolveV -> Vector Double -- ^ initial conditions -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution -odeSolveV meth hi epsAbs epsRel = undefined +odeSolveV _meth _hi _epsAbs _epsRel = undefined odeSolve :: ODEMethod + -> (Double -> Vector Double -> Matrix Double) -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> [Double] -- ^ initial conditions -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution -odeSolve method f y0 ts = - case solveOde method 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of +odeSolve method jac f y0 ts = + case solveOde method jac' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of Left c -> error $ show c -- FIXME Right (v, _) -> (nR >< nC) (V.toList v) where @@ -150,30 +175,40 @@ odeSolve method f y0 ts = nR = length us nC = length y0 g t x0 = V.fromList $ f t (V.toList x0) + jac' t v = foo $ jac t (V.fromList $ toList v) + foo m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } + where + nr = fromIntegral $ rows m + nc = fromIntegral $ cols m + vs = V.fromList $ map coerce $ concat $ toLists m solveOde :: ODEMethod + -> (Double -> V.Vector Double -> T.SunMatrix) -> Double -> Double -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector Double -- ^ Initial conditions -> V.Vector Double -- ^ Desired solution times -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution -solveOde method relTol absTol f y0 tt = - case solveOdeC (fromIntegral $ fromEnum method) (CDouble relTol) (CDouble absTol) +solveOde method jac relTol absTol f y0 tt = + case solveOdeC (fromIntegral $ fromEnum method) jacH (CDouble relTol) (CDouble absTol) (coerce f) (coerce y0) (coerce tt) of Left c -> Left $ fromIntegral c Right (v, d) -> Right (coerce v, d) + where + jacH t v = jac (coerce t) (coerce v) solveOdeC :: CInt -> + (CDouble -> V.Vector CDouble -> T.SunMatrix) -> CDouble -> CDouble -> (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions -> V.Vector CDouble -- ^ Desired solution times -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution -solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do +solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim @@ -197,9 +232,19 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do fImm <- fun x <$> getDataFromContents dim y -- Fill in the provided pointer with the resulting vector. putDataInContents fImm dim f - -- I don't understand what this comment means + -- FIXME: I don't understand what this comment means + -- Unsafe since the function will be called many times. + [CU.exp| int{ 0 } |] + let jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> + Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> + IO CInt + jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do + foo <- jacH t <$> getDataFromContents dim y + putMatrixDataFromContents foo jacS + -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] + res <- [C.block| int { /* general problem variables */ int flag; /* reusable error-checking flag */ @@ -246,6 +291,11 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do /* Linear solver interface */ flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ + + flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); + if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; + + /* Store initial conditions */ for (j = 0; j < NEQ; j++) { ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); @@ -340,7 +390,8 @@ btGet :: ODEMethod -> Matrix Double btGet method = case getBT method of Left c -> error $ show c -- FIXME -- FIXME - Right (v, sqp) -> subMatrix (0, 0) (2, 2) $ (B.arkSMax >< B.arkSMax) (V.toList v) + Right (v, _sqp) -> subMatrix (0, 0) (2, 2) $ + (B.arkSMax >< B.arkSMax) (V.toList v) getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) getBT method = case getButcherTable method of @@ -352,9 +403,9 @@ getButcherTable method = unsafePerformIO $ do -- arkode seems to want an ODE in order to set and then get the -- Butcher tableau so here's one to keep it happy let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble - fun t ys = V.fromList [ ys V.! 0 ] - f0 = V.fromList [ 1.0 ] - ts = V.fromList [ 0.0 ] + fun _t ys = V.fromList [ ys V.! 0 ] + f0 = V.fromList [ 1.0 ] + ts = V.fromList [ 0.0 ] dim = V.length f0 nEq :: CLong nEq = fromIntegral dim diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs index e910c57..04e4280 100644 --- a/packages/sundials/src/Types.hs +++ b/packages/sundials/src/Types.hs @@ -15,8 +15,14 @@ import qualified Language.C.Types as CT import qualified Data.Map as Map import Language.C.Inline.Context +import qualified Data.Vector.Storable as V + data SunVector +data SunMatrix = SunMatrix { rows :: CInt + , cols :: CInt + , vals :: V.Vector CDouble + } -- FIXME: Is this true? type SunIndexType = CLong @@ -26,6 +32,7 @@ sunTypesTable = Map.fromList [ (CT.TypeName "sunindextype", [t| SunIndexType |] ) , (CT.TypeName "SunVector", [t| SunVector |] ) + , (CT.TypeName "SunMatrix", [t| SunMatrix |] ) ] sunCtx :: Context -- cgit v1.2.3