diff options
Diffstat (limited to 'packages/sundials/src')
-rw-r--r-- | packages/sundials/src/Main.hs | 47 | ||||
-rw-r--r-- | packages/sundials/src/helpers.c | 22 | ||||
-rw-r--r-- | packages/sundials/src/helpers.h | 3 |
3 files changed, 34 insertions, 38 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index b3ebcb3..fc48710 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -78,11 +78,27 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] | |||
78 | u = v V.! 0 | 78 | u = v V.! 0 |
79 | lamda = -100.0 | 79 | lamda = -100.0 |
80 | 80 | ||
81 | brusselator :: Double -> V.Vector Double -> V.Vector Double | ||
82 | brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2 | ||
83 | , w * u - v * u^2 | ||
84 | , (b - w) / eps - w * u | ||
85 | ] | ||
86 | where | ||
87 | a = 1.0 | ||
88 | b = 3.5 | ||
89 | eps = 5.0e-6 | ||
90 | u = x V.! 0 | ||
91 | v = x V.! 1 | ||
92 | w = x V.! 2 | ||
93 | |||
81 | solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> | 94 | solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> |
82 | V.Vector Double -> | 95 | V.Vector CDouble -> |
83 | CInt | 96 | Either CInt (V.Vector CDouble) |
84 | solveOdeC fun f0 = unsafePerformIO $ do | 97 | solveOdeC fun f0 = unsafePerformIO $ do |
85 | let dim = V.length f0 | 98 | let dim = V.length f0 |
99 | nEq :: CLong | ||
100 | nEq = fromIntegral dim | ||
101 | fMut <- V.thaw f0 | ||
86 | -- We need the types that sundials expects. These are tied together | 102 | -- We need the types that sundials expects. These are tied together |
87 | -- in 'Types'. The Haskell type is currently empty! | 103 | -- in 'Types'. The Haskell type is currently empty! |
88 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt | 104 | let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt |
@@ -110,21 +126,22 @@ solveOdeC fun f0 = unsafePerformIO $ do | |||
110 | realtype T0 = RCONST(0.0); /* initial time */ | 126 | realtype T0 = RCONST(0.0); /* initial time */ |
111 | realtype Tf = RCONST(10.0); /* final time */ | 127 | realtype Tf = RCONST(10.0); /* final time */ |
112 | realtype dTout = RCONST(1.0); /* time between outputs */ | 128 | realtype dTout = RCONST(1.0); /* time between outputs */ |
113 | sunindextype NEQ = 1; /* number of dependent vars. */ | 129 | sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */ |
114 | realtype reltol = 1.0e-6; /* tolerances */ | 130 | realtype reltol = 1.0e-6; /* tolerances */ |
115 | realtype abstol = 1.0e-10; | 131 | realtype abstol = 1.0e-10; |
116 | realtype lamda = -100.0; /* stiffness parameter */ | ||
117 | 132 | ||
118 | /* Initial diagnostics output */ | 133 | /* Initial diagnostics output */ |
119 | printf("\nAnalytical ODE test problem:\n"); | 134 | printf("\nAnalytical ODE test problem:\n"); |
120 | printf(" lamda = %"GSYM"\n", lamda); | ||
121 | printf(" reltol = %.1"ESYM"\n", reltol); | 135 | printf(" reltol = %.1"ESYM"\n", reltol); |
122 | printf(" abstol = %.1"ESYM"\n\n",abstol); | 136 | printf(" abstol = %.1"ESYM"\n\n",abstol); |
123 | 137 | ||
124 | /* Initialize data structures */ | 138 | /* Initialize data structures */ |
125 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ | 139 | y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ |
126 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; | 140 | if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; |
127 | N_VConst(0.0, y); /* Specify initial condition */ | 141 | int i; |
142 | for (i = 0; i < NEQ; i++) { | ||
143 | NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i]; | ||
144 | }; /* Specify initial condition */ | ||
128 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ | 145 | arkode_mem = ARKodeCreate(); /* Create the solver memory */ |
129 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; | 146 | if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; |
130 | 147 | ||
@@ -139,8 +156,6 @@ solveOdeC fun f0 = unsafePerformIO $ do | |||
139 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; | 156 | if (check_flag(&flag, "ARKodeInit", 1)) return 1; |
140 | 157 | ||
141 | /* Set routines */ | 158 | /* Set routines */ |
142 | flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */ | ||
143 | if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1; | ||
144 | flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ | 159 | flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ |
145 | if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; | 160 | if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; |
146 | 161 | ||
@@ -182,6 +197,10 @@ solveOdeC fun f0 = unsafePerformIO $ do | |||
182 | printf(" ---------------------\n"); | 197 | printf(" ---------------------\n"); |
183 | fclose(UFID); | 198 | fclose(UFID); |
184 | 199 | ||
200 | for (i = 0; i < NEQ; i++) { | ||
201 | ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i); | ||
202 | }; | ||
203 | |||
185 | /* Get/print some final statistics on how the solve progressed */ | 204 | /* Get/print some final statistics on how the solve progressed */ |
186 | flag = ARKodeGetNumSteps(arkode_mem, &nst); | 205 | flag = ARKodeGetNumSteps(arkode_mem, &nst); |
187 | check_flag(&flag, "ARKodeGetNumSteps", 1); | 206 | check_flag(&flag, "ARKodeGetNumSteps", 1); |
@@ -212,9 +231,6 @@ solveOdeC fun f0 = unsafePerformIO $ do | |||
212 | printf(" Total number of linear solver convergence failures = %li\n", ncfn); | 231 | printf(" Total number of linear solver convergence failures = %li\n", ncfn); |
213 | printf(" Total number of error test failures = %li\n\n", netf); | 232 | printf(" Total number of error test failures = %li\n\n", netf); |
214 | 233 | ||
215 | /* check the solution error */ | ||
216 | flag = check_ans(y, t, reltol, abstol); | ||
217 | |||
218 | /* Clean up and return */ | 234 | /* Clean up and return */ |
219 | N_VDestroy(y); /* Free y vector */ | 235 | N_VDestroy(y); /* Free y vector */ |
220 | ARKodeFree(&arkode_mem); /* Free integrator memory */ | 236 | ARKodeFree(&arkode_mem); /* Free integrator memory */ |
@@ -223,9 +239,14 @@ solveOdeC fun f0 = unsafePerformIO $ do | |||
223 | 239 | ||
224 | return flag; | 240 | return flag; |
225 | } |] | 241 | } |] |
226 | return res | 242 | if res ==0 |
243 | then do | ||
244 | v <- V.freeze fMut | ||
245 | return $ Right v | ||
246 | else do | ||
247 | return $ Left res | ||
227 | 248 | ||
228 | main :: IO () | 249 | main :: IO () |
229 | main = do | 250 | main = do |
230 | let res = solveOdeC (coerce stiffish) (V.fromList [1.0]) | 251 | let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0]) |
231 | putStrLn $ show res | 252 | putStrLn $ show res |
diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c index 420d3be..f0ca592 100644 --- a/packages/sundials/src/helpers.c +++ b/packages/sundials/src/helpers.c | |||
@@ -42,25 +42,3 @@ int check_flag(void *flagvalue, const char *funcname, int opt) | |||
42 | 42 | ||
43 | return 0; | 43 | return 0; |
44 | } | 44 | } |
45 | |||
46 | /* check the computed solution */ | ||
47 | int check_ans(N_Vector y, realtype t, realtype rtol, realtype atol) | ||
48 | { | ||
49 | int passfail=0; /* answer pass (0) or fail (1) flag */ | ||
50 | realtype ans, err, ewt; /* answer data, error, and error weight */ | ||
51 | realtype ONE=RCONST(1.0); | ||
52 | |||
53 | /* compute solution error */ | ||
54 | ans = atan(t); | ||
55 | ewt = ONE / (rtol * SUNRabs(ans) + atol); | ||
56 | err = ewt * SUNRabs(NV_Ith_S(y,0) - ans); | ||
57 | |||
58 | /* is the solution within the tolerances? */ | ||
59 | passfail = (err < ONE) ? 0 : 1; | ||
60 | |||
61 | if (passfail) { | ||
62 | fprintf(stdout, "\nSUNDIALS_WARNING: check_ans error=%g \n\n", err); | ||
63 | } | ||
64 | |||
65 | return(passfail); | ||
66 | } | ||
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h index 69a3dfe..5c1d9f3 100644 --- a/packages/sundials/src/helpers.h +++ b/packages/sundials/src/helpers.h | |||
@@ -20,6 +20,3 @@ typedef struct _N_VectorContent_Serial BazType; | |||
20 | NULL pointer | 20 | NULL pointer |
21 | */ | 21 | */ |
22 | int check_flag(void *flagvalue, const char *funcname, int opt); | 22 | int check_flag(void *flagvalue, const char *funcname, int opt); |
23 | |||
24 | /* check the computed solution */ | ||
25 | int check_ans(N_Vector y, realtype t, realtype rtol, realtype atol); | ||