summaryrefslogtreecommitdiff
path: root/packages/sundials/src/Main.hs
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-20 12:03:09 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-20 12:03:09 +0000
commit755175a557d07c6f73683f358ddd8f8ee07f26a9 (patch)
tree447227ac65e48e6d4051d3f48ed30c8fc586083f /packages/sundials/src/Main.hs
parent17ba35af029cee0122f91fc91427a307a5f11dfa (diff)
Handle arbitrary systems
Diffstat (limited to 'packages/sundials/src/Main.hs')
-rw-r--r--packages/sundials/src/Main.hs47
1 files changed, 34 insertions, 13 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index b3ebcb3..fc48710 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -78,11 +78,27 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
78 u = v V.! 0 78 u = v V.! 0
79 lamda = -100.0 79 lamda = -100.0
80 80
81brusselator :: Double -> V.Vector Double -> V.Vector Double
82brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2
83 , w * u - v * u^2
84 , (b - w) / eps - w * u
85 ]
86 where
87 a = 1.0
88 b = 3.5
89 eps = 5.0e-6
90 u = x V.! 0
91 v = x V.! 1
92 w = x V.! 2
93
81solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> 94solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
82 V.Vector Double -> 95 V.Vector CDouble ->
83 CInt 96 Either CInt (V.Vector CDouble)
84solveOdeC fun f0 = unsafePerformIO $ do 97solveOdeC fun f0 = unsafePerformIO $ do
85 let dim = V.length f0 98 let dim = V.length f0
99 nEq :: CLong
100 nEq = fromIntegral dim
101 fMut <- V.thaw f0
86 -- We need the types that sundials expects. These are tied together 102 -- We need the types that sundials expects. These are tied together
87 -- in 'Types'. The Haskell type is currently empty! 103 -- in 'Types'. The Haskell type is currently empty!
88 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt 104 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
@@ -110,21 +126,22 @@ solveOdeC fun f0 = unsafePerformIO $ do
110 realtype T0 = RCONST(0.0); /* initial time */ 126 realtype T0 = RCONST(0.0); /* initial time */
111 realtype Tf = RCONST(10.0); /* final time */ 127 realtype Tf = RCONST(10.0); /* final time */
112 realtype dTout = RCONST(1.0); /* time between outputs */ 128 realtype dTout = RCONST(1.0); /* time between outputs */
113 sunindextype NEQ = 1; /* number of dependent vars. */ 129 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
114 realtype reltol = 1.0e-6; /* tolerances */ 130 realtype reltol = 1.0e-6; /* tolerances */
115 realtype abstol = 1.0e-10; 131 realtype abstol = 1.0e-10;
116 realtype lamda = -100.0; /* stiffness parameter */
117 132
118 /* Initial diagnostics output */ 133 /* Initial diagnostics output */
119 printf("\nAnalytical ODE test problem:\n"); 134 printf("\nAnalytical ODE test problem:\n");
120 printf(" lamda = %"GSYM"\n", lamda);
121 printf(" reltol = %.1"ESYM"\n", reltol); 135 printf(" reltol = %.1"ESYM"\n", reltol);
122 printf(" abstol = %.1"ESYM"\n\n",abstol); 136 printf(" abstol = %.1"ESYM"\n\n",abstol);
123 137
124 /* Initialize data structures */ 138 /* Initialize data structures */
125 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ 139 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
126 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 140 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
127 N_VConst(0.0, y); /* Specify initial condition */ 141 int i;
142 for (i = 0; i < NEQ; i++) {
143 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
144 }; /* Specify initial condition */
128 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 145 arkode_mem = ARKodeCreate(); /* Create the solver memory */
129 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; 146 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
130 147
@@ -139,8 +156,6 @@ solveOdeC fun f0 = unsafePerformIO $ do
139 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 156 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
140 157
141 /* Set routines */ 158 /* Set routines */
142 flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */
143 if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1;
144 flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ 159 flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */
145 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; 160 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1;
146 161
@@ -182,6 +197,10 @@ solveOdeC fun f0 = unsafePerformIO $ do
182 printf(" ---------------------\n"); 197 printf(" ---------------------\n");
183 fclose(UFID); 198 fclose(UFID);
184 199
200 for (i = 0; i < NEQ; i++) {
201 ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i);
202 };
203
185 /* Get/print some final statistics on how the solve progressed */ 204 /* Get/print some final statistics on how the solve progressed */
186 flag = ARKodeGetNumSteps(arkode_mem, &nst); 205 flag = ARKodeGetNumSteps(arkode_mem, &nst);
187 check_flag(&flag, "ARKodeGetNumSteps", 1); 206 check_flag(&flag, "ARKodeGetNumSteps", 1);
@@ -212,9 +231,6 @@ solveOdeC fun f0 = unsafePerformIO $ do
212 printf(" Total number of linear solver convergence failures = %li\n", ncfn); 231 printf(" Total number of linear solver convergence failures = %li\n", ncfn);
213 printf(" Total number of error test failures = %li\n\n", netf); 232 printf(" Total number of error test failures = %li\n\n", netf);
214 233
215 /* check the solution error */
216 flag = check_ans(y, t, reltol, abstol);
217
218 /* Clean up and return */ 234 /* Clean up and return */
219 N_VDestroy(y); /* Free y vector */ 235 N_VDestroy(y); /* Free y vector */
220 ARKodeFree(&arkode_mem); /* Free integrator memory */ 236 ARKodeFree(&arkode_mem); /* Free integrator memory */
@@ -223,9 +239,14 @@ solveOdeC fun f0 = unsafePerformIO $ do
223 239
224 return flag; 240 return flag;
225 } |] 241 } |]
226 return res 242 if res ==0
243 then do
244 v <- V.freeze fMut
245 return $ Right v
246 else do
247 return $ Left res
227 248
228main :: IO () 249main :: IO ()
229main = do 250main = do
230 let res = solveOdeC (coerce stiffish) (V.fromList [1.0]) 251 let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0])
231 putStrLn $ show res 252 putStrLn $ show res