From d057093a7681a0ea448f8ae98e241eeafd5ad050 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 21 Mar 2018 13:12:57 +0000 Subject: Return the entire results matrix (as a vector) --- packages/sundials/src/Main.hs | 7 +++ .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 59 +++++++++++----------- 2 files changed, 37 insertions(+), 29 deletions(-) (limited to 'packages') diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 978088b..2a561c4 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -16,7 +16,14 @@ brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 v = x V.! 1 w = x V.! 2 +stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] + where + lamda = -100.0 + u = v V.! 0 + main :: IO () main = do let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0]) putStrLn $ show res + let res = solveOde stiffish (V.fromList [1.0]) (V.fromList [0.0, 0.1 .. 10.0]) + putStrLn $ show res diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 9de20b6..c5d085e 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -10,18 +10,21 @@ module Numeric.Sundials.Arkode.ODE ( solveOde ) where import qualified Language.C.Inline as C import qualified Language.C.Inline.Unsafe as CU + import Data.Monoid ((<>)) + import Foreign.C.Types import Foreign.Ptr (Ptr) +import Foreign.ForeignPtr (newForeignPtr_) +import Foreign.Storable (Storable, peekByteOff) + import qualified Data.Vector.Storable as V +import qualified Data.Vector.Storable.Mutable as VM import Data.Coerce (coerce) -import qualified Data.Vector.Storable.Mutable as VM -import Foreign.ForeignPtr (newForeignPtr_) -import Foreign.Storable (Storable) import System.IO.Unsafe (unsafePerformIO) -import Foreign.Storable (peekByteOff) +import Numeric.LinearAlgebra.Devel (createVector) import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) @@ -104,8 +107,12 @@ solveOdeC fun f0 ts = unsafePerformIO $ do nTs = fromIntegral $ V.length ts fMut <- V.thaw f0 tMut <- V.thaw ts + -- FIXME: I believe this gets taken from the ghc heap and so should + -- be subject to garbage collection. + quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) + qMatMut <- V.thaw quasiMatrixRes -- We need the types that sundials expects. These are tied together - -- in 'Types'. The Haskell type is currently empty! + -- in 'Types'. FIXME: The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt funIO x y f _ptr = do -- Convert the pointer we get from C (y) to a vector, and then @@ -124,13 +131,12 @@ solveOdeC fun f0 ts = unsafePerformIO $ do SUNLinearSolver LS = NULL; /* empty linear solver object */ void *arkode_mem = NULL; /* empty ARKode memory structure */ FILE *UFID; - realtype t, tout; + realtype t; long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; /* general problem parameters */ realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ - realtype dTout = RCONST(1.0); /* time between outputs */ sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ realtype reltol = 1.0e-6; /* tolerances */ realtype abstol = 1.0e-10; @@ -143,7 +149,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do /* Initialize data structures */ y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; - int i; + int i, j; for (i = 0; i < NEQ; i++) { NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; }; /* Specify initial condition */ @@ -172,35 +178,30 @@ solveOdeC fun f0 ts = unsafePerformIO $ do /* Linear solver interface */ flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ - /* Open output stream for results, output comment line */ - UFID = fopen("solution.txt","w"); - fprintf(UFID,"# t u\n"); - - /* output initial condition to disk */ - fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); - + /* Output initial conditions */ + for (j = 0; j < NEQ; j++) { + ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); + } + /* Main time-stepping loop: calls ARKode to perform the integration, then prints results. Stops when the final time has been reached */ - t = T0; - tout = T0+dTout; printf(" t u\n"); printf(" ---------------------\n"); - for (i = 0; i < $(int nTs); i++) { + for (i = 1; i < $(int nTs); i++) { - flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ + flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ if (check_flag(&flag, "ARKode", 1)) break; - printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ - fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); - if (flag >= 0) { /* successful solve: update time */ - tout += dTout; - tout = (tout > Tf) ? Tf : tout; - } else { /* unsuccessful solve: break */ + printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ + for (j = 0; j < NEQ; j++) { + ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); + } + + if (flag < 0) { /* unsuccessful solve: break */ fprintf(stderr,"Solver failure, stopping integration\n"); break; } } printf(" ---------------------\n"); - fclose(UFID); for (i = 0; i < NEQ; i++) { ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); @@ -244,9 +245,9 @@ solveOdeC fun f0 ts = unsafePerformIO $ do return flag; } |] - if res ==0 + if res == 0 then do - v <- V.freeze fMut - return $ Right v + m <- V.freeze qMatMut + return $ Right m else do return $ Left res -- cgit v1.2.3