summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
blob: 5e2ea84a1577162daaaf149ca8f7e1df8c2a659d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
{-# LANGUAGE RecordWildCards #-}

module Numeric.LinearAlgebra.Util.CG(
    cgSolve, cgSolve',
    CGMat, CGState(..), R, V
) where

import Data.Packed.Numeric
import Numeric.Vector()

{-
import Util.Misc(debug, debugMat)

(//) :: Show a => a -> String -> a
infix 0 // -- , ///
a // b = debug b id a

(///) :: V -> String -> V
infix 0 ///
v /// b = debugMat b 2 asRow v
-}

type R = Double
type V = Vector R

data CGState = CGState
    { cgp  :: V  -- ^ conjugate gradient
    , cgr  :: V  -- ^ residual
    , cgr2 :: R  -- ^ squared norm of residual
    , cgx  :: V  -- ^ current solution
    , cgdx :: R  -- ^ normalized size of correction
    }

cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
  where
    ap1 = a p
    ap  | sym       = ap1
        | otherwise = at ap1
    pap | sym       = p ◇ ap1
        | otherwise = norm2 ap1 ** 2
    alpha = r2 / pap
    dx = scale alpha p
    x' = x + dx
    r' = r - scale alpha ap
    r'2 = r' ◇ r'
    beta = r'2 / r2
    p' = r' + scale beta p

    rdx = norm2 dx / max 1 (norm2 x)

conjugrad
  :: (Transposable m, Contraction m V V)
  => Bool -> m -> V -> V -> R -> R -> [CGState]
conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b

solveG
    :: (V -> V) -> (V -> V)
    -> ((V -> V) -> (V -> V) -> CGState -> CGState)
    -> V
    -> V
    -> R -> R
    -> [CGState]
solveG mat ma meth rawb x0' ϵb ϵx
    = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
  where
    a = mat . ma
    b = mat rawb
    x0  = if x0' == 0 then konst 0 (dim b) else x0'
    r0  = b - a x0
    r20 = r0 ◇ r0
    p0  = r0
    nb2 = b ◇ b
    ok CGState {..}
        =  cgr2 <nb2*ϵb**2
        || cgdx < ϵx


takeUntil :: (a -> Bool) -> [a] -> [a]
takeUntil q xs = a++ take 1 b
  where
    (a,b) = break q xs

class (Transposable m, Contraction m V V) => CGMat m

cgSolve
  :: CGMat m
  => Bool     -- ^ is symmetric
  -> m        -- ^ coefficient matrix
  -> Vector Double -- ^ right-hand side
  -> Vector Double        -- ^ solution
cgSolve sym a b  = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
  where
    n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))

cgSolve'
  :: CGMat m
  => Bool      -- ^ symmetric
  -> R         -- ^ relative tolerance for the residual (e.g. 1E-4)
  -> R         -- ^ relative tolerance for δx (e.g. 1E-3)
  -> Int       -- ^ maximum number of iterations
  -> m         -- ^ coefficient matrix
  -> V         -- ^ initial solution
  -> V         -- ^ right-hand side
  -> [CGState] -- ^ solution
cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es