summaryrefslogtreecommitdiff
path: root/packages/gsl/src/Numeric/GSL/ODE.hs
blob: 987d47e6f865f4277c68b5168da1b249af41c418 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
{-# LANGUAGE FlexibleContexts #-}


{- |
Module      :  Numeric.GSL.ODE
Copyright   :  (c) Alberto Ruiz 2010
License     :  GPL
Maintainer  :  Alberto Ruiz
Stability   :  provisional

Solution of ordinary differential equation (ODE) initial value problems.

<http://www.gnu.org/software/gsl/manual/html_node/Ordinary-Differential-Equations.html>

A simple example:

@
import Numeric.GSL.ODE
import Numeric.LinearAlgebra
import Graphics.Plot(mplot)

xdot t [x,v] = [v, -0.95*x - 0.1*v]

ts = linspace 100 (0,20 :: Double)

sol = odeSolve xdot [10,0] ts

main = mplot (ts : toColumns sol)
@

-}
-----------------------------------------------------------------------------

module Numeric.GSL.ODE (
    odeSolve, odeSolveV, odeSolveVWith, ODEMethod(..), Jacobian, StepControl(..)
) where

import Numeric.LinearAlgebra.HMatrix
import Numeric.GSL.Internal

import Foreign.Ptr(FunPtr, nullFunPtr, freeHaskellFunPtr)
import Foreign.C.Types
import System.IO.Unsafe(unsafePerformIO)

-------------------------------------------------------------------------

type TVV   = TV (TV Res)
type TVM   = TV (TM Res)
type TVVM  = TV (TV (TM Res))
type TVVVM = TV (TV (TV (TM Res)))

type Jacobian = Double -> Vector Double -> Matrix Double

-- | Stepping functions
data ODEMethod = RK2 -- ^ Embedded Runge-Kutta (2, 3) method.
               | RK4 -- ^ 4th order (classical) Runge-Kutta. The error estimate is obtained by halving the step-size. For more efficient estimate of the error, use the embedded methods.
               | RKf45 -- ^ Embedded Runge-Kutta-Fehlberg (4, 5) method. This method is a good general-purpose integrator.
               | RKck -- ^ Embedded Runge-Kutta Cash-Karp (4, 5) method.
               | RK8pd -- ^ Embedded Runge-Kutta Prince-Dormand (8,9) method.
               | RK2imp Jacobian -- ^ Implicit 2nd order Runge-Kutta at Gaussian points.
               | RK4imp Jacobian -- ^ Implicit 4th order Runge-Kutta at Gaussian points.
               | BSimp Jacobian -- ^ Implicit Bulirsch-Stoer method of Bader and Deuflhard. The method is generally suitable for stiff problems.
               | RK1imp Jacobian -- ^ Implicit Gaussian first order Runge-Kutta. Also known as implicit Euler or backward Euler method. Error estimation is carried out by the step doubling method.
               | MSAdams -- ^ A variable-coefficient linear multistep Adams method in Nordsieck form. This stepper uses explicit Adams-Bashforth (predictor) and implicit Adams-Moulton (corrector) methods in P(EC)^m functional iteration mode. Method order varies dynamically between 1 and 12. 
               | MSBDF Jacobian -- ^ A variable-coefficient linear multistep backward differentiation formula (BDF) method in Nordsieck form. This stepper uses the explicit BDF formula as predictor and implicit BDF formula as corrector. A modified Newton iteration method is used to solve the system of non-linear equations. Method order varies dynamically between 1 and 5. The method is generally suitable for stiff problems.

-- | Adaptive step-size control functions
data StepControl = X     Double Double -- ^ abs. and rel. tolerance for x(t)
                 | X'    Double Double -- ^ abs. and rel. tolerance for x'(t)
                 | XX'   Double Double Double Double -- ^ include both via rel. tolerance scaling factors a_x, a_x'
                 | ScXX' Double Double Double Double (Vector Double) -- ^ scale abs. tolerance of x(t) components

-- | A version of 'odeSolveV' with reasonable default parameters and system of equations defined using lists.
odeSolve
    :: (Double -> [Double] -> [Double])        -- ^ x'(t,x)
    -> [Double]        -- ^ initial conditions
    -> Vector Double   -- ^ desired solution times
    -> Matrix Double   -- ^ solution
odeSolve xdot xi ts = odeSolveV RKf45 hi epsAbs epsRel (l2v xdot) (fromList xi) ts
    where hi = (ts!1 - ts!0)/100
          epsAbs = 1.49012e-08
          epsRel = epsAbs
          l2v f  = \t -> fromList . f t . toList

-- | A version of 'odeSolveVWith' with reasonable default step control.
odeSolveV
    :: ODEMethod
    -> Double            -- ^ initial step size
    -> Double            -- ^ absolute tolerance for the state vector
    -> Double            -- ^ relative tolerance for the state vector
    -> (Double -> Vector Double -> Vector Double)   -- ^ x'(t,x)
    -> Vector Double     -- ^ initial conditions
    -> Vector Double     -- ^ desired solution times
    -> Matrix Double     -- ^ solution
odeSolveV meth hi epsAbs epsRel = odeSolveVWith meth (XX' epsAbs epsRel 1 1) hi

-- | Evolution of the system with adaptive step-size control.
odeSolveVWith
    :: ODEMethod
    -> StepControl
    -> Double            -- ^ initial step size
    -> (Double -> Vector Double -> Vector Double)   -- ^ x'(t,x)
    -> Vector Double     -- ^ initial conditions
    -> Vector Double     -- ^ desired solution times
    -> Matrix Double     -- ^ solution
odeSolveVWith method control = odeSolveVWith' m mbj c epsAbs epsRel aX aX' mbsc
    where (m, mbj) = case method of
              RK2        -> (0 , Nothing )
              RK4        -> (1 , Nothing )
              RKf45      -> (2 , Nothing )
              RKck       -> (3 , Nothing )
              RK8pd      -> (4 , Nothing )
              RK2imp jac -> (5 , Just jac)
              RK4imp jac -> (6 , Just jac)
              BSimp  jac -> (7 , Just jac)
              RK1imp jac -> (8 , Just jac)
              MSAdams    -> (9 , Nothing )
              MSBDF  jac -> (10, Just jac)
          (c, epsAbs, epsRel, aX, aX', mbsc) = case control of
              X     ea er           -> (0, ea, er, 1 , 0  , Nothing)
              X'    ea er           -> (0, ea, er, 0 , 1  , Nothing)
              XX'   ea er ax ax'    -> (0, ea, er, ax, ax', Nothing)
              ScXX' ea er ax ax' sc -> (1, ea, er, ax, ax', Just sc)

odeSolveVWith'
    :: CInt     -- ^ stepping function
    -> Maybe (Double -> Vector Double -> Matrix Double)   -- ^ optional jacobian
    -> CInt     -- ^ step-size control function
    -> Double   -- ^ absolute tolerance for step-size control
    -> Double   -- ^ relative tolerance for step-size control
    -> Double   -- ^ scaling factor for relative tolerance of x(t)
    -> Double   -- ^ scaling factor for relative tolerance of x'(t)
    -> Maybe (Vector Double)    -- ^ optional scaling for absolute error
    -> Double   -- ^ initial step size
    -> (Double -> Vector Double -> Vector Double)        -- ^ x'(t,x)
    -> Vector Double  -- ^ initial conditions
    -> Vector Double  -- ^ desired solution times
    -> Matrix Double  -- ^ solution
odeSolveVWith' method mbjac control epsAbs epsRel aX aX' mbsc h f xiv ts =
    unsafePerformIO $ do
        let n  = size xiv
            sc = case mbsc of
                Just scv -> checkdim1 n scv
                Nothing  -> xiv
        fp <- mkDoubleVecVecfun (\t -> aux_vTov (checkdim1 n . f t))
        jp <- case mbjac of
            Just jac -> mkDoubleVecMatfun (\t -> aux_vTom (checkdim2 n . jac t))
            Nothing  -> return nullFunPtr
        sol <- vec sc $ \sc' -> vec xiv $ \xiv' ->
            vec (checkTimes ts) $ \ts' -> createMIO (size ts) n
                (ode_c method control h epsAbs epsRel aX aX' fp jp
                // sc' // xiv' // ts' )
                "ode"
        freeHaskellFunPtr fp
        if (jp /= nullFunPtr) then freeHaskellFunPtr jp else pure ()
        return sol

foreign import ccall safe "ode"
    ode_c :: CInt -> CInt -> Double
          -> Double -> Double -> Double -> Double
          -> FunPtr (Double -> TVV) -> FunPtr (Double -> TVM) -> TVVVM

-------------------------------------------------------

checkdim1 n v
    | size v == n = v
    | otherwise = error $ "Error: "++ show n
                        ++ " components expected in the result of the function supplied to odeSolve"

checkdim2 n m
    | rows m == n && cols m == n = m
    | otherwise = error $ "Error: "++ show n ++ "x" ++ show n
                        ++ " Jacobian expected in odeSolve"

checkTimes ts | size ts > 1 && all (>0) (zipWith subtract ts' (tail ts')) = ts
              | otherwise = error "odeSolve requires increasing times"
    where ts' = toList ts