summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/sundials/src/Main.hs245
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs45
2 files changed, 18 insertions, 272 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index d1f35bd..978088b 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -1,84 +1,7 @@
1{-# OPTIONS_GHC -Wall #-} 1{-# OPTIONS_GHC -Wall #-}
2 2
3{-# LANGUAGE QuasiQuotes #-}
4{-# LANGUAGE TemplateHaskell #-}
5{-# LANGUAGE MultiWayIf #-}
6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8
9import qualified Language.C.Inline as C
10import qualified Language.C.Inline.Unsafe as CU
11import Data.Monoid ((<>))
12import Foreign.C.Types
13import Foreign.Ptr (Ptr)
14import qualified Data.Vector.Storable as V 3import qualified Data.Vector.Storable as V
15 4import Numeric.Sundials.Arkode.ODE
16import Data.Coerce (coerce)
17import qualified Data.Vector.Storable.Mutable as VM
18import Foreign.ForeignPtr (newForeignPtr_)
19import Foreign.Storable (Storable)
20import System.IO.Unsafe (unsafePerformIO)
21
22import Foreign.Storable (peekByteOff)
23
24import Numeric.LinearAlgebra.HMatrix (Vector, Matrix)
25
26import qualified Types as T
27
28C.context (C.baseCtx <> C.vecCtx <> C.funCtx <> T.sunCtx)
29
30-- C includes
31C.include "<stdio.h>"
32C.include "<math.h>"
33C.include "<arkode/arkode.h>" -- prototypes for ARKODE fcts., consts.
34C.include "<nvector/nvector_serial.h>" -- serial N_Vector types, fcts., macros
35C.include "<sunmatrix/sunmatrix_dense.h>" -- access to dense SUNMatrix
36C.include "<sunlinsol/sunlinsol_dense.h>" -- access to dense SUNLinearSolver
37C.include "<arkode/arkode_direct.h>" -- access to ARKDls interface
38C.include "<sundials/sundials_types.h>" -- definition of type realtype
39C.include "<sundials/sundials_math.h>"
40C.include "helpers.h"
41
42
43-- These were semi-generated using hsc2hs with Bar.hsc as the
44-- template. They are probably very fragile and could easily break on
45-- different architectures and / or changes in the sundials package.
46
47getContentPtr :: Storable a => Ptr b -> IO a
48getContentPtr ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 0)) ptr
49
50getData :: Storable a => Ptr b -> IO a
51getData ptr = ((\hsc_ptr -> peekByteOff hsc_ptr 16)) ptr
52
53getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b)
54getDataFromContents len ptr = do
55 qtr <- getContentPtr ptr
56 rtr <- getData qtr
57 vectorFromC len rtr
58
59putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
60putDataInContents vec len ptr = do
61 qtr <- getContentPtr ptr
62 rtr <- getData qtr
63 vectorToC vec len rtr
64
65-- Utils
66
67vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a)
68vectorFromC len ptr = do
69 ptr' <- newForeignPtr_ ptr
70 V.freeze $ VM.unsafeFromForeignPtr0 ptr' len
71
72vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO ()
73vectorToC vec len ptr = do
74 ptr' <- newForeignPtr_ ptr
75 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
76
77stiffish :: Double -> V.Vector Double -> V.Vector Double
78stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
79 where
80 u = v V.! 0
81 lamda = -100.0
82 5
83brusselator :: Double -> V.Vector Double -> V.Vector Double 6brusselator :: Double -> V.Vector Double -> V.Vector Double
84brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 7brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2
@@ -93,171 +16,7 @@ brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2
93 v = x V.! 1 16 v = x V.! 1
94 w = x V.! 2 17 w = x V.! 2
95 18
96
97odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
98 -> [Double] -- ^ initial conditions
99 -> Vector Double -- ^ desired solution times
100 -> Matrix Double -- ^ solution
101odeSolve = undefined
102
103solveOdeC ::
104 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
105 -> V.Vector CDouble -- ^ Initial conditions
106 -> V.Vector CDouble -- ^ Desired solution times
107 -> Either CInt (V.Vector CDouble) -- ^ Error code or solution
108solveOdeC fun f0 ts = unsafePerformIO $ do
109 let dim = V.length f0
110 nEq :: CLong
111 nEq = fromIntegral dim
112 fMut <- V.thaw f0
113 -- We need the types that sundials expects. These are tied together
114 -- in 'Types'. The Haskell type is currently empty!
115 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
116 funIO x y f _ptr = do
117 -- Convert the pointer we get from C (y) to a vector, and then
118 -- apply the user-supplied function.
119 fImm <- fun x <$> getDataFromContents dim y
120 -- Fill in the provided pointer with the resulting vector.
121 putDataInContents fImm dim f
122 -- I don't understand what this comment means
123 -- Unsafe since the function will be called many times.
124 [CU.exp| int{ 0 } |]
125 res <- [C.block| int {
126 /* general problem variables */
127 int flag; /* reusable error-checking flag */
128 N_Vector y = NULL; /* empty vector for storing solution */
129 SUNMatrix A = NULL; /* empty matrix for linear solver */
130 SUNLinearSolver LS = NULL; /* empty linear solver object */
131 void *arkode_mem = NULL; /* empty ARKode memory structure */
132 FILE *UFID;
133 realtype t, tout;
134 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
135
136 /* general problem parameters */
137 realtype T0 = RCONST(0.0); /* initial time */
138 realtype Tf = RCONST(10.0); /* final time */
139 realtype dTout = RCONST(1.0); /* time between outputs */
140 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
141 realtype reltol = 1.0e-6; /* tolerances */
142 realtype abstol = 1.0e-10;
143
144 /* Initial diagnostics output */
145 printf("\nAnalytical ODE test problem:\n");
146 printf(" reltol = %.1"ESYM"\n", reltol);
147 printf(" abstol = %.1"ESYM"\n\n",abstol);
148
149 /* Initialize data structures */
150 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
151 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
152 int i;
153 for (i = 0; i < NEQ; i++) {
154 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
155 }; /* Specify initial condition */
156 arkode_mem = ARKodeCreate(); /* Create the solver memory */
157 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
158
159 /* Call ARKodeInit to initialize the integrator memory and specify the */
160 /* right-hand side function in y'=f(t,y), the inital time T0, and */
161 /* the initial dependent variable vector y. Note: since this */
162 /* problem is fully implicit, we set f_E to NULL and f_I to f. */
163
164 /* Here we use the C types defined in helpers.h which tie up with */
165 /* the Haskell types defined in Types */
166 flag = ARKodeInit(arkode_mem, NULL, $fun:(int (* funIO) (double t, BarType y[], BarType dydt[], void * params)), T0, y);
167 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
168
169 /* Set routines */
170 flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */
171 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1;
172
173 /* Initialize dense matrix data structure and solver */
174 A = SUNDenseMatrix(NEQ, NEQ);
175 if (check_flag((void *)A, "SUNDenseMatrix", 0)) return 1;
176 LS = SUNDenseLinearSolver(y, A);
177 if (check_flag((void *)LS, "SUNDenseLinearSolver", 0)) return 1;
178
179 /* Linear solver interface */
180 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
181 /* Open output stream for results, output comment line */
182 UFID = fopen("solution.txt","w");
183 fprintf(UFID,"# t u\n");
184
185 /* output initial condition to disk */
186 fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0));
187
188 /* Main time-stepping loop: calls ARKode to perform the integration, then
189 prints results. Stops when the final time has been reached */
190 t = T0;
191 tout = T0+dTout;
192 printf(" t u\n");
193 printf(" ---------------------\n");
194 while (Tf - t > 1.0e-15) {
195
196 flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */
197 if (check_flag(&flag, "ARKode", 1)) break;
198 printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */
199 fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0));
200 if (flag >= 0) { /* successful solve: update time */
201 tout += dTout;
202 tout = (tout > Tf) ? Tf : tout;
203 } else { /* unsuccessful solve: break */
204 fprintf(stderr,"Solver failure, stopping integration\n");
205 break;
206 }
207 }
208 printf(" ---------------------\n");
209 fclose(UFID);
210
211 for (i = 0; i < NEQ; i++) {
212 ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i);
213 };
214
215 /* Get/print some final statistics on how the solve progressed */
216 flag = ARKodeGetNumSteps(arkode_mem, &nst);
217 check_flag(&flag, "ARKodeGetNumSteps", 1);
218 flag = ARKodeGetNumStepAttempts(arkode_mem, &nst_a);
219 check_flag(&flag, "ARKodeGetNumStepAttempts", 1);
220 flag = ARKodeGetNumRhsEvals(arkode_mem, &nfe, &nfi);
221 check_flag(&flag, "ARKodeGetNumRhsEvals", 1);
222 flag = ARKodeGetNumLinSolvSetups(arkode_mem, &nsetups);
223 check_flag(&flag, "ARKodeGetNumLinSolvSetups", 1);
224 flag = ARKodeGetNumErrTestFails(arkode_mem, &netf);
225 check_flag(&flag, "ARKodeGetNumErrTestFails", 1);
226 flag = ARKodeGetNumNonlinSolvIters(arkode_mem, &nni);
227 check_flag(&flag, "ARKodeGetNumNonlinSolvIters", 1);
228 flag = ARKodeGetNumNonlinSolvConvFails(arkode_mem, &ncfn);
229 check_flag(&flag, "ARKodeGetNumNonlinSolvConvFails", 1);
230 flag = ARKDlsGetNumJacEvals(arkode_mem, &nje);
231 check_flag(&flag, "ARKDlsGetNumJacEvals", 1);
232 flag = ARKDlsGetNumRhsEvals(arkode_mem, &nfeLS);
233 check_flag(&flag, "ARKDlsGetNumRhsEvals", 1);
234
235 printf("\nFinal Solver Statistics:\n");
236 printf(" Internal solver steps = %li (attempted = %li)\n", nst, nst_a);
237 printf(" Total RHS evals: Fe = %li, Fi = %li\n", nfe, nfi);
238 printf(" Total linear solver setups = %li\n", nsetups);
239 printf(" Total RHS evals for setting up the linear system = %li\n", nfeLS);
240 printf(" Total number of Jacobian evaluations = %li\n", nje);
241 printf(" Total number of Newton iterations = %li\n", nni);
242 printf(" Total number of linear solver convergence failures = %li\n", ncfn);
243 printf(" Total number of error test failures = %li\n\n", netf);
244
245 /* Clean up and return */
246 N_VDestroy(y); /* Free y vector */
247 ARKodeFree(&arkode_mem); /* Free integrator memory */
248 SUNLinSolFree(LS); /* Free linear solver */
249 SUNMatDestroy(A); /* Free A matrix */
250
251 return flag;
252 } |]
253 if res ==0
254 then do
255 v <- V.freeze fMut
256 return $ Right v
257 else do
258 return $ Left res
259
260main :: IO () 19main :: IO ()
261main = do 20main = do
262 let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined 21 let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0])
263 putStrLn $ show res 22 putStrLn $ show res
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 58acef3..9de20b6 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -6,7 +6,7 @@
6{-# LANGUAGE OverloadedStrings #-} 6{-# LANGUAGE OverloadedStrings #-}
7{-# LANGUAGE ScopedTypeVariables #-} 7{-# LANGUAGE ScopedTypeVariables #-}
8 8
9module Numeric.Sundials.Arkode.ODE ( solveOdeC ) where 9module Numeric.Sundials.Arkode.ODE ( solveOde ) where
10 10
11import qualified Language.C.Inline as C 11import qualified Language.C.Inline as C
12import qualified Language.C.Inline.Unsafe as CU 12import qualified Language.C.Inline.Unsafe as CU
@@ -76,32 +76,21 @@ vectorToC vec len ptr = do
76 ptr' <- newForeignPtr_ ptr 76 ptr' <- newForeignPtr_ ptr
77 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 77 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
78 78
79stiffish :: Double -> V.Vector Double -> V.Vector Double
80stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
81 where
82 u = v V.! 0
83 lamda = -100.0
84
85brusselator :: Double -> V.Vector Double -> V.Vector Double
86brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2
87 , w * u - v * u^2
88 , (b - w) / eps - w * u
89 ]
90 where
91 a = 1.0
92 b = 3.5
93 eps = 5.0e-6
94 u = x V.! 0
95 v = x V.! 1
96 w = x V.! 2
97
98
99odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 79odeSolve :: (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
100 -> [Double] -- ^ initial conditions 80 -> [Double] -- ^ initial conditions
101 -> Vector Double -- ^ desired solution times 81 -> Vector Double -- ^ desired solution times
102 -> Matrix Double -- ^ solution 82 -> Matrix Double -- ^ solution
103odeSolve = undefined 83odeSolve = undefined
104 84
85solveOde ::
86 (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
87 -> V.Vector Double -- ^ Initial conditions
88 -> V.Vector Double -- ^ Desired solution times
89 -> Either Int (V.Vector Double) -- ^ Error code or solution
90solveOde f y0 tt = case solveOdeC (coerce f) (coerce y0) (coerce tt) of
91 Left c -> Left $ fromIntegral c
92 Right v -> Right $ coerce v
93
105solveOdeC :: 94solveOdeC ::
106 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 95 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
107 -> V.Vector CDouble -- ^ Initial conditions 96 -> V.Vector CDouble -- ^ Initial conditions
@@ -111,7 +100,10 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
111 let dim = V.length f0 100 let dim = V.length f0
112 nEq :: CLong 101 nEq :: CLong
113 nEq = fromIntegral dim 102 nEq = fromIntegral dim
103 nTs :: CInt
104 nTs = fromIntegral $ V.length ts
114 fMut <- V.thaw f0 105 fMut <- V.thaw f0
106 tMut <- V.thaw ts
115 -- We need the types that sundials expects. These are tied together 107 -- We need the types that sundials expects. These are tied together
116 -- in 'Types'. The Haskell type is currently empty! 108 -- in 'Types'. The Haskell type is currently empty!
117 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt 109 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
@@ -136,8 +128,8 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
136 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; 128 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
137 129
138 /* general problem parameters */ 130 /* general problem parameters */
139 realtype T0 = RCONST(0.0); /* initial time */ 131 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */
140 realtype Tf = RCONST(10.0); /* final time */ 132 realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */
141 realtype dTout = RCONST(1.0); /* time between outputs */ 133 realtype dTout = RCONST(1.0); /* time between outputs */
142 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ 134 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
143 realtype reltol = 1.0e-6; /* tolerances */ 135 realtype reltol = 1.0e-6; /* tolerances */
@@ -193,7 +185,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
193 tout = T0+dTout; 185 tout = T0+dTout;
194 printf(" t u\n"); 186 printf(" t u\n");
195 printf(" ---------------------\n"); 187 printf(" ---------------------\n");
196 while (Tf - t > 1.0e-15) { 188 for (i = 0; i < $(int nTs); i++) {
197 189
198 flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ 190 flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */
199 if (check_flag(&flag, "ARKode", 1)) break; 191 if (check_flag(&flag, "ARKode", 1)) break;
@@ -258,8 +250,3 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
258 return $ Right v 250 return $ Right v
259 else do 251 else do
260 return $ Left res 252 return $ Left res
261
262main :: IO ()
263main = do
264 let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) undefined
265 putStrLn $ show res