From a22963fa83156b76dd73777b7044897eed50e3bc Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Sun, 11 Mar 2018 14:21:31 +0000 Subject: The start of an hmatrix interface to sundials --- packages/sundials/src/Test.hs | 164 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 packages/sundials/src/Test.hs (limited to 'packages/sundials/src/Test.hs') diff --git a/packages/sundials/src/Test.hs b/packages/sundials/src/Test.hs new file mode 100644 index 0000000..a99582a --- /dev/null +++ b/packages/sundials/src/Test.hs @@ -0,0 +1,164 @@ +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE OverloadedStrings #-} + +import qualified Language.C.Inline as C +import qualified Language.C.Inline.Unsafe as CU +import Data.Monoid ((<>)) +import Foreign.C.Types +import Foreign.Ptr (Ptr) +import Foreign.Marshal.Array +import qualified Data.Vector.Storable as V + +import Data.Coerce (coerce) +import Data.Monoid ((<>)) +import qualified Data.Vector.Storable as V +import qualified Data.Vector.Storable.Mutable as VM +import Foreign.C.Types +import Foreign.ForeignPtr (newForeignPtr_) +import Foreign.Ptr (Ptr) +import Foreign.Storable (Storable) +import qualified Language.C.Inline as C +import qualified Language.C.Inline.Unsafe as CU +import System.IO.Unsafe (unsafePerformIO) + +import qualified Language.Haskell.TH as TH +import qualified Language.C.Types as CT +import qualified Data.Map as Map +import Language.C.Inline.Context + +C.context (C.baseCtx <> C.vecCtx <> C.funCtx) + +-- C includes +C.include "" +C.include "" +C.include "" -- prototypes for ARKODE fcts., consts. +C.include "" -- serial N_Vector types, fcts., macros +C.include "" -- access to dense SUNMatrix +C.include "" -- access to dense SUNLinearSolver +C.include "" -- access to ARKDls interface +C.include "" -- definition of type realtype +C.include "" +C.include "helpers.h" + +-- | Solves a system of ODEs. Every 'V.Vector' involved must be of the +-- same size. +-- {-# NOINLINE solveOdeC #-} +-- solveOdeC +-- :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) +-- -- ^ ODE to Solve +-- -> CDouble +-- -- ^ Start +-- -> V.Vector CDouble +-- -- ^ Solution at start point +-- -> CDouble +-- -- ^ End +-- -> Either String (V.Vector CDouble) +-- -- ^ Solution at end point, or error. +-- solveOdeC fun x0 f0 xend = unsafePerformIO $ do +-- let dim = V.length f0 +-- let dim_c = fromIntegral dim -- This is in CInt +-- -- Convert the function to something of the right type to C. +-- let funIO x y f _ptr = do +-- -- Convert the pointer we get from C (y) to a vector, and then +-- -- apply the user-supplied function. +-- fImm <- fun x <$> vectorFromC dim y +-- -- Fill in the provided pointer with the resulting vector. +-- vectorToC fImm dim f +-- -- Unsafe since the function will be called many times. +-- [CU.exp| int{ GSL_SUCCESS } |] +-- -- Create a mutable vector from the initial solution. This will be +-- -- passed to the ODE solving function provided by GSL, and will +-- -- contain the final solution. +-- fMut <- V.thaw f0 +-- res <- [C.block| int { +-- gsl_odeiv2_system sys = { +-- $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)), +-- // The ODE to solve, converted to function pointer using the `fun` +-- // anti-quoter +-- NULL, // We don't provide a Jacobian +-- $(int dim_c), // The dimension +-- NULL // We don't need the parameter pointer +-- }; +-- // Create the driver, using some sensible values for the stepping +-- // function and the tolerances +-- gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new ( +-- &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0); +-- // Finally, apply the driver. +-- int status = gsl_odeiv2_driver_apply( +-- d, &$(double x0), $(double xend), $vec-ptr:(double *fMut)); +-- // Free the driver +-- gsl_odeiv2_driver_free(d); +-- return status; +-- } |] +-- -- Check the error code +-- maxSteps <- [C.exp| int{ GSL_EMAXITER } |] +-- smallStep <- [C.exp| int{ GSL_ENOPROG } |] +-- good <- [C.exp| int{ GSL_SUCCESS } |] +-- if | res == good -> Right <$> V.freeze fMut +-- | res == maxSteps -> return $ Left "Too many steps" +-- | res == smallStep -> return $ Left "Step size dropped below minimum allowed size" +-- | otherwise -> return $ Left $ "Unknown error code " ++ show res + +-- -- Utils + +-- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) +-- vectorFromC len ptr = do +-- ptr' <- newForeignPtr_ ptr +-- V.freeze $ VM.unsafeFromForeignPtr0 ptr' len + +-- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () +-- vectorToC vec len ptr = do +-- ptr' <- newForeignPtr_ ptr +-- V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec + + +-- /* Check function return value... +-- opt == 0 means SUNDIALS function allocates memory so check if +-- returned NULL pointer +-- opt == 1 means SUNDIALS function returns a flag so check if +-- flag >= 0 +-- opt == 2 means function allocates memory so check if returned +-- NULL pointer +-- */ +-- static int check_flag(void *flagvalue, const char *funcname, int opt) +-- { +-- int *errflag; + +-- /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ +-- if (opt == 0 && flagvalue == NULL) { +-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", +-- funcname); +-- return 1; } + +-- /* Check if flag < 0 */ +-- else if (opt == 1) { +-- errflag = (int *) flagvalue; +-- if (*errflag < 0) { +-- fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n", +-- funcname, *errflag); +-- return 1; }} + +-- /* Check if function returned NULL pointer - no memory allocated */ +-- else if (opt == 2 && flagvalue == NULL) { +-- fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n", +-- funcname); +-- return 1; } + +-- return 0; +-- } + +main = do + res <- [C.block| int { sunindextype NEQ = 1; /* number of dependent vars. */ + N_Vector y = NULL; /* empty vector for storing solution */ + void *arkode_mem = NULL; /* empty ARKode memory structure */ + y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ + if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; + + N_VConst(0.0, y); /* Specify initial condition */ + arkode_mem = ARKodeCreate(); /* Create the solver memory */ + if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; + return 0; + } |] + putStrLn $ show res -- cgit v1.2.3