diff options
Diffstat (limited to 'packages/sundials')
-rw-r--r-- | packages/sundials/src/Main.hs | 44 |
1 files changed, 40 insertions, 4 deletions
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs index 5972be7..3d5f941 100644 --- a/packages/sundials/src/Main.hs +++ b/packages/sundials/src/Main.hs | |||
@@ -55,9 +55,22 @@ vectorToC vec len ptr = do | |||
55 | ptr' <- newForeignPtr_ ptr | 55 | ptr' <- newForeignPtr_ ptr |
56 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 56 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
57 | 57 | ||
58 | solve :: CDouble -> CInt | 58 | solve :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) -> |
59 | solve lambda = unsafePerformIO $ do | 59 | V.Vector Double -> |
60 | res <- [C.block| int { /* general problem variables */ | 60 | CDouble -> |
61 | CInt | ||
62 | solve fun f0 lambda = unsafePerformIO $ do | ||
63 | let dim = V.length f0 | ||
64 | let funIO x y f _ptr = do | ||
65 | -- Convert the pointer we get from C (y) to a vector, and then | ||
66 | -- apply the user-supplied function. | ||
67 | fImm <- fun x <$> vectorFromC dim y | ||
68 | -- Fill in the provided pointer with the resulting vector. | ||
69 | vectorToC fImm dim f | ||
70 | -- Unsafe since the function will be called many times. | ||
71 | [CU.exp| int{ 0 } |] | ||
72 | res <- [C.block| int { | ||
73 | /* general problem variables */ | ||
61 | int flag; /* reusable error-checking flag */ | 74 | int flag; /* reusable error-checking flag */ |
62 | N_Vector y = NULL; /* empty vector for storing solution */ | 75 | N_Vector y = NULL; /* empty vector for storing solution */ |
63 | SUNMatrix A = NULL; /* empty matrix for linear solver */ | 76 | SUNMatrix A = NULL; /* empty matrix for linear solver */ |
@@ -75,7 +88,30 @@ solve lambda = unsafePerformIO $ do | |||
75 | realtype reltol = 1.0e-6; /* tolerances */ | 88 | realtype reltol = 1.0e-6; /* tolerances */ |
76 | realtype abstol = 1.0e-10; | 89 | realtype abstol = 1.0e-10; |
77 | realtype lamda = -100.0; /* stiffness parameter */ | 90 | realtype lamda = -100.0; /* stiffness parameter */ |
78 | 91 | ||
92 | /* Beginning of stolen code from the Fortran interface */ | ||
93 | |||
94 | N_Vector F2C_ARKODE_vec; | ||
95 | F2C_ARKODE_vec = NULL; | ||
96 | F2C_ARKODE_vec = N_VNewEmpty_Serial(NEQ); /* was *N */ | ||
97 | if (F2C_ARKODE_vec == NULL) return 1; | ||
98 | |||
99 | /* Check for required vector operations */ | ||
100 | if(F2C_ARKODE_vec->ops->nvgetarraypointer == NULL) { | ||
101 | fprintf(stderr, "Error: getarraypointer vector operation is not implemented.\n\n"); | ||
102 | return 1; | ||
103 | } | ||
104 | if(F2C_ARKODE_vec->ops->nvsetarraypointer == NULL) { | ||
105 | fprintf(stderr, "Error: setarraypointer vector operation is not implemented.\n\n"); | ||
106 | return 1; | ||
107 | } | ||
108 | if(F2C_ARKODE_vec->ops->nvcloneempty == NULL) { | ||
109 | fprintf(stderr, "Error: cloneempty vector operation is not implemented.\n\n"); | ||
110 | return 1; | ||
111 | } | ||
112 | |||
113 | /* End of stolen code from the Fortran interface */ | ||
114 | |||
79 | /* Initial diagnostics output */ | 115 | /* Initial diagnostics output */ |
80 | printf("\nAnalytical ODE test problem:\n"); | 116 | printf("\nAnalytical ODE test problem:\n"); |
81 | printf(" lamda = %"GSYM"\n", lamda); | 117 | printf(" lamda = %"GSYM"\n", lamda); |