summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Numeric
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-21 13:12:57 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-21 13:12:57 +0000
commitd057093a7681a0ea448f8ae98e241eeafd5ad050 (patch)
treee3f91821ebed46631cdd0293aa3c8cacec1ddec4 /packages/sundials/src/Numeric
parent1b64b28dfccf2cb9539cdb4344cd7ecb1c1d0a1d (diff)
Return the entire results matrix (as a vector)
Diffstat (limited to 'packages/sundials/src/Numeric')
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs59
1 files changed, 30 insertions, 29 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 9de20b6..c5d085e 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -10,18 +10,21 @@ module Numeric.Sundials.Arkode.ODE ( solveOde ) where
10 10
11import qualified Language.C.Inline as C 11import qualified Language.C.Inline as C
12import qualified Language.C.Inline.Unsafe as CU 12import qualified Language.C.Inline.Unsafe as CU
13
13import Data.Monoid ((<>)) 14import Data.Monoid ((<>))
15
14import Foreign.C.Types 16import Foreign.C.Types
15import Foreign.Ptr (Ptr) 17import Foreign.Ptr (Ptr)
18import Foreign.ForeignPtr (newForeignPtr_)
19import Foreign.Storable (Storable, peekByteOff)
20
16import qualified Data.Vector.Storable as V 21import qualified Data.Vector.Storable as V
22import qualified Data.Vector.Storable.Mutable as VM
17 23
18import Data.Coerce (coerce) 24import Data.Coerce (coerce)
19import qualified Data.Vector.Storable.Mutable as VM
20import Foreign.ForeignPtr (newForeignPtr_)
21import Foreign.Storable (Storable)
22import System.IO.Unsafe (unsafePerformIO) 25import System.IO.Unsafe (unsafePerformIO)
23 26
24import Foreign.Storable (peekByteOff) 27import Numeric.LinearAlgebra.Devel (createVector)
25 28
26import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) 29import Numeric.LinearAlgebra.HMatrix (Vector, Matrix)
27 30
@@ -104,8 +107,12 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
104 nTs = fromIntegral $ V.length ts 107 nTs = fromIntegral $ V.length ts
105 fMut <- V.thaw f0 108 fMut <- V.thaw f0
106 tMut <- V.thaw ts 109 tMut <- V.thaw ts
110 -- FIXME: I believe this gets taken from the ghc heap and so should
111 -- be subject to garbage collection.
112 quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs))
113 qMatMut <- V.thaw quasiMatrixRes
107 -- We need the types that sundials expects. These are tied together 114 -- We need the types that sundials expects. These are tied together
108 -- in 'Types'. The Haskell type is currently empty! 115 -- in 'Types'. FIXME: The Haskell type is currently empty!
109 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt 116 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
110 funIO x y f _ptr = do 117 funIO x y f _ptr = do
111 -- Convert the pointer we get from C (y) to a vector, and then 118 -- Convert the pointer we get from C (y) to a vector, and then
@@ -124,13 +131,12 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
124 SUNLinearSolver LS = NULL; /* empty linear solver object */ 131 SUNLinearSolver LS = NULL; /* empty linear solver object */
125 void *arkode_mem = NULL; /* empty ARKode memory structure */ 132 void *arkode_mem = NULL; /* empty ARKode memory structure */
126 FILE *UFID; 133 FILE *UFID;
127 realtype t, tout; 134 realtype t;
128 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; 135 long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf;
129 136
130 /* general problem parameters */ 137 /* general problem parameters */
131 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ 138 realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */
132 realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ 139 realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */
133 realtype dTout = RCONST(1.0); /* time between outputs */
134 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ 140 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
135 realtype reltol = 1.0e-6; /* tolerances */ 141 realtype reltol = 1.0e-6; /* tolerances */
136 realtype abstol = 1.0e-10; 142 realtype abstol = 1.0e-10;
@@ -143,7 +149,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
143 /* Initialize data structures */ 149 /* Initialize data structures */
144 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ 150 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
145 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 151 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
146 int i; 152 int i, j;
147 for (i = 0; i < NEQ; i++) { 153 for (i = 0; i < NEQ; i++) {
148 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; 154 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
149 }; /* Specify initial condition */ 155 }; /* Specify initial condition */
@@ -172,35 +178,30 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
172 178
173 /* Linear solver interface */ 179 /* Linear solver interface */
174 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ 180 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
175 /* Open output stream for results, output comment line */ 181 /* Output initial conditions */
176 UFID = fopen("solution.txt","w"); 182 for (j = 0; j < NEQ; j++) {
177 fprintf(UFID,"# t u\n"); 183 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
178 184 }
179 /* output initial condition to disk */ 185
180 fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0));
181
182 /* Main time-stepping loop: calls ARKode to perform the integration, then 186 /* Main time-stepping loop: calls ARKode to perform the integration, then
183 prints results. Stops when the final time has been reached */ 187 prints results. Stops when the final time has been reached */
184 t = T0;
185 tout = T0+dTout;
186 printf(" t u\n"); 188 printf(" t u\n");
187 printf(" ---------------------\n"); 189 printf(" ---------------------\n");
188 for (i = 0; i < $(int nTs); i++) { 190 for (i = 1; i < $(int nTs); i++) {
189 191
190 flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ 192 flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */
191 if (check_flag(&flag, "ARKode", 1)) break; 193 if (check_flag(&flag, "ARKode", 1)) break;
192 printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ 194 printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */
193 fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); 195 for (j = 0; j < NEQ; j++) {
194 if (flag >= 0) { /* successful solve: update time */ 196 ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j);
195 tout += dTout; 197 }
196 tout = (tout > Tf) ? Tf : tout; 198
197 } else { /* unsuccessful solve: break */ 199 if (flag < 0) { /* unsuccessful solve: break */
198 fprintf(stderr,"Solver failure, stopping integration\n"); 200 fprintf(stderr,"Solver failure, stopping integration\n");
199 break; 201 break;
200 } 202 }
201 } 203 }
202 printf(" ---------------------\n"); 204 printf(" ---------------------\n");
203 fclose(UFID);
204 205
205 for (i = 0; i < NEQ; i++) { 206 for (i = 0; i < NEQ; i++) {
206 ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); 207 ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i);
@@ -244,9 +245,9 @@ solveOdeC fun f0 ts = unsafePerformIO $ do
244 245
245 return flag; 246 return flag;
246 } |] 247 } |]
247 if res ==0 248 if res == 0
248 then do 249 then do
249 v <- V.freeze fMut 250 m <- V.freeze qMatMut
250 return $ Right v 251 return $ Right m
251 else do 252 else do
252 return $ Left res 253 return $ Left res