summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominic Steinitz <dominic@steinitz.org>2018-03-31 13:06:35 +0100
committerDominic Steinitz <dominic@steinitz.org>2018-03-31 13:06:35 +0100
commit3c4411e48cbcfaf8035e893ac63aa250fcc56d3e (patch)
tree0542a6ab2c68f2c9245dedf7e04d4e15e279cfd3
parent5bcc77b1e115a8c8eb94a1aa1a441618bfeb0b54 (diff)
Add in the Jacobian
-rw-r--r--packages/sundials/src/Arkode.hsc19
-rw-r--r--packages/sundials/src/Main.hs27
-rw-r--r--packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs99
-rw-r--r--packages/sundials/src/Types.hs7
4 files changed, 122 insertions, 30 deletions
diff --git a/packages/sundials/src/Arkode.hsc b/packages/sundials/src/Arkode.hsc
index 59e701e..ae2b40f 100644
--- a/packages/sundials/src/Arkode.hsc
+++ b/packages/sundials/src/Arkode.hsc
@@ -9,12 +9,31 @@ import Foreign.C.String
9 9
10#include <stdio.h> 10#include <stdio.h>
11#include <sundials/sundials_nvector.h> 11#include <sundials/sundials_nvector.h>
12#include <sundials/sundials_matrix.h>
12#include <nvector/nvector_serial.h> 13#include <nvector/nvector_serial.h>
14#include <sunmatrix/sunmatrix_dense.h>
13#include <arkode/arkode.h> 15#include <arkode/arkode.h>
14 16
15#def typedef struct _generic_N_Vector SunVector; 17#def typedef struct _generic_N_Vector SunVector;
16#def typedef struct _N_VectorContent_Serial SunContent; 18#def typedef struct _N_VectorContent_Serial SunContent;
17 19
20#def typedef struct _generic_SUNMatrix SunMatrix;
21#def typedef struct _SUNMatrixContent_Dense SunMatrixContent;
22
23getContentMatrixPtr ptr = (#peek SunMatrix, content) ptr
24
25getNRows :: Ptr b -> IO CInt
26getNRows ptr = (#peek SunMatrixContent, M) ptr
27putNRows :: CInt -> Ptr b -> IO ()
28putNRows nr ptr = (#poke SunMatrixContent, M) ptr nr
29
30getNCols :: Ptr b -> IO CInt
31getNCols ptr = (#peek SunMatrixContent, N) ptr
32putNCols :: CInt -> Ptr b -> IO ()
33putNCols nc ptr = (#poke SunMatrixContent, N) ptr nc
34
35getMatrixData ptr = (#peek SunMatrixContent, data) ptr
36
18getContentPtr :: Storable a => Ptr b -> IO a 37getContentPtr :: Storable a => Ptr b -> IO a
19getContentPtr ptr = (#peek SunVector, content) ptr 38getContentPtr ptr = (#peek SunVector, content) ptr
20 39
diff --git a/packages/sundials/src/Main.hs b/packages/sundials/src/Main.hs
index 71bcbac..01d3595 100644
--- a/packages/sundials/src/Main.hs
+++ b/packages/sundials/src/Main.hs
@@ -27,11 +27,26 @@ brusselator _t x = [ a - (w + 1) * u + v * u^2
27 v = x !! 1 27 v = x !! 1
28 w = x !! 2 28 w = x !! 2
29 29
30brussJac _t x = (3><3) [ (-(w + 1.0)) + 2.0 * u * v, w - 2.0 * u * v, (-w)
31 , u * u , (-(u * u)) , 0.0
32 , (-u) , u , (-1.0) / eps - u
33 ]
34 where
35 y = toList x
36 u = y !! 0
37 v = y !! 1
38 w = y !! 2
39 eps = 5.0e-6
40
30stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ] 41stiffish t v = [ lamda * u + 1.0 / (1.0 + t * t) - lamda * atan t ]
31 where 42 where
32 lamda = -100.0 43 lamda = -100.0
33 u = v !! 0 44 u = v !! 0
34 45
46stiffJac _t _v = (1><1) [ lamda ]
47 where
48 lamda = -100.0
49
35lSaxis :: [[Double]] -> P.Axis B D.V2 Double 50lSaxis :: [[Double]] -> P.Axis B D.V2 Double
36lSaxis xs = P.r2Axis &~ do 51lSaxis xs = P.r2Axis &~ do
37 let ts = xs!!0 52 let ts = xs!!0
@@ -77,14 +92,14 @@ main = do
77 putStrLn $ show res 92 putStrLn $ show res
78 putStrLn $ butcherTableauTex res 93 putStrLn $ butcherTableauTex res
79 94
80 let res = odeSolve KVAERNO_4_2_3 brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0]) 95 let res1 = odeSolve KVAERNO_4_2_3 brussJac brusselator [1.2, 3.1, 3.0] (fromList [0.0, 0.1 .. 10.0])
81 putStrLn $ show res 96 putStrLn $ show res1
82 renderRasterific "diagrams/brusselator.png" 97 renderRasterific "diagrams/brusselator.png"
83 (D.dims2D 500.0 500.0) 98 (D.dims2D 500.0 500.0)
84 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res)) 99 (renderAxis $ lSaxis $ [0.0, 0.1 .. 10.0]:(toLists $ tr res1))
85 100
86 let res = odeSolve KVAERNO_4_2_3 stiffish [0.0] (fromList [0.0, 0.1 .. 10.0]) 101 let res2 = odeSolve KVAERNO_4_2_3 stiffJac stiffish [0.0] (fromList [0.0, 0.1 .. 10.0])
87 putStrLn $ show res 102 putStrLn $ show res2
88 renderRasterific "diagrams/stiffish.png" 103 renderRasterific "diagrams/stiffish.png"
89 (D.dims2D 500.0 500.0) 104 (D.dims2D 500.0 500.0)
90 (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res)) 105 (renderAxis $ kSaxis $ zip [0.0, 0.1 .. 10.0] (concat $ toLists res2))
diff --git a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
index 30ff4c8..5af9e41 100644
--- a/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
+++ b/packages/sundials/src/Numeric/Sundials/ARKode/ODE.hs
@@ -35,6 +35,7 @@ module Numeric.Sundials.Arkode.ODE ( solveOde
35 , getBT 35 , getBT
36 , btGet 36 , btGet
37 , ODEMethod(..) 37 , ODEMethod(..)
38 , odeSolveV
38 ) where 39 ) where
39 40
40import qualified Language.C.Inline as C 41import qualified Language.C.Inline as C
@@ -45,7 +46,7 @@ import Data.Monoid ((<>))
45import Foreign.C.Types 46import Foreign.C.Types
46import Foreign.Ptr (Ptr) 47import Foreign.Ptr (Ptr)
47import Foreign.ForeignPtr (newForeignPtr_) 48import Foreign.ForeignPtr (newForeignPtr_)
48import Foreign.Storable (Storable, peekByteOff) 49import Foreign.Storable (Storable)
49 50
50import qualified Data.Vector.Storable as V 51import qualified Data.Vector.Storable as V
51import qualified Data.Vector.Storable.Mutable as VM 52import qualified Data.Vector.Storable.Mutable as VM
@@ -55,7 +56,8 @@ import System.IO.Unsafe (unsafePerformIO)
55 56
56import Numeric.LinearAlgebra.Devel (createVector) 57import Numeric.LinearAlgebra.Devel (createVector)
57 58
58import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><), subMatrix) 59import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, toList, (><),
60 subMatrix, rows, cols, toLists)
59 61
60import qualified Types as T 62import qualified Types as T
61import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3) 63import Arkode (sDIRK_2_1_2, kVAERNO_4_2_3)
@@ -78,12 +80,34 @@ C.include "../../../helpers.h"
78C.include "Arkode_hsc.h" 80C.include "Arkode_hsc.h"
79 81
80 82
81getDataFromContents :: Storable b => Int -> Ptr a -> IO (V.Vector b) 83getDataFromContents :: Int -> Ptr T.SunVector -> IO (V.Vector CDouble)
82getDataFromContents len ptr = do 84getDataFromContents len ptr = do
83 qtr <- B.getContentPtr ptr 85 qtr <- B.getContentPtr ptr
84 rtr <- B.getData qtr 86 rtr <- B.getData qtr
85 vectorFromC len rtr 87 vectorFromC len rtr
86 88
89-- FIXME: Potentially an instance of Storable
90getMatrixDataFromContents :: Ptr T.SunMatrix -> IO T.SunMatrix
91getMatrixDataFromContents ptr = do
92 qtr <- B.getContentMatrixPtr ptr
93 rs <- B.getNRows qtr
94 cs <- B.getNCols qtr
95 rtr <- B.getMatrixData qtr
96 vs <- vectorFromC (fromIntegral $ rs * cs) rtr
97 return $ T.SunMatrix { T.rows = rs, T.cols = cs, T.vals = vs }
98
99putMatrixDataFromContents :: T.SunMatrix -> Ptr T.SunMatrix -> IO ()
100putMatrixDataFromContents mat ptr = do
101 let rs = T.rows mat
102 cs = T.cols mat
103 vs = T.vals mat
104 qtr <- B.getContentMatrixPtr ptr
105 B.putNRows rs qtr
106 B.putNCols cs qtr
107 rtr <- B.getMatrixData qtr
108 vectorToC vs (fromIntegral $ rs * cs) rtr
109-- FIXME: END
110
87putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO () 111putDataInContents :: Storable a => V.Vector a -> Int -> Ptr b -> IO ()
88putDataInContents vec len ptr = do 112putDataInContents vec len ptr = do
89 qtr <- B.getContentPtr ptr 113 qtr <- B.getContentPtr ptr
@@ -103,16 +127,16 @@ vectorToC vec len ptr = do
103 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec 127 V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec
104 128
105data SundialsDiagnostics = SundialsDiagnostics { 129data SundialsDiagnostics = SundialsDiagnostics {
106 aRKodeGetNumSteps :: Int 130 _aRKodeGetNumSteps :: Int
107 , aRKodeGetNumStepAttempts :: Int 131 , _aRKodeGetNumStepAttempts :: Int
108 , aRKodeGetNumRhsEvals_fe :: Int 132 , _aRKodeGetNumRhsEvals_fe :: Int
109 , aRKodeGetNumRhsEvals_fi :: Int 133 , _aRKodeGetNumRhsEvals_fi :: Int
110 , aRKodeGetNumLinSolvSetups :: Int 134 , _aRKodeGetNumLinSolvSetups :: Int
111 , aRKodeGetNumErrTestFails :: Int 135 , _aRKodeGetNumErrTestFails :: Int
112 , aRKodeGetNumNonlinSolvIters :: Int 136 , _aRKodeGetNumNonlinSolvIters :: Int
113 , aRKodeGetNumNonlinSolvConvFails :: Int 137 , _aRKodeGetNumNonlinSolvConvFails :: Int
114 , aRKDlsGetNumJacEvals :: Int 138 , _aRKDlsGetNumJacEvals :: Int
115 , aRKDlsGetNumRhsEvals :: Int 139 , _aRKDlsGetNumRhsEvals :: Int
116 } deriving Show 140 } deriving Show
117 141
118-- | Stepping functions 142-- | Stepping functions
@@ -134,15 +158,16 @@ odeSolveV
134 -> Vector Double -- ^ initial conditions 158 -> Vector Double -- ^ initial conditions
135 -> Vector Double -- ^ desired solution times 159 -> Vector Double -- ^ desired solution times
136 -> Matrix Double -- ^ solution 160 -> Matrix Double -- ^ solution
137odeSolveV meth hi epsAbs epsRel = undefined 161odeSolveV _meth _hi _epsAbs _epsRel = undefined
138 162
139odeSolve :: ODEMethod 163odeSolve :: ODEMethod
164 -> (Double -> Vector Double -> Matrix Double)
140 -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 165 -> (Double -> [Double] -> [Double]) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
141 -> [Double] -- ^ initial conditions 166 -> [Double] -- ^ initial conditions
142 -> Vector Double -- ^ desired solution times 167 -> Vector Double -- ^ desired solution times
143 -> Matrix Double -- ^ solution 168 -> Matrix Double -- ^ solution
144odeSolve method f y0 ts = 169odeSolve method jac f y0 ts =
145 case solveOde method 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of 170 case solveOde method jac' 1.0e-6 1.0e-10 g (V.fromList y0) (V.fromList $ toList ts) of
146 Left c -> error $ show c -- FIXME 171 Left c -> error $ show c -- FIXME
147 Right (v, _) -> (nR >< nC) (V.toList v) 172 Right (v, _) -> (nR >< nC) (V.toList v)
148 where 173 where
@@ -150,30 +175,40 @@ odeSolve method f y0 ts =
150 nR = length us 175 nR = length us
151 nC = length y0 176 nC = length y0
152 g t x0 = V.fromList $ f t (V.toList x0) 177 g t x0 = V.fromList $ f t (V.toList x0)
178 jac' t v = foo $ jac t (V.fromList $ toList v)
179 foo m = T.SunMatrix { T.rows = nr, T.cols = nc, T.vals = vs }
180 where
181 nr = fromIntegral $ rows m
182 nc = fromIntegral $ cols m
183 vs = V.fromList $ map coerce $ concat $ toLists m
153 184
154solveOde :: 185solveOde ::
155 ODEMethod 186 ODEMethod
187 -> (Double -> V.Vector Double -> T.SunMatrix)
156 -> Double 188 -> Double
157 -> Double 189 -> Double
158 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 190 -> (Double -> V.Vector Double -> V.Vector Double) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
159 -> V.Vector Double -- ^ Initial conditions 191 -> V.Vector Double -- ^ Initial conditions
160 -> V.Vector Double -- ^ Desired solution times 192 -> V.Vector Double -- ^ Desired solution times
161 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution 193 -> Either Int ((V.Vector Double), SundialsDiagnostics) -- ^ Error code or solution
162solveOde method relTol absTol f y0 tt = 194solveOde method jac relTol absTol f y0 tt =
163 case solveOdeC (fromIntegral $ fromEnum method) (CDouble relTol) (CDouble absTol) 195 case solveOdeC (fromIntegral $ fromEnum method) jacH (CDouble relTol) (CDouble absTol)
164 (coerce f) (coerce y0) (coerce tt) of 196 (coerce f) (coerce y0) (coerce tt) of
165 Left c -> Left $ fromIntegral c 197 Left c -> Left $ fromIntegral c
166 Right (v, d) -> Right (coerce v, d) 198 Right (v, d) -> Right (coerce v, d)
199 where
200 jacH t v = jac (coerce t) (coerce v)
167 201
168solveOdeC :: 202solveOdeC ::
169 CInt -> 203 CInt ->
204 (CDouble -> V.Vector CDouble -> T.SunMatrix) ->
170 CDouble -> 205 CDouble ->
171 CDouble -> 206 CDouble ->
172 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\) 207 (CDouble -> V.Vector CDouble -> V.Vector CDouble) -- ^ The RHS of the system \(\dot{y} = f(t,y)\)
173 -> V.Vector CDouble -- ^ Initial conditions 208 -> V.Vector CDouble -- ^ Initial conditions
174 -> V.Vector CDouble -- ^ Desired solution times 209 -> V.Vector CDouble -- ^ Desired solution times
175 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution 210 -> Either CInt ((V.Vector CDouble), SundialsDiagnostics) -- ^ Error code or solution
176solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do 211solveOdeC method jacH relTol absTol fun f0 ts = unsafePerformIO $ do
177 let dim = V.length f0 212 let dim = V.length f0
178 nEq :: CLong 213 nEq :: CLong
179 nEq = fromIntegral dim 214 nEq = fromIntegral dim
@@ -197,9 +232,19 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do
197 fImm <- fun x <$> getDataFromContents dim y 232 fImm <- fun x <$> getDataFromContents dim y
198 -- Fill in the provided pointer with the resulting vector. 233 -- Fill in the provided pointer with the resulting vector.
199 putDataInContents fImm dim f 234 putDataInContents fImm dim f
200 -- I don't understand what this comment means 235 -- FIXME: I don't understand what this comment means
236 -- Unsafe since the function will be called many times.
237 [CU.exp| int{ 0 } |]
238 let jacIO :: CDouble -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunMatrix ->
239 Ptr () -> Ptr T.SunVector -> Ptr T.SunVector -> Ptr T.SunVector ->
240 IO CInt
241 jacIO t y _fy jacS _ptr _tmp1 _tmp2 _tmp3 = do
242 foo <- jacH t <$> getDataFromContents dim y
243 putMatrixDataFromContents foo jacS
244 -- FIXME: I don't understand what this comment means
201 -- Unsafe since the function will be called many times. 245 -- Unsafe since the function will be called many times.
202 [CU.exp| int{ 0 } |] 246 [CU.exp| int{ 0 } |]
247
203 res <- [C.block| int { 248 res <- [C.block| int {
204 /* general problem variables */ 249 /* general problem variables */
205 int flag; /* reusable error-checking flag */ 250 int flag; /* reusable error-checking flag */
@@ -246,6 +291,11 @@ solveOdeC method relTol absTol fun f0 ts = unsafePerformIO $ do
246 291
247 /* Linear solver interface */ 292 /* Linear solver interface */
248 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */ 293 flag = ARKDlsSetLinearSolver(arkode_mem, LS, A); /* Attach matrix and linear solver */
294
295 flag = ARKDlsSetJacFn(arkode_mem, $fun:(int (* jacIO) (double t, SunVector y[], SunVector fy[], SunMatrix Jac[], void * params, SunVector tmp1[], SunVector tmp2[], SunVector tmp3[])));
296 if (check_flag(&flag, "ARKDlsSetJacFn", 1)) return 1;
297
298
249 /* Store initial conditions */ 299 /* Store initial conditions */
250 for (j = 0; j < NEQ; j++) { 300 for (j = 0; j < NEQ; j++) {
251 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j); 301 ($vec-ptr:(double *qMatMut))[0 * $(int nTs) + j] = NV_Ith_S(y,j);
@@ -340,7 +390,8 @@ btGet :: ODEMethod -> Matrix Double
340btGet method = case getBT method of 390btGet method = case getBT method of
341 Left c -> error $ show c -- FIXME 391 Left c -> error $ show c -- FIXME
342 -- FIXME 392 -- FIXME
343 Right (v, sqp) -> subMatrix (0, 0) (2, 2) $ (B.arkSMax >< B.arkSMax) (V.toList v) 393 Right (v, _sqp) -> subMatrix (0, 0) (2, 2) $
394 (B.arkSMax >< B.arkSMax) (V.toList v)
344 395
345getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int) 396getBT :: ODEMethod -> Either Int (V.Vector Double, V.Vector Int)
346getBT method = case getButcherTable method of 397getBT method = case getButcherTable method of
@@ -352,9 +403,9 @@ getButcherTable method = unsafePerformIO $ do
352 -- arkode seems to want an ODE in order to set and then get the 403 -- arkode seems to want an ODE in order to set and then get the
353 -- Butcher tableau so here's one to keep it happy 404 -- Butcher tableau so here's one to keep it happy
354 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble 405 let fun :: CDouble -> V.Vector CDouble -> V.Vector CDouble
355 fun t ys = V.fromList [ ys V.! 0 ] 406 fun _t ys = V.fromList [ ys V.! 0 ]
356 f0 = V.fromList [ 1.0 ] 407 f0 = V.fromList [ 1.0 ]
357 ts = V.fromList [ 0.0 ] 408 ts = V.fromList [ 0.0 ]
358 dim = V.length f0 409 dim = V.length f0
359 nEq :: CLong 410 nEq :: CLong
360 nEq = fromIntegral dim 411 nEq = fromIntegral dim
diff --git a/packages/sundials/src/Types.hs b/packages/sundials/src/Types.hs
index e910c57..04e4280 100644
--- a/packages/sundials/src/Types.hs
+++ b/packages/sundials/src/Types.hs
@@ -15,8 +15,14 @@ import qualified Language.C.Types as CT
15import qualified Data.Map as Map 15import qualified Data.Map as Map
16import Language.C.Inline.Context 16import Language.C.Inline.Context
17 17
18import qualified Data.Vector.Storable as V
19
18 20
19data SunVector 21data SunVector
22data SunMatrix = SunMatrix { rows :: CInt
23 , cols :: CInt
24 , vals :: V.Vector CDouble
25 }
20 26
21-- FIXME: Is this true? 27-- FIXME: Is this true?
22type SunIndexType = CLong 28type SunIndexType = CLong
@@ -26,6 +32,7 @@ sunTypesTable = Map.fromList
26 [ 32 [
27 (CT.TypeName "sunindextype", [t| SunIndexType |] ) 33 (CT.TypeName "sunindextype", [t| SunIndexType |] )
28 , (CT.TypeName "SunVector", [t| SunVector |] ) 34 , (CT.TypeName "SunVector", [t| SunVector |] )
35 , (CT.TypeName "SunMatrix", [t| SunMatrix |] )
29 ] 36 ]
30 37
31sunCtx :: Context 38sunCtx :: Context