summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-20 12:03:09 +0000
committerDominic Steinitz <dominic@steinitz.org>2018-03-20 12:03:09 +0000
commit755175a557d07c6f73683f358ddd8f8ee07f26a9 (patch)
tree447227ac65e48e6d4051d3f48ed30c8fc586083f
parent17ba35af029cee0122f91fc91427a307a5f11dfa (diff)
Handle arbitrary systems
-rw-r--r--packages/sundials/hmatrix-sundials.cabal2
-rw-r--r--packages/sundials/src/Main.hs47
-rw-r--r--packages/sundials/src/helpers.c22
-rw-r--r--packages/sundials/src/helpers.h3
4 files changed, 35 insertions, 39 deletions
diff --git a/packages/sundials/hmatrix-sundials.cabal b/packages/sundials/hmatrix-sundials.cabal
index 43a83d0..d928ab1 100644
--- a/packages/sundials/hmatrix-sundials.cabal
+++ b/packages/sundials/hmatrix-sundials.cabal
@@ -1,7 +1,7 @@
1-- Initial sundials.cabal generated by cabal init. For further 1-- Initial sundials.cabal generated by cabal init. For further
2-- documentation, see http://haskell.org/cabal/users-guide/ 2-- documentation, see http://haskell.org/cabal/users-guide/
3 3
4name: sundials 4name: hmatrix-sundials
5version: 0.1.0.0 5version: 0.1.0.0
6-- synopsis: 6-- synopsis:
7-- description: 7-- description:
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index b3ebcb3..fc48710 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -78,11 +78,27 @@ stiffish t v = V.fromList [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
78 u = v V.! 0 78 u = v V.! 0
79 lamda = -100.0 79 lamda = -100.0
80 80
81brusselator :: Double -> V.Vector Double -> V.Vector Double
82brusselator _t x = V.fromList [ a - (w + 1) * u + v * u^2
83 , w * u - v * u^2
84 , (b - w) / eps - w * u
85 ]
86 where
87 a = 1.0
88 b = 3.5
89 eps = 5.0e-6
90 u = x V.! 0
91 v = x V.! 1
92 w = x V.! 2
93
81solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> 94solveOdeC :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
82 V.Vector Double -> 95 V.Vector CDouble ->
83 CInt 96 Either CInt (V.Vector CDouble)
84solveOdeC fun f0 = unsafePerformIO $ do 97solveOdeC fun f0 = unsafePerformIO $ do
85 let dim = V.length f0 98 let dim = V.length f0
99 nEq :: CLong
100 nEq = fromIntegral dim
101 fMut <- V.thaw f0
86 -- We need the types that sundials expects. These are tied together 102 -- We need the types that sundials expects. These are tied together
87 -- in 'Types'. The Haskell type is currently empty! 103 -- in 'Types'. The Haskell type is currently empty!
88 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt 104 let funIO :: CDouble -> Ptr T.BarType -> Ptr T.BarType -> Ptr () -> IO CInt
@@ -110,21 +126,22 @@ solveOdeC fun f0 = unsafePerformIO $ do
110 realtype T0 = RCONST(0.0); /* initial time */ 126 realtype T0 = RCONST(0.0); /* initial time */
111 realtype Tf = RCONST(10.0); /* final time */ 127 realtype Tf = RCONST(10.0); /* final time */
112 realtype dTout = RCONST(1.0); /* time between outputs */ 128 realtype dTout = RCONST(1.0); /* time between outputs */
113 sunindextype NEQ = 1; /* number of dependent vars. */ 129 sunindextype NEQ = $(sunindextype nEq); /* number of dependent vars. */
114 realtype reltol = 1.0e-6; /* tolerances */ 130 realtype reltol = 1.0e-6; /* tolerances */
115 realtype abstol = 1.0e-10; 131 realtype abstol = 1.0e-10;
116 realtype lamda = -100.0; /* stiffness parameter */
117 132
118 /* Initial diagnostics output */ 133 /* Initial diagnostics output */
119 printf("\nAnalytical ODE test problem:\n"); 134 printf("\nAnalytical ODE test problem:\n");
120 printf(" lamda = %"GSYM"\n", lamda);
121 printf(" reltol = %.1"ESYM"\n", reltol); 135 printf(" reltol = %.1"ESYM"\n", reltol);
122 printf(" abstol = %.1"ESYM"\n\n",abstol); 136 printf(" abstol = %.1"ESYM"\n\n",abstol);
123 137
124 /* Initialize data structures */ 138 /* Initialize data structures */
125 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */ 139 y = N_VNew_Serial(NEQ); /* Create serial vector for solution */
126 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1; 140 if (check_flag((void *)y, "N_VNew_Serial", 0)) return 1;
127 N_VConst(0.0, y); /* Specify initial condition */ 141 int i;
142 for (i = 0; i < NEQ; i++) {
143 NV_Ith_S(y,i) = ($vec-ptr:(double *fMut))[i];
144 }; /* Specify initial condition */
128 arkode_mem = ARKodeCreate(); /* Create the solver memory */ 145 arkode_mem = ARKodeCreate(); /* Create the solver memory */
129 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1; 146 if (check_flag((void *)arkode_mem, "ARKodeCreate", 0)) return 1;
130 147
@@ -139,8 +156,6 @@ solveOdeC fun f0 = unsafePerformIO $ do
139 if (check_flag(&flag, "ARKodeInit", 1)) return 1; 156 if (check_flag(&flag, "ARKodeInit", 1)) return 1;
140 157
141 /* Set routines */ 158 /* Set routines */
142 flag = ARKodeSetUserData(arkode_mem, (void *) &lamda); /* Pass lamda to user functions */
143 if (check_flag(&flag, "ARKodeSetUserData", 1)) return 1;
144 flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */ 159 flag = ARKodeSStolerances(arkode_mem, reltol, abstol); /* Specify tolerances */
145 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1; 160 if (check_flag(&flag, "ARKodeSStolerances", 1)) return 1;
146 161
@@ -182,6 +197,10 @@ solveOdeC fun f0 = unsafePerformIO $ do
182 printf(" ---------------------\n"); 197 printf(" ---------------------\n");
183 fclose(UFID); 198 fclose(UFID);
184 199
200 for (i = 0; i < NEQ; i++) {
201 ($vec-ptr:(double *fMut))[i] = NV_Ith_S(y,i);
202 };
203
185 /* Get/print some final statistics on how the solve progressed */ 204 /* Get/print some final statistics on how the solve progressed */
186 flag = ARKodeGetNumSteps(arkode_mem, &nst); 205 flag = ARKodeGetNumSteps(arkode_mem, &nst);
187 check_flag(&flag, "ARKodeGetNumSteps", 1); 206 check_flag(&flag, "ARKodeGetNumSteps", 1);
@@ -212,9 +231,6 @@ solveOdeC fun f0 = unsafePerformIO $ do
212 printf(" Total number of linear solver convergence failures = %li\n", ncfn); 231 printf(" Total number of linear solver convergence failures = %li\n", ncfn);
213 printf(" Total number of error test failures = %li\n\n", netf); 232 printf(" Total number of error test failures = %li\n\n", netf);
214 233
215 /* check the solution error */
216 flag = check_ans(y, t, reltol, abstol);
217
218 /* Clean up and return */ 234 /* Clean up and return */
219 N_VDestroy(y); /* Free y vector */ 235 N_VDestroy(y); /* Free y vector */
220 ARKodeFree(&arkode_mem); /* Free integrator memory */ 236 ARKodeFree(&arkode_mem); /* Free integrator memory */
@@ -223,9 +239,14 @@ solveOdeC fun f0 = unsafePerformIO $ do
223 239
224 return flag; 240 return flag;
225 } |] 241 } |]
226 return res 242 if res ==0
243 then do
244 v <- V.freeze fMut
245 return $ Right v
246 else do
247 return $ Left res
227 248
228main :: IO () 249main :: IO ()
229main = do 250main = do
230 let res = solveOdeC (coerce stiffish) (V.fromList [1.0]) 251 let res = solveOdeC (coerce brusselator) (V.fromList [1.2, 3.1, 3.0])
231 putStrLn $ show res 252 putStrLn $ show res
diff --git a/packages/sundials/src/helpers.c b/packages/sundials/src/helpers.c
index 420d3be..f0ca592 100644
--- a/packages/sundials/src/helpers.c
+++ b/packages/sundials/src/helpers.c
@@ -42,25 +42,3 @@ int check_flag(void *flagvalue, const char *funcname, int opt)
42 42
43 return 0; 43 return 0;
44} 44}
45
46/* check the computed solution */
47int check_ans(N_Vector y, realtype t, realtype rtol, realtype atol)
48{
49 int passfail=0; /* answer pass (0) or fail (1) flag */
50 realtype ans, err, ewt; /* answer data, error, and error weight */
51 realtype ONE=RCONST(1.0);
52
53 /* compute solution error */
54 ans = atan(t);
55 ewt = ONE / (rtol * SUNRabs(ans) + atol);
56 err = ewt * SUNRabs(NV_Ith_S(y,0) - ans);
57
58 /* is the solution within the tolerances? */
59 passfail = (err < ONE) ? 0 : 1;
60
61 if (passfail) {
62 fprintf(stdout, "\nSUNDIALS_WARNING: check_ans error=%g \n\n", err);
63 }
64
65 return(passfail);
66}
diff --git a/packages/sundials/src/helpers.h b/packages/sundials/src/helpers.h
index 69a3dfe..5c1d9f3 100644
--- a/packages/sundials/src/helpers.h
+++ b/packages/sundials/src/helpers.h
@@ -20,6 +20,3 @@ typedef struct _N_VectorContent_Serial BazType;
20 NULL pointer 20 NULL pointer
21*/ 21*/
22int check_flag(void *flagvalue, const char *funcname, int opt); 22int check_flag(void *flagvalue, const char *funcname, int opt);
23
24/* check the computed solution */
25int check_ans(N_Vector y, realtype t, realtype rtol, realtype atol);