diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2018-03-31 13:06:35 +0100 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2018-03-31 13:06:35 +0100 |
commit | 3c4411e48cbcfaf8035e893ac63aa250fcc56d3e (patch) | |
tree | 0542a6ab2c68f2c9245dedf7e04d4e15e279cfd3 /packages/sundials/src/Numeric | |
parent | 5bcc77b1e115a8c8eb94a1aa1a441618bfeb0b54 (diff) |
Add in the Jacobian
Diffstat (limited to 'packages/sundials/src/Numeric')
-rw-r--r-- | packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | 99 |
1 files changed, 75 insertions, 24 deletions
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs index 30ff4c8..5af9e41 100644 --- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs +++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs | |||
@@ -35,6 +35,7 @@ module Numeric.Sundials.Arkode.ODE ( solveOde | |||
35 | , getBT | 35 | , getBT |
36 | , btGet | 36 | , btGet |
37 | , ODEMethod(..) | 37 | , ODEMethod(..) |
38 | , odeSolveV | ||
38 | ) where | 39 | ) where |
39 | 40 | ||
40 | import qualified Language.C.Inline as C | 41 | import qualified Language.C.Inline as C |
@@ -45,7 +46,7 @@ import Data.Monoid ((<>)) | |||
45 | import Foreign.C.Types | 46 | import Foreign.C.Types |
46 | import Foreign.Ptr (Ptr) | 47 | import Foreign.Ptr (Ptr) |
47 | import Foreign.ForeignPtr (newForeignPtr_) | 48 | import Foreign.ForeignPtr (newForeignPtr_) |
48 | import Foreign.Storable (Storable, peekByteOff) | 49 | import Foreign.Storable (Storable) |
49 | 50 | ||
50 | import qualified Data.Vector.Storable as V | 51 | import qualified Data.Vector.Storable as V |
51 | import qualified Data.Vector.Storable.Mutable as VM | 52 | import qualified Data.Vector.Storable.Mutable as VM |
@@ -55,7 +56,8 @@ import System.IO.Unsafe (unsafePerformIO) | |||
55 | 56 | ||
56 | import Numeric.LinearAlgebra.Devel (createVector) | 57 | import Numeric.LinearAlgebra.Devel (createVector) |
57 | 58 | ||
58 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) | 59 | import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), |
60 | subMatrix, rows, cols, toLists) | ||
59 | 61 | ||
60 | import qualified Types as T | 62 | import qualified Types as T |
61 | import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) | 63 | import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) |
@@ -78,12 +80,34 @@ C.include "../../../helpers.h" | |||
78 | C.include "Arkode_hsc.h" | 80 | C.include "Arkode_hsc.h" |
79 | 81 | ||
80 | 82 | ||
81 | getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) | 83 | getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble) |
82 | getDataFromContents len ptr = do | 84 | getDataFromContents len ptr = do |
83 | qtr <- B.getContentPtr ptr | 85 | qtr <- B.getContentPtr ptr |
84 | rtr <- B.getData qtr | 86 | rtr <- B.getData qtr |
85 | vectorFromC len rtr | 87 | vectorFromC len rtr |
86 | 88 | ||
89 | -- FIXME: Potentially an instance of Storable | ||
90 | getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix | ||
91 | getMatrixDataFromContents ptr = do | ||
92 | qtr <- B.getContentMatrixPtr ptr | ||
93 | rs <- B.getNRows qtr | ||
94 | cs <- B.getNCols qtr | ||
95 | rtr <- B.getMatrixData qtr | ||
96 | vs <- vectorFromC (fromIntegral $ rs * cs) rtr | ||
97 | return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs } | ||
98 | |||
99 | putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO () | ||
100 | putMatrixDataFromContents mat ptr = do | ||
101 | let rs = T.rows mat | ||
102 | cs = T.cols mat | ||
103 | vs = T.vals mat | ||
104 | qtr <- B.getContentMatrixPtr ptr | ||
105 | B.putNRows rs qtr | ||
106 | B.putNCols cs qtr | ||
107 | rtr <- B.getMatrixData qtr | ||
108 | vectorToC vs (fromIntegral $ rs * cs) rtr | ||
109 | -- FIXME: END | ||
110 | |||
87 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () | 111 | putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () |
88 | putDataInContents vec len ptr = do | 112 | putDataInContents vec len ptr = do |
89 | qtr <- B.getContentPtr ptr | 113 | qtr <- B.getContentPtr ptr |
@@ -103,16 +127,16 @@ vectorToC vec len ptr = do | |||
103 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | 127 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec |
104 | 128 | ||
105 | data SundialsDiagnostics = SundialsDiagnostics { | 129 | data SundialsDiagnostics = SundialsDiagnostics { |
106 | aRKodeGetNumSteps :: Int | 130 | _aRKodeGetNumSteps :: Int |
107 | , aRKodeGetNumStepAttempts :: Int | 131 | , _aRKodeGetNumStepAttempts :: Int |
108 | , aRKodeGetNumRhsEvals_fe :: Int | 132 | , _aRKodeGetNumRhsEvals_fe :: Int |
109 | , aRKodeGetNumRhsEvals_fi :: Int | 133 | , _aRKodeGetNumRhsEvals_fi :: Int |
110 | , aRKodeGetNumLinSolvSetups :: Int | 134 | , _aRKodeGetNumLinSolvSetups :: Int |
111 | , aRKodeGetNumErrTestFails :: Int | 135 | , _aRKodeGetNumErrTestFails :: Int |
112 | , aRKodeGetNumNonlinSolvIters :: Int | 136 | , _aRKodeGetNumNonlinSolvIters :: Int |
113 | , aRKodeGetNumNonlinSolvConvFails :: Int | 137 | , _aRKodeGetNumNonlinSolvConvFails :: Int |
114 | , aRKDlsGetNumJacEvals :: Int | 138 | , _aRKDlsGetNumJacEvals :: Int |
115 | , aRKDlsGetNumRhsEvals :: Int | 139 | , _aRKDlsGetNumRhsEvals :: Int |
116 | } deriving Show | 140 | } deriving Show |
117 | 141 | ||
118 | -- | Stepping functions | 142 | -- | Stepping functions |
@@ -134,15 +158,16 @@ odeSolveV | |||
134 | -> Vector Double -- ^ initial conditions | 158 | -> Vector Double -- ^ initial conditions |
135 | -> Vector Double -- ^ desired solution times | 159 | -> Vector Double -- ^ desired solution times |
136 | -> Matrix Double -- ^ solution | 160 | -> Matrix Double -- ^ solution |
137 | odeSolveV meth hi epsAbs epsRel = undefined | 161 | odeSolveV _meth _hi _epsAbs _epsRel = undefined |
138 | 162 | ||
139 | odeSolve :: ODEMethod | 163 | odeSolve :: ODEMethod |
164 | -> (Double -> Vector Double -> Matrix Double) | ||
140 | -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 165 | -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
141 | -> [Double] -- ^ initial conditions | 166 | -> [Double] -- ^ initial conditions |
142 | -> Vector Double -- ^ desired solution times | 167 | -> Vector Double -- ^ desired solution times |
143 | -> Matrix Double -- ^ solution | 168 | -> Matrix Double -- ^ solution |
144 | odeSolve method f y0 ts = | 169 | odeSolve method jac f y0 ts = |
145 | case solveOde method 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of | 170 | case solveOde method jac' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of |
146 | Left c -> error $ show c -- FIXME | 171 | Left c -> error $ show c -- FIXME |
147 | Right (v, _) -> (nR >< nC) (V.toList v) | 172 | Right (v, _) -> (nR >< nC) (V.toList v) |
148 | where | 173 | where |
@@ -150,30 +175,40 @@ odeSolve method f y0 ts = | |||
150 | nR = length us | 175 | nR = length us |
151 | nC = length y0 | 176 | nC = length y0 |
152 | g t x0 = V.fromList $ f t (V.toList x0) | 177 | g t x0 = V.fromList $ f t (V.toList x0) |
178 | jac' t v = foo $ jac t (V.fromList $ toList v) | ||
179 | foo m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs } | ||
180 | where | ||
181 | nr = fromIntegral $ rows m | ||
182 | nc = fromIntegral $ cols m | ||
183 | vs = V.fromList $ map coerce $ concat $ toLists m | ||
153 | 184 | ||
154 | solveOde :: | 185 | solveOde :: |
155 | ODEMethod | 186 | ODEMethod |
187 | -> (Double -> V.Vector Double -> T.SunMatrix) | ||
156 | -> Double | 188 | -> Double |
157 | -> Double | 189 | -> Double |
158 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 190 | -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
159 | -> V.Vector Double -- ^ Initial conditions | 191 | -> V.Vector Double -- ^ Initial conditions |
160 | -> V.Vector Double -- ^ Desired solution times | 192 | -> V.Vector Double -- ^ Desired solution times |
161 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution | 193 | -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution |
162 | solveOde method relTol absTol f y0 tt = | 194 | solveOde method jac relTol absTol f y0 tt = |
163 | case solveOdeC (fromIntegral $ fromEnum method) (CDouble relTol) (CDouble absTol) | 195 | case solveOdeC (fromIntegral $ fromEnum method) jacH (CDouble relTol) (CDouble absTol) |
164 | (coerce f) (coerce y0) (coerce tt) of | 196 | (coerce f) (coerce y0) (coerce tt) of |
165 | Left c -> Left $ fromIntegral c | 197 | Left c -> Left $ fromIntegral c |
166 | Right (v, d) -> Right (coerce v, d) | 198 | Right (v, d) -> Right (coerce v, d) |
199 | where | ||
200 | jacH t v = jac (coerce t) (coerce v) | ||
167 | 201 | ||
168 | solveOdeC :: | 202 | solveOdeC :: |
169 | CInt -> | 203 | CInt -> |
204 | (CDouble -> V.Vector CDouble -> T.SunMatrix) -> | ||
170 | CDouble -> | 205 | CDouble -> |
171 | CDouble -> | 206 | CDouble -> |
172 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) | 207 | (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) |
173 | -> V.Vector CDouble -- ^ Initial conditions | 208 | -> V.Vector CDouble -- ^ Initial conditions |
174 | -> V.Vector CDouble -- ^ Desired solution times | 209 | -> V.Vector CDouble -- ^ Desired solution times |
175 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution | 210 | -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution |
176 | solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do | 211 | solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do |
177 | let dim = V.length f0 | 212 | let dim = V.length f0 |
178 | nEq :: CLong | 213 | nEq :: CLong |
179 | nEq = fromIntegral dim | 214 | nEq = fromIntegral dim |
@@ -197,9 +232,19 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do | |||
197 | fImm <- fun x <$> getDataFromContents dim y | 232 | fImm <- fun x <$> getDataFromContents dim y |
198 | -- Fill in the provided pointer with the resulting vector. | 233 | -- Fill in the provided pointer with the resulting vector. |
199 | putDataInContents fImm dim f | 234 | putDataInContents fImm dim f |
200 | -- I don't understand what this comment means | 235 | -- FIXME: I don't understand what this comment means |
236 | -- Unsafe since the function will be called many times. | ||
237 | [CU.exp| int{ 0 } |] | ||
238 | let jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix -> | ||
239 | Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector -> | ||
240 | IO CInt | ||
241 | jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do | ||
242 | foo <- jacH t <$> getDataFromContents dim y | ||
243 | putMatrixDataFromContents foo jacS | ||
244 | -- FIXME: I don't understand what this comment means | ||
201 | -- Unsafe since the function will be called many times. | 245 | -- Unsafe since the function will be called many times. |
202 | [CU.exp| int{ 0 } |] | 246 | [CU.exp| int{ 0 } |] |
247 | |||
203 | res <- [C.block| int { | 248 | res <- [C.block| int { |
204 | /* general problem variables */ | 249 | /* general problem variables */ |
205 | int flag; /* reusable error-checking flag */ | 250 | int flag; /* reusable error-checking flag */ |
@@ -246,6 +291,11 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do | |||
246 | 291 | ||
247 | /* Linear solver interface */ | 292 | /* Linear solver interface */ |
248 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ | 293 | flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ |
294 | |||
295 | flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[]))); | ||
296 | if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1; | ||
297 | |||
298 | |||
249 | /* Store initial conditions */ | 299 | /* Store initial conditions */ |
250 | for (j = 0; j < NEQ; j++) { | 300 | for (j = 0; j < NEQ; j++) { |
251 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); | 301 | ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); |
@@ -340,7 +390,8 @@ btGet :: ODEMethod -> Matrix Double | |||
340 | btGet method = case getBT method of | 390 | btGet method = case getBT method of |
341 | Left c -> error $ show c -- FIXME | 391 | Left c -> error $ show c -- FIXME |
342 | -- FIXME | 392 | -- FIXME |
343 | Right (v, sqp) -> subMatrix (0, 0) (2, 2) $ (B.arkSMax >< B.arkSMax) (V.toList v) | 393 | Right (v, _sqp) -> subMatrix (0, 0) (2, 2) $ |
394 | (B.arkSMax >< B.arkSMax) (V.toList v) | ||
344 | 395 | ||
345 | getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) | 396 | getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) |
346 | getBT method = case getButcherTable method of | 397 | getBT method = case getButcherTable method of |
@@ -352,9 +403,9 @@ getButcherTable method = unsafePerformIO $ do | |||
352 | -- arkode seems to want an ODE in order to set and then get the | 403 | -- arkode seems to want an ODE in order to set and then get the |
353 | -- Butcher tableau so here's one to keep it happy | 404 | -- Butcher tableau so here's one to keep it happy |
354 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble | 405 | let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble |
355 | fun t ys = V.fromList [ ys V.! 0 ] | 406 | fun _t ys = V.fromList [ ys V.! 0 ] |
356 | f0 = V.fromList [ 1.0 ] | 407 | f0 = V.fromList [ 1.0 ] |
357 | ts = V.fromList [ 0.0 ] | 408 | ts = V.fromList [ 0.0 ] |
358 | dim = V.length f0 | 409 | dim = V.length f0 |
359 | nEq :: CLong | 410 | nEq :: CLong |
360 | nEq = fromIntegral dim | 411 | nEq = fromIntegral dim |