diff options
Diffstat (limited to 'packages/sundials')
-rw-r--r-- | packages/sundials/src/Main.hs | 7 | ||||
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 59 |
2 files changed, 37 insertions, 29 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 978088b..2a561c4 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -16,7 +16,14 @@ brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | |||
16 | v = x V.! 1 | 16 | v = x V.! 1 |
17 | w = x V.! 2 | 17 | w = x V.! 2 |
18 | 18 | ||
19 | stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | ||
20 | where | ||
21 | lamda = -100.0 | ||
22 | u = v V.! 0 | ||
23 | |||
19 | main :: IO () | 24 | main :: IO () |
20 | main = do | 25 | main = do |
21 | let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0]) | 26 | let res = solveOde brusselator (V.fromList [1.2, 3.1, 3.0]) (V.fromList [0.0, 1.0 .. 10.0]) |
22 | putStrLn $ show res | 27 | putStrLn $ show res |
28 | let res = solveOde stiffish (V.fromList [1.0]) (V.fromList [0.0, 0.1 .. 10.0]) | ||
29 | putStrLn $ show res | ||
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 9de20b6..c5d085e 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -10,18 +10,21 @@ module Numeric.Sundials.Arkode.ODE ( solveOde ) where | |||
10 | 10 | ||
11 | import qualified Language.C.Inline as C | 11 | import qualified Language.C.Inline as C |
12 | import qualified Language.C.Inline.Unsafe as CU | 12 | import qualified Language.C.Inline.Unsafe as CU |
13 | |||
13 | import Data.Monoid ((<>)) | 14 | import Data.Monoid ((<>)) |
15 | |||
14 | import Foreign.C.Types | 16 | import Foreign.C.Types |
15 | import Foreign.Ptr (Ptr) | 17 | import Foreign.Ptr (Ptr) |
18 | import Foreign.ForeignPtr (newForeignPtr_) | ||
19 | import Foreign.Storable (Storable, peekByteOff) | ||
20 | |||
16 | import qualified Data.Vector.Storable as V | 21 | import qualified Data.Vector.Storable as V |
22 | import qualified Data.Vector.Storable.Mutable as VM | ||
17 | 23 | ||
18 | import Data.Coerce (coerce) | 24 | import Data.Coerce (coerce) |
19 | import qualified Data.Vector.Storable.Mutable as VM | ||
20 | import Foreign.ForeignPtr (newForeignPtr_) | ||
21 | import Foreign.Storable (Storable) | ||
22 | import System.IO.Unsafe (unsafePerformIO) | 25 | import System.IO.Unsafe (unsafePerformIO) |
23 | 26 | ||
24 | import Foreign.Storable (peekByteOff) | 27 | import Numeric.LinearAlgebra.Devel (createVector) |
25 | 28 | ||
26 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) | 29 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix) |
27 | 30 | ||
@@ -104,8 +107,12 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
104 | nTs = fromIntegral $ V.length ts | 107 | nTs = fromIntegral $ V.length ts |
105 | fMut <- V.thaw f0 | 108 | fMut <- V.thaw f0 |
106 | tMut <- V.thaw ts | 109 | tMut <- V.thaw ts |
110 | -- FIXME: I believe this gets taken from the ghc heap and so should | ||
111 | -- be subject to garbage collection. | ||
112 | quasiMatrixRes <- createVector ((fromIntegral dim) * (fromIntegral nTs)) | ||
113 | qMatMut <- V.thaw quasiMatrixRes | ||
107 | -- We need the types that sundials expects. These are tied together | 114 | -- We need the types that sundials expects. These are tied together |
108 | -- in 'Types'. The Haskell type is currently empty! | 115 | -- in 'Types'. FIXME: The Haskell type is currently empty! |
109 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 116 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt |
110 | funIO x y f _ptr = do | 117 | funIO x y f _ptr = do |
111 | -- Convert the pointer we get from C (y) to a vector, and then | 118 | -- Convert the pointer we get from C (y) to a vector, and then |
@@ -124,13 +131,12 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
124 | SUNLinearSolver LS = NULL; /* empty linear solver object */ | 131 | SUNLinearSolver LS = NULL; /* empty linear solver object */ |
125 | void *arkode_mem = NULL; /* empty ARKode memory structure */ | 132 | void *arkode_mem = NULL; /* empty ARKode memory structure */ |
126 | FILE *UFID; | 133 | FILE *UFID; |
127 | realtype t, tout; | 134 | realtype t; |
128 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; | 135 | long int nst, nst_a, nfe, nfi, nsetups, nje, nfeLS, nni, ncfn, netf; |
129 | 136 | ||
130 | /* general problem parameters */ | 137 | /* general problem parameters */ |
131 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ | 138 | realtype T0 = RCONST(($vec-ptr:(double *tMut))[0]); /* initial time */ |
132 | realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ | 139 | realtype Tf = RCONST(($vec-ptr:(double *tMut))[$(int nTs) - 1]); /* final time */ |
133 | realtype dTout = RCONST(1.0); /* time between outputs */ | ||
134 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ | 140 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ |
135 | realtype reltol = 1.0e-6; /* tolerances */ | 141 | realtype reltol = 1.0e-6; /* tolerances */ |
136 | realtype abstol = 1.0e-10; | 142 | realtype abstol = 1.0e-10; |
@@ -143,7 +149,7 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
143 | /* Initialize data structures */ | 149 | /* Initialize data structures */ |
144 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | 150 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ |
145 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | 151 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; |
146 | int i; | 152 | int i, j; |
147 | for (i = 0; i < NEQ; i++) { | 153 | for (i = 0; i < NEQ; i++) { |
148 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | 154 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; |
149 | }; /* Specify initial condition */ | 155 | }; /* Specify initial condition */ |
@@ -172,35 +178,30 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
172 | 178 | ||
173 | /* Linear solver interface */ | 179 | /* Linear solver interface */ |
174 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | 180 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ |
175 | /* Open output stream for results, output comment line */ | 181 | /* Output initial conditions */ |
176 | UFID = fopen("solution.txt","w"); | 182 | for (j = 0; j < NEQ; j++) { |
177 | fprintf(UFID,"# t u\n"); | 183 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); |
178 | 184 | } | |
179 | /* output initial condition to disk */ | 185 | |
180 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", T0, NV_Ith_S(y,0)); | ||
181 | |||
182 | /* Main time-stepping loop: calls ARKode to perform the integration, then | 186 | /* Main time-stepping loop: calls ARKode to perform the integration, then |
183 | prints results. Stops when the final time has been reached */ | 187 | prints results. Stops when the final time has been reached */ |
184 | t = T0; | ||
185 | tout = T0+dTout; | ||
186 | printf(" t u\n"); | 188 | printf(" t u\n"); |
187 | printf(" ---------------------\n"); | 189 | printf(" ---------------------\n"); |
188 | for (i = 0; i < $(int nTs); i++) { | 190 | for (i = 1; i < $(int nTs); i++) { |
189 | 191 | ||
190 | flag = ARKode(arkode_mem, tout, y, &t, ARK_NORMAL); /* call integrator */ | 192 | flag = ARKode(arkode_mem, ($vec-ptr:(double *tMut))[i], y, &t, ARK_NORMAL); /* call integrator */ |
191 | if (check_flag(&flag, "ARKode", 1)) break; | 193 | if (check_flag(&flag, "ARKode", 1)) break; |
192 | printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ | 194 | printf(" %10.6"FSYM" %10.6"FSYM"\n", t, NV_Ith_S(y,0)); /* access/print solution */ |
193 | fprintf(UFID," %.16"ESYM" %.16"ESYM"\n", t, NV_Ith_S(y,0)); | 195 | for (j = 0; j < NEQ; j++) { |
194 | if (flag >= 0) { /* successful solve: update time */ | 196 | ($vec-ptr:(double *qMatMut))[i * NEQ + j] = NV_Ith_S(y,j); |
195 | tout += dTout; | 197 | } |
196 | tout = (tout > Tf) ? Tf : tout; | 198 | |
197 | } else { /* unsuccessful solve: break */ | 199 | if (flag < 0) { /* unsuccessful solve: break */ |
198 | fprintf(stderr,"Solver failure, stopping integration\n"); | 200 | fprintf(stderr,"Solver failure, stopping integration\n"); |
199 | break; | 201 | break; |
200 | } | 202 | } |
201 | } | 203 | } |
202 | printf(" ---------------------\n"); | 204 | printf(" ---------------------\n"); |
203 | fclose(UFID); | ||
204 | 205 | ||
205 | for (i = 0; i < NEQ; i++) { | 206 | for (i = 0; i < NEQ; i++) { |
206 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); | 207 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); |
@@ -244,9 +245,9 @@ solveOdeC fun f0 ts = unsafePerformIO $ do | |||
244 | 245 | ||
245 | return flag; | 246 | return flag; |
246 | } |] | 247 | } |] |
247 | if res ==0 | 248 | if res == 0 |
248 | then do | 249 | then do |
249 | v <- V.freeze fMut | 250 | m <- V.freeze qMatMut |
250 | return $ Right v | 251 | return $ Right m |
251 | else do | 252 | else do |
252 | return $ Left res | 253 | return $ Left res |