summaryrefslogtreecommitdiff
path: root/packages/sundials/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/sundials/src')
-rw-r--r--packages/sundials/src/Main.hs44
1 files changed, 40 insertions, 4 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 5972be7..3d5f941 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -55,9 +55,22 @@ vectorToC vec len ptr = do
55 ptr' <- newForeignPtr_ ptr 55 ptr' <- newForeignPtr_ ptr
56 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 56 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
57 57
58solve :: CDouble -> CInt 58solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) ->
59solve lambda = unsafePerformIO $ do 59 V.Vector Double ->
60 res <- [C.block| int { /* general problem variables */ 60 CDouble ->
61 CInt
62solve fun f0 lambda = unsafePerformIO $ do
63 let dim = V.length f0
64 let funIO x y f _ptr = do
65 -- Convert the pointer we get from C (y) to a vector, and then
66 -- apply the user-supplied function.
67 fImm <- fun x <$> vectorFromC dim y
68 -- Fill in the provided pointer with the resulting vector.
69 vectorToC fImm dim f
70 -- Unsafe since the function will be called many times.
71 [CU.exp| int{ 0 } |]
72 res <- [C.block| int {
73 /* general problem variables */
61 int flag; /* reusable error-checking flag */ 74 int flag; /* reusable error-checking flag */
62 N_Vector y = NULL; /* empty vector for storing solution */ 75 N_Vector y = NULL; /* empty vector for storing solution */
63 SUNMatrix A = NULL; /* empty matrix for linear solver */ 76 SUNMatrix A = NULL; /* empty matrix for linear solver */
@@ -75,7 +88,30 @@ solve lambda = unsafePerformIO $ do
75 realtype reltol = 1.0e-6; /* tolerances */ 88 realtype reltol = 1.0e-6; /* tolerances */
76 realtype abstol = 1.0e-10; 89 realtype abstol = 1.0e-10;
77 realtype lamda = -100.0; /* stiffness parameter */ 90 realtype lamda = -100.0; /* stiffness parameter */
78 91
92 /* Beginning of stolen code from the Fortran interface */
93
94 N_Vector F2C_ARKODE_vec;
95 F2C_ARKODE_vec = NULL;
96 F2C_ARKODE_vec = N_VNewEmpty_Serial(NEQ); /* was *N */
97 if (F2C_ARKODE_vec == NULL) return 1;
98
99 /* Check for required vector operations */
100 if(F2C_ARKODE_vec->ops->nvgetarraypointer == NULL) {
101 fprintf(stderr, "Error: getarraypointer vector operation is not implemented.\n\n");
102 return 1;
103 }
104 if(F2C_ARKODE_vec->ops->nvsetarraypointer == NULL) {
105 fprintf(stderr, "Error: setarraypointer vector operation is not implemented.\n\n");
106 return 1;
107 }
108 if(F2C_ARKODE_vec->ops->nvcloneempty == NULL) {
109 fprintf(stderr, "Error: cloneempty vector operation is not implemented.\n\n");
110 return 1;
111 }
112
113 /* End of stolen code from the Fortran interface */
114
79 /* Initial diagnostics output */ 115 /* Initial diagnostics output */
80 printf("\nAnalytical ODE test problem:\n"); 116 printf("\nAnalytical ODE test problem:\n");
81 printf(" lamda = %"GSYM"\n", lamda); 117 printf(" lamda = %"GSYM"\n", lamda);