From 1b64b28dfccf2cb9539cdb4344cd7ecb1c1d0a1d Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Wed, 21 Mar 2018 08:24:05 +0000 Subject: A library function + move towards the original hmatrix interface --- packages/sundials/src/Main.hs | 245 +-------------------- .../sundials/src/Numeric/Sundials/ARKode/ODE.hs | 45 ++-- 2 files changed, 18 insertions(+), 272 deletions(-) (limited to 'packages') diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index d1f35bd..978088b 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs @@ -1,84 +1,7 @@ {-# OPTIONS_GHC -Wall #-} -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE ScopedTypeVariables #-} - -import qualified Language.C.Inline as C -import qualified Language.C.Inline.Unsafe as CU -import Data.Monoid ((<>)) -import Foreign.C.Types -import Foreign.Ptr (Ptr) import qualified Data.Vector.Storable as V - -import Data.Coerce (coerce) -import qualified Data.Vector.Storable.Mutable as VM -import Foreign.ForeignPtr (newForeignPtr_) -import Foreign.Storable (Storable) -import System.IO.Unsafe (unsafePerformIO) - -import Foreign.Storable (peekByteOff) - -import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) - -import qualified Types as T - -C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx) - --- C includes -C.include "" -C.include "" -C.include "" -- prototypes for ARKODE fcts., consts. -C.include "" -- serial N_Vector types, fcts., macros -C.include "" -- access to dense SUNMatrix -C.include "" -- access to dense SUNLinearSolver -C.include "" -- access to ARKDls interface -C.include "" -- definition of type realtype -C.include "" -C.include "helpers.h" - - --- These were semi-generated using hsc2hs with Bar.hsc as the --- template. They are probably very fragile and could easily break on --- different architectures and / or changes in the sundials package. - -getContentPtr :: Storable a => Ptr b -> IO a -getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr - -getData :: Storable a => Ptr b -> IO a -getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr - -getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) -getDataFromContents len ptr = do - qtr <- getContentPtr ptr - rtr <- getData qtr - vectorFromC len rtr - -putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () -putDataInContents vec len ptr = do - qtr <- getContentPtr ptr - rtr <- getData qtr - vectorToC vec len rtr - --- Utils - -vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) -vectorFromC len ptr = do - ptr' <- newForeignPtr_ ptr - V.freeze $ VM.unsafeFromForeignPtr0 ptr' len - -vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () -vectorToC vec len ptr = do - ptr' <- newForeignPtr_ ptr - V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec - -stiffish :: Double -> V.Vector Double -> V.Vector Double -stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] - where - u = v V.! 0 - lamda = -100.0 +import Numeric.Sundials.Arkode.ODE brusselator :: Double -> V.Vector Double -> V.Vector Double brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 @@ -93,171 +16,7 @@ brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 v = x V.! 1 w = x V.! 2 - -odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) - -> [Double] -- ^ initial conditions - -> Vector Double -- ^ desired solution times - -> Matrix Double -- ^ solution -odeSolve = undefined - -solveOdeC :: - (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) - -> V.Vector CDouble -- ^ Initial conditions - -> V.Vector CDouble -- ^ Desired solution times - -> Either CInt (V.Vector CDouble) -- ^ Error code or solution -solveOdeC fun f0 ts = unsafePerformIO $ do - let dim = V.length f0 - nEq :: CLong - nEq = fromIntegral dim - fMut <- V.thaw f0 - -- We need the types that sundials expects. These are tied together - -- in 'Types'. The Haskell type is currently empty! - let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt - funIO x y f _ptr = do - -- Convert the pointer we get from C (y) to a vector, and then - -- apply the user-supplied function. - fImm <- fun x <$> getDataFromContents dim y - -- Fill in the provided pointer with the resulting vector. - putDataInContents fImm dim f - -- I don't understand what this comment means - -- Unsafe since the function will be called many times. - [CU.exp| int{ 0 } |] - res <- [C.block| int { - /* general problem variables */ - int flag; /* reusable error-checking flag */ - N_Vector y = NULL; /* empty vector for storing solution */ - SUNMatrix A = NULL; /* empty matrix for linear solver */ - SUNLinearSolver LS = NULL; /* empty linear solver object */ - void *arkode_mem = NULL; /* empty ARKode memory structure */ - FILE *UFID; - realtype t, tout; - long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; - - /* general problem parameters */ - realtype T0 = RCONST(0.0); /* initial time */ - realtype Tf = RCONST(10.0); /* final time */ - realtype dTout = RCONST(1.0); /* time between outputs */ - sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ - realtype reltol = 1.0e-6; /* tolerances */ - realtype abstol = 1.0e-10; - - /* Initial diagnostics output */ - printf("\nAnalytical ODE test problem:\n"); - printf(" reltol = %.1"ESYM"\n", reltol); - printf(" abstol = %.1"ESYM"\n\n",abstol); - - /* Initialize data structures */ - y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ - if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; - int i; - for (i = 0; i < NEQ; i++) { - NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; - }; /* Specify initial condition */ - arkode_mem = ARKodeCreate(); /* Create the solver memory */ - if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; - - /* Call ARKodeInit to initialize the integrator memory and specify the */ - /* right-hand side function in y'=f(t,y), the inital time T0, and */ - /* the initial dependent variable vector y. Note: since this */ - /* problem is fully implicit, we set f_E to NULL and f_I to f. */ - - /* Here we use the C types defined in helpers.h which tie up with */ - /* the Haskell types defined in Types */ - flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y); - if (check_flag(&flag, "ARKodeInit", 1)) return 1; - - /* Set routines */ - flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ - if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; - - /* Initialize dense matrix data structure and solver */ - A = SUNDenseMatrix(NEQ, NEQ); - if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1; - LS = SUNDenseLinearSolver(y, A); - if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1; - - /* Linear solver interface */ - flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ - /* Open output stream for results, output comment line */ - UFID = fopen("solution.txt","w"); - fprintf(UFID,"# t u\n"); - - /* output initial condition to disk */ - fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); - - /* Main time-stepping loop: calls ARKode to perform the integration, then - prints results. Stops when the final time has been reached */ - t = T0; - tout = T0+dTout; - printf(" t u\n"); - printf(" ---------------------\n"); - while (Tf - t > 1.0e-15) { - - flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ - if (check_flag(&flag, "ARKode", 1)) break; - printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ - fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); - if (flag >= 0) { /* successful solve: update time */ - tout += dTout; - tout = (tout > Tf) ? Tf : tout; - } else { /* unsuccessful solve: break */ - fprintf(stderr,"Solver failure, stopping integration\n"); - break; - } - } - printf(" ---------------------\n"); - fclose(UFID); - - for (i = 0; i < NEQ; i++) { - ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); - }; - - /* Get/print some final statistics on how the solve progressed */ - flag = ARKodeGetNumSteps(arkode_mem, &nst); - check_flag(&flag, "ARKodeGetNumSteps", 1); - flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a); - check_flag(&flag, "ARKodeGetNumStepAttempts", 1); - flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi); - check_flag(&flag, "ARKodeGetNumRhsEvals", 1); - flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups); - check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1); - flag = ARKodeGetNumErrTestFails(arkode_mem, &netf); - check_flag(&flag, "ARKodeGetNumErrTestFails", 1); - flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni); - check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1); - flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn); - check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1); - flag = ARKDlsGetNumJacEvals(arkode_mem, &nje); - check_flag(&flag, "ARKDlsGetNumJacEvals", 1); - flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS); - check_flag(&flag, "ARKDlsGetNumRhsEvals", 1); - - printf("\nFinal Solver Statistics:\n"); - printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a); - printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi); - printf(" Total linear solver setups = %li\n", nsetups); - printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS); - printf(" Total number of Jacobian evaluations = %li\n", nje); - printf(" Total number of Newton iterations = %li\n", nni); - printf(" Total number of linear solver convergence failures = %li\n", ncfn); - printf(" Total number of error test failures = %li\n\n", netf); - - /* Clean up and return */ - N_VDestroy(y); /* Free y vector */ - ARKodeFree(&arkode_mem); /* Free integrator memory */ - SUNLinSolFree(LS); /* Free linear solver */ - SUNMatDestroy(A); /* Free A matrix */ - - return flag; - } |] - if res ==0 - then do - v <- V.freeze fMut - return $ Right v - else do - return $ Left res - main :: IO () main = do - let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined + let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0]) putStrLn $ show res diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 58acef3..9de20b6 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs @@ -6,7 +6,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} -module Numeric.Sundials.Arkode.ODE ( solveOdeC ) where +module Numeric.Sundials.Arkode.ODE ( solveOde ) where import qualified Language.C.Inline as C import qualified Language.C.Inline.Unsafe as CU @@ -76,32 +76,21 @@ vectorToC vec len ptr = do ptr' <- newForeignPtr_ ptr V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec -stiffish :: Double -> V.Vector Double -> V.Vector Double -stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] - where - u = v V.! 0 - lamda = -100.0 - -brusselator :: Double -> V.Vector Double -> V.Vector Double -brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 - , w * u - v * u^2 - , (b - w) / eps - w * u - ] - where - a = 1.0 - b = 3.5 - eps = 5.0e-6 - u = x V.! 0 - v = x V.! 1 - w = x V.! 2 - - odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> [Double] -- ^ initial conditions -> Vector Double -- ^ desired solution times -> Matrix Double -- ^ solution odeSolve = undefined +solveOde :: + (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) + -> V.Vector Double -- ^ Initial conditions + -> V.Vector Double -- ^ Desired solution times + -> Either Int (V.Vector Double) -- ^ Error code or solution +solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of + Left c -> Left $ fromIntegral c + Right v -> Right $ coerce v + solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) -> V.Vector CDouble -- ^ Initial conditions @@ -111,7 +100,10 @@ solveOdeC fun f0 ts = unsafePerformIO $ do let dim = V.length f0 nEq :: CLong nEq = fromIntegral dim + nTs :: CInt + nTs = fromIntegral $ V.length ts fMut <- V.thaw f0 + tMut <- V.thaw ts -- We need the types that sundials expects. These are tied together -- in 'Types'. The Haskell type is currently empty! let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt @@ -136,8 +128,8 @@ solveOdeC fun f0 ts = unsafePerformIO $ do long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; /* general problem parameters */ - realtype T0 = RCONST(0.0); /* initial time */ - realtype Tf = RCONST(10.0); /* final time */ + realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ + realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ realtype dTout = RCONST(1.0); /* time between outputs */ sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ realtype reltol = 1.0e-6; /* tolerances */ @@ -193,7 +185,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do tout = T0+dTout; printf(" t u\n"); printf(" ---------------------\n"); - while (Tf - t > 1.0e-15) { + for (i = 0; i < $(int nTs); i++) { flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ if (check_flag(&flag, "ARKode", 1)) break; @@ -258,8 +250,3 @@ solveOdeC fun f0 ts = unsafePerformIO $ do return $ Right v else do return $ Left res - -main :: IO () -main = do - let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined - putStrLn $ show res -- cgit v1.2.3