summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Test.hs
blob: a99582aa092209d9bafc91837c1df50506ea91e2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}

import qualified Language.C.Inline as C
import qualified Language.C.Inline.Unsafe as CU
import           Data.Monoid ((<>))
import           Foreign.C.Types
import           Foreign.Ptr (Ptr)
import           Foreign.Marshal.Array
import qualified Data.Vector.Storable as V

import           Data.Coerce (coerce)
import           Data.Monoid ((<>))
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as VM
import           Foreign.C.Types
import           Foreign.ForeignPtr (newForeignPtr_)
import           Foreign.Ptr (Ptr)
import           Foreign.Storable (Storable)
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Unsafe as CU
import           System.IO.Unsafe (unsafePerformIO)

import qualified Language.Haskell.TH as TH
import qualified Language.C.Types as CT
import qualified Data.Map as Map
import           Language.C.Inline.Context

C.context (C.baseCtx <> C.vecCtx <> C.funCtx)

-- C includes
C.include "<stdio.h>"
C.include "<math.h>"
C.include "<arkode/arkode.h>"                 -- prototypes for ARKODE fcts., consts.
C.include "<nvector/nvector_serial.h>"        -- serial N_Vector types, fcts., macros
C.include "<sunmatrix/sunmatrix_dense.h>"     -- access to dense SUNMatrix           
C.include "<sunlinsol/sunlinsol_dense.h>"     -- access to dense SUNLinearSolver     
C.include "<arkode/arkode_direct.h>"          -- access to ARKDls interface          
C.include "<sundials/sundials_types.h>"       -- definition of type realtype         
C.include "<sundials/sundials_math.h>"
C.include "helpers.h"

-- | Solves a system of ODEs.  Every 'V.Vector' involved must be of the
-- same size.
-- {-# NOINLINE solveOdeC #-}
-- solveOdeC
--   :: (CDouble -> V.Vector CDouble -> V.Vector CDouble)
--   -- ^ ODE to Solve
--   -> CDouble
--   -- ^ Start
--   -> V.Vector CDouble
--   -- ^ Solution at start point
--   -> CDouble
--   -- ^ End
--   -> Either String (V.Vector CDouble)
--   -- ^ Solution at end point, or error.
-- solveOdeC fun x0 f0 xend = unsafePerformIO $ do
--   let dim = V.length f0
--   let dim_c = fromIntegral dim -- This is in CInt
--   -- Convert the function to something of the right type to C.
--   let funIO x y f _ptr = do
--         -- Convert the pointer we get from C (y) to a vector, and then
--         -- apply the user-supplied function.
--         fImm <- fun x <$> vectorFromC dim y
--         -- Fill in the provided pointer with the resulting vector.
--         vectorToC fImm dim f
--         -- Unsafe since the function will be called many times.
--         [CU.exp| int{ GSL_SUCCESS } |]
--   -- Create a mutable vector from the initial solution.  This will be
--   -- passed to the ODE solving function provided by GSL, and will
--   -- contain the final solution.
--   fMut <- V.thaw f0
--   res <- [C.block| int {
--       gsl_odeiv2_system sys = {
--         $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)),
--         // The ODE to solve, converted to function pointer using the `fun`
--         // anti-quoter
--         NULL,                   // We don't provide a Jacobian
--         $(int dim_c),           // The dimension
--         NULL                    // We don't need the parameter pointer
--       };
--       // Create the driver, using some sensible values for the stepping
--       // function and the tolerances
--       gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new (
--         &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0);
--       // Finally, apply the driver.
--       int status = gsl_odeiv2_driver_apply(
--         d, &$(double x0), $(double xend), $vec-ptr:(double *fMut));
--       // Free the driver
--       gsl_odeiv2_driver_free(d);
--       return status;
--     } |]
--   -- Check the error code
--   maxSteps <- [C.exp| int{ GSL_EMAXITER } |]
--   smallStep <- [C.exp| int{ GSL_ENOPROG } |]
--   good <- [C.exp| int{ GSL_SUCCESS } |]
--   if | res == good -> Right <$> V.freeze fMut
--      | res == maxSteps -> return $ Left "Too many steps"
--      | res == smallStep -> return $ Left "Step size dropped below minimum allowed size"
--      | otherwise -> return $ Left $ "Unknown error code " ++ show res

-- -- Utils

-- vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
-- vectorFromC len ptr = do
--   ptr' <- newForeignPtr_ ptr
--   V.freeze $ VM.unsafeFromForeignPtr0 ptr' len

-- vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
-- vectorToC vec len ptr = do
--   ptr' <- newForeignPtr_ ptr
--   V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec


-- /* Check function return value...
--     opt == 0 means SUNDIALS function allocates memory so check if
--              returned NULL pointer
--     opt == 1 means SUNDIALS function returns a flag so check if
--              flag >= 0
--     opt == 2 means function allocates memory so check if returned
--              NULL pointer  
-- */
-- static int check_flag(void *flagvalue, const char *funcname, int opt)
-- {
--   int *errflag;

--   /* Check if SUNDIALS function returned NULL pointer - no memory allocated */
--   if (opt == 0 && flagvalue == NULL) {
--     fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
-- 	    funcname);
--     return 1; }

--   /* Check if flag < 0 */
--   else if (opt == 1) {
--     errflag = (int *) flagvalue;
--     if (*errflag < 0) {
--       fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n",
-- 	      funcname, *errflag);
--       return 1; }}

--   /* Check if function returned NULL pointer - no memory allocated */
--   else if (opt == 2 && flagvalue == NULL) {
--     fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
-- 	    funcname);
--     return 1; }

--   return 0;
-- }

main = do
  res <- [C.block| int { sunindextype NEQ = 1;        /* number of dependent vars. */
                         N_Vector y = NULL;           /* empty vector for storing solution */
                         void *arkode_mem = NULL;     /* empty ARKode memory structure */
                         y = N_VNew_Serial(NEQ);      /* Create serial vector for solution */
                         if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;

                         N_VConst(0.0, y);            /* Specify initial condition */
                         arkode_mem = ARKodeCreate(); /* Create the solver memory */
                         if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
                         return 0;
                       } |]
  putStrLn $ show res