diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2018-03-11 14:21:31 +0000 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2018-03-11 14:21:31 +0000 |
commit | a22963fa83156b76dd73777b7044897eed50e3bc (patch) | |
tree | 325c1f1d19bb0764290650770733e6bf8db171f8 /packages/sundials/src/gsl-ode.hs | |
parent | d83b17190029c11e3ab8b504e5cdc917f5863120 (diff) |
The start of an hmatrix interface to sundials
Diffstat (limited to 'packages/sundials/src/gsl-ode.hs')
-rw-r--r-- | packages/sundials/src/gsl-ode.hs | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/packages/sundials/src/gsl-ode.hs b/packages/sundials/src/gsl-ode.hs new file mode 100644 index 0000000..045fce1 --- /dev/null +++ b/packages/sundials/src/gsl-ode.hs | |||
@@ -0,0 +1,152 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# LANGUAGE OverloadedStrings #-} | ||
3 | {-# LANGUAGE TemplateHaskell #-} | ||
4 | {-# LANGUAGE QuasiQuotes #-} | ||
5 | {-# LANGUAGE MultiWayIf #-} | ||
6 | import Data.Coerce (coerce) | ||
7 | import Data.Monoid ((<>)) | ||
8 | import qualified Data.Vector.Storable as V | ||
9 | import qualified Data.Vector.Storable.Mutable as VM | ||
10 | import Foreign.C.Types | ||
11 | import Foreign.ForeignPtr (newForeignPtr_) | ||
12 | import Foreign.Ptr (Ptr) | ||
13 | import Foreign.Storable (Storable) | ||
14 | -- import qualified Graphics.Rendering.Chart.Backend.Cairo as Chart | ||
15 | -- import qualified Graphics.Rendering.Chart.Easy as Chart | ||
16 | import qualified Language.C.Inline as C | ||
17 | import qualified Language.C.Inline.Unsafe as CU | ||
18 | import System.IO.Unsafe (unsafePerformIO) | ||
19 | |||
20 | -- #if __GLASGOW_HASKELL__ < 710 | ||
21 | -- import Data.Functor ((<$>)) | ||
22 | -- #endif | ||
23 | |||
24 | C.context (C.baseCtx <> C.vecCtx <> C.funCtx) | ||
25 | |||
26 | C.include "<gsl/gsl_errno.h>" | ||
27 | C.include "<gsl/gsl_matrix.h>" | ||
28 | C.include "<gsl/gsl_odeiv2.h>" | ||
29 | |||
30 | -- | Solves a system of ODEs. Every 'V.Vector' involved must be of the | ||
31 | -- same size. | ||
32 | {-# NOINLINE solveOdeC #-} | ||
33 | solveOdeC | ||
34 | :: (CDouble -> V.Vector CDouble -> V.Vector CDouble) | ||
35 | -- ^ ODE to Solve | ||
36 | -> CDouble | ||
37 | -- ^ Start | ||
38 | -> V.Vector CDouble | ||
39 | -- ^ Solution at start point | ||
40 | -> CDouble | ||
41 | -- ^ End | ||
42 | -> Either String (V.Vector CDouble) | ||
43 | -- ^ Solution at end point, or error. | ||
44 | solveOdeC fun x0 f0 xend = unsafePerformIO $ do | ||
45 | let dim = V.length f0 | ||
46 | let dim_c = fromIntegral dim -- This is in CInt | ||
47 | -- Convert the function to something of the right type to C. | ||
48 | let funIO x y f _ptr = do | ||
49 | -- Convert the pointer we get from C (y) to a vector, and then | ||
50 | -- apply the user-supplied function. | ||
51 | fImm <- fun x <$> vectorFromC dim y | ||
52 | -- Fill in the provided pointer with the resulting vector. | ||
53 | vectorToC fImm dim f | ||
54 | -- Unsafe since the function will be called many times. | ||
55 | [CU.exp| int{ GSL_SUCCESS } |] | ||
56 | -- Create a mutable vector from the initial solution. This will be | ||
57 | -- passed to the ODE solving function provided by GSL, and will | ||
58 | -- contain the final solution. | ||
59 | fMut <- V.thaw f0 | ||
60 | res <- [C.block| int { | ||
61 | gsl_odeiv2_system sys = { | ||
62 | $fun:(int (* funIO) (double t, const double y[], double dydt[], void * params)), | ||
63 | // The ODE to solve, converted to function pointer using the `fun` | ||
64 | // anti-quoter | ||
65 | NULL, // We don't provide a Jacobian | ||
66 | $(int dim_c), // The dimension | ||
67 | NULL // We don't need the parameter pointer | ||
68 | }; | ||
69 | // Create the driver, using some sensible values for the stepping | ||
70 | // function and the tolerances | ||
71 | gsl_odeiv2_driver *d = gsl_odeiv2_driver_alloc_y_new ( | ||
72 | &sys, gsl_odeiv2_step_rk8pd, 1e-6, 1e-6, 0.0); | ||
73 | // Finally, apply the driver. | ||
74 | int status = gsl_odeiv2_driver_apply( | ||
75 | d, &$(double x0), $(double xend), $vec-ptr:(double *fMut)); | ||
76 | // Free the driver | ||
77 | gsl_odeiv2_driver_free(d); | ||
78 | return status; | ||
79 | } |] | ||
80 | -- Check the error code | ||
81 | maxSteps <- [C.exp| int{ GSL_EMAXITER } |] | ||
82 | smallStep <- [C.exp| int{ GSL_ENOPROG } |] | ||
83 | good <- [C.exp| int{ GSL_SUCCESS } |] | ||
84 | if | res == good -> Right <$> V.freeze fMut | ||
85 | | res == maxSteps -> return $ Left "Too many steps" | ||
86 | | res == smallStep -> return $ Left "Step size dropped below minimum allowed size" | ||
87 | | otherwise -> return $ Left $ "Unknown error code " ++ show res | ||
88 | |||
89 | solveOde | ||
90 | :: (Double -> V.Vector Double -> V.Vector Double) | ||
91 | -- ^ ODE to Solve | ||
92 | -> Double | ||
93 | -- ^ Start | ||
94 | -> V.Vector Double | ||
95 | -- ^ Solution at start point | ||
96 | -> Double | ||
97 | -- ^ End | ||
98 | -> Either String (V.Vector Double) | ||
99 | -- ^ Solution at end point, or error. | ||
100 | solveOde fun x0 f0 xend = | ||
101 | coerce $ solveOdeC (coerce fun) (coerce x0) (coerce f0) (coerce xend) | ||
102 | |||
103 | lorenz | ||
104 | :: Double | ||
105 | -- ^ Starting point | ||
106 | -> V.Vector Double | ||
107 | -- ^ Solution at starting point | ||
108 | -> Double | ||
109 | -- ^ End point | ||
110 | -> Either String (V.Vector Double) | ||
111 | lorenz x0 f0 xend = solveOde fun x0 f0 xend | ||
112 | where | ||
113 | sigma = 10.0; | ||
114 | _R = 28.0; | ||
115 | b = 8.0 / 3.0; | ||
116 | |||
117 | fun _x y = | ||
118 | let y0 = y V.! 0 | ||
119 | y1 = y V.! 1 | ||
120 | y2 = y V.! 2 | ||
121 | in V.fromList | ||
122 | [ sigma * ( y1 - y0 ) | ||
123 | , _R * y0 - y1 - y0 * y2 | ||
124 | , -b * y2 + y0 * y1 | ||
125 | ] | ||
126 | |||
127 | main :: IO () | ||
128 | main = undefined | ||
129 | -- main = Chart.toFile Chart.def "lorenz.png" $ do | ||
130 | -- Chart.layout_title Chart..= "Lorenz" | ||
131 | -- Chart.plot $ Chart.line "curve" [pts] | ||
132 | -- where | ||
133 | -- pts = [(f V.! 0, f V.! 2) | (_x, f) <- go 0 (V.fromList [10.0 , 1.0 , 1.0])] | ||
134 | |||
135 | -- go x f | x > 40 = | ||
136 | -- [(x, f)] | ||
137 | -- go x f = | ||
138 | -- let x' = x + 0.01 | ||
139 | -- Right f' = lorenz x f x' | ||
140 | -- in (x, f) : go x' f' | ||
141 | |||
142 | -- Utils | ||
143 | |||
144 | vectorFromC :: Storable a => Int -> Ptr a -> IO (V.Vector a) | ||
145 | vectorFromC len ptr = do | ||
146 | ptr' <- newForeignPtr_ ptr | ||
147 | V.freeze $ VM.unsafeFromForeignPtr0 ptr' len | ||
148 | |||
149 | vectorToC :: Storable a => V.Vector a -> Int -> Ptr a -> IO () | ||
150 | vectorToC vec len ptr = do | ||
151 | ptr' <- newForeignPtr_ ptr | ||
152 | V.copy (VM.unsafeFromForeignPtr0 ptr' len) vec | ||