From f5dc976be9ee31095fbcd2f825375bbb42f44b64 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Sun, 8 Apr 2018 15:04:33 +0100 Subject: Improve Butcher Tableaux --- packages/sundials/src/Main.hs | 18 +++---- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 55 ++++++++++++---------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index ac19e7f..d22fafa 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -121,17 +121,17 @@ main = do -- \end{array} -- $$ - -- let res = btGet (SDIRK_2_1_2 undefined) - -- putStrLn $ show res - -- putStrLn $ butcherTableauTex res + let res = btGet (SDIRK_2_1_2 undefined) + putStrLn $ show res + putStrLn $ butcherTableauTex $ fst res - -- let res = btGet (KVAERNO_4_2_3 undefined) - -- putStrLn $ show res - -- putStrLn $ butcherTableauTex res + let res = btGet (KVAERNO_4_2_3 undefined) + putStrLn $ show res + putStrLn $ butcherTableauTex $ fst res - -- let res = btGet (SDIRK_5_3_4 undefined) - -- putStrLn $ show res - -- putStrLn $ butcherTableauTex res + let res = btGet (SDIRK_5_3_4 undefined) + putStrLn $ show res + putStrLn $ butcherTableauTex $ fst res let res1 = odeSolve brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) renderRasterific "diagrams/brusselator.png" diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 2577b8e..b6a59e2 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -101,7 +101,7 @@ import Numeric.LinearAlgebra.Devel (createVector) import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix, rows, cols, toLists, - size) + size, subVector) import qualified Types as T import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3, sDIRK_5_3_4) @@ -482,22 +482,24 @@ solveOdeC method jacH (absTols, relTol) fun f0 ts = unsafePerformIO $ do else do return $ Left res -btGet :: ODEMethod -> Matrix Double +btGet :: ODEMethod -> (Matrix Double, Vector Double) btGet method = case getBT method of Left c -> error $ show c -- FIXME - Right (v, sqp) -> subMatrix (0, 0) (s, s) $ - (B.arkSMax >< B.arkSMax) (V.toList v) + Right ((v, w), sqp) -> ( subMatrix (0, 0) (s, s) $ + (B.arkSMax >< B.arkSMax) (V.toList v) + , subVector 0 s w) where s = fromIntegral $ sqp V.! 0 -getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) +getBT :: ODEMethod -> Either Int ((V.Vector Double, V.Vector Double), V.Vector Int) getBT method = case getButcherTable method of Left c -> Left $ fromIntegral c - Right (v, sqp) -> Right $ (coerce v, V.map fromIntegral sqp) + Right ((v, w), sqp) -> Right $ ((coerce v, coerce w), V.map fromIntegral sqp) -getButcherTable :: ODEMethod -> Either CInt ((V.Vector CDouble), V.Vector CInt) +getButcherTable :: ODEMethod + -> Either CInt ((V.Vector CDouble, V.Vector CDouble), V.Vector CInt) getButcherTable method = unsafePerformIO $ do - -- arkode seems to want an ODE in order to set and then get the + -- ARKode seems to want an ODE in order to set and then get the -- Butcher tableau so here's one to keep it happy let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble fun _t ys = V.fromList [ ys V.! 0 ] @@ -509,41 +511,37 @@ getButcherTable method = unsafePerformIO $ do mN :: CInt mN = fromIntegral $ getMethod method - -- FIXME: I believe these gets taken from the ghc heap and so should - -- be subject to garbage collection. btSQP :: V.Vector CInt <- createVector 3 btSQPMut <- V.thaw btSQP btAs :: V.Vector CDouble <- createVector (B.arkSMax * B.arkSMax) btAsMut <- V.thaw btAs - -- We need the types that sundials expects. These are tied together - -- in 'Types'. FIXME: The Haskell type is currently empty! + btCs :: V.Vector CDouble <- createVector B.arkSMax + btCsMut <- V.thaw btCs let funIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr () -> IO CInt funIO x y f _ptr = do - -- Convert the pointer we get from C (y) to a vector, and then - -- apply the user-supplied function. fImm <- fun x <$> getDataFromContents dim y - -- Fill in the provided pointer with the resulting vector. putDataInContents fImm dim f - -- I don't understand what this comment means + -- FIXME: I don't understand what this comment means -- Unsafe since the function will be called many times. [CU.exp| int{ 0 } |] res <- [C.block| int { /* general problem variables */ - int flag; /* reusable error-checking flag */ - N_Vector y = NULL; /* empty vector for storing solution */ - void *arkode_mem = NULL; /* empty ARKode memory structure */ + + int flag; /* reusable error-checking flag */ + N_Vector y = NULL; /* empty vector for storing solution */ + void *arkode_mem = NULL; /* empty ARKode memory structure */ + int i, j; /* reusable loop indices */ /* general problem parameters */ - /* initial time */ - realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); - /* number of dependent vars. */ - sunindextype NEQ = $(sunindextype nEq); + + realtype T0 = RCONST(($vec-ptr:(double *ts))[0]); /* initial time */ + sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars */ /* Initialize data structures */ - y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ + + y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; /* Specify initial condition */ - int i, j; for (i = 0; i < NEQ; i++) { NV_Ith_S(y,i) = ($vec-ptr:(double *f0))[i]; }; @@ -577,6 +575,10 @@ getButcherTable method = unsafePerformIO $ do } } + for (i = 0; i < s; i++) { + ($vec-ptr:(double *btCsMut))[i] = ci[i]; + } + /* Clean up and return */ N_VDestroy(y); /* Free y vector */ ARKodeFree(&arkode_mem); /* Free integrator memory */ @@ -587,7 +589,8 @@ getButcherTable method = unsafePerformIO $ do then do x <- V.freeze btAsMut y <- V.freeze btSQPMut - return $ Right (x, y) + z <- V.freeze btCsMut + return $ Right ((x, z), y) else do return $ Left res -- cgit v1.2.3