diff options
Diffstat (limited to 'packages/base/src/Internal/CG.hs')
-rw-r--r-- | packages/base/src/Internal/CG.hs | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs new file mode 100644 index 0000000..1193b18 --- /dev/null +++ b/packages/base/src/Internal/CG.hs | |||
@@ -0,0 +1,177 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | ||
2 | {-# LANGUAGE RecordWildCards #-} | ||
3 | |||
4 | module Internal.CG( | ||
5 | cgSolve, cgSolve', | ||
6 | CGState(..), R, V | ||
7 | ) where | ||
8 | |||
9 | import Internal.Vector | ||
10 | import Internal.Matrix hiding (mat) | ||
11 | import Internal.Numeric | ||
12 | import Internal.Element | ||
13 | import Internal.IO | ||
14 | import Internal.Container | ||
15 | import Internal.Sparse | ||
16 | import Numeric.Vector() | ||
17 | import Internal.Algorithms(linearSolveLS, relativeError, pnorm, NormType(..)) | ||
18 | import Control.Arrow((***)) | ||
19 | import Data.Vector.Storable(fromList) | ||
20 | |||
21 | {- | ||
22 | import Util.Misc(debug, debugMat) | ||
23 | |||
24 | (//) :: Show a => a -> String -> a | ||
25 | infix 0 // -- , /// | ||
26 | a // b = debug b id a | ||
27 | |||
28 | (///) :: V -> String -> V | ||
29 | infix 0 /// | ||
30 | v /// b = debugMat b 2 asRow v | ||
31 | -} | ||
32 | |||
33 | type R = Double | ||
34 | type V = Vector R | ||
35 | |||
36 | data CGState = CGState | ||
37 | { cgp :: V -- ^ conjugate gradient | ||
38 | , cgr :: V -- ^ residual | ||
39 | , cgr2 :: R -- ^ squared norm of residual | ||
40 | , cgx :: V -- ^ current solution | ||
41 | , cgdx :: R -- ^ normalized size of correction | ||
42 | } | ||
43 | |||
44 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState | ||
45 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | ||
46 | where | ||
47 | ap1 = a p | ||
48 | ap | sym = ap1 | ||
49 | | otherwise = at ap1 | ||
50 | pap | sym = p <·> ap1 | ||
51 | | otherwise = norm2 ap1 ** 2 | ||
52 | alpha = r2 / pap | ||
53 | dx = scale alpha p | ||
54 | x' = x + dx | ||
55 | r' = r - scale alpha ap | ||
56 | r'2 = r' <·> r' | ||
57 | beta = r'2 / r2 | ||
58 | p' = r' + scale beta p | ||
59 | |||
60 | rdx = norm2 dx / max 1 (norm2 x) | ||
61 | |||
62 | conjugrad | ||
63 | :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] | ||
64 | conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b | ||
65 | |||
66 | solveG | ||
67 | :: (V -> V) -> (V -> V) | ||
68 | -> ((V -> V) -> (V -> V) -> CGState -> CGState) | ||
69 | -> V | ||
70 | -> V | ||
71 | -> R -> R | ||
72 | -> [CGState] | ||
73 | solveG mat ma meth rawb x0' ϵb ϵx | ||
74 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 | ||
75 | where | ||
76 | a = mat . ma | ||
77 | b = mat rawb | ||
78 | x0 = if x0' == 0 then konst 0 (dim b) else x0' | ||
79 | r0 = b - a x0 | ||
80 | r20 = r0 <·> r0 | ||
81 | p0 = r0 | ||
82 | nb2 = b <·> b | ||
83 | ok CGState {..} | ||
84 | = cgr2 <nb2*ϵb**2 | ||
85 | || cgdx < ϵx | ||
86 | |||
87 | |||
88 | takeUntil :: (a -> Bool) -> [a] -> [a] | ||
89 | takeUntil q xs = a++ take 1 b | ||
90 | where | ||
91 | (a,b) = break q xs | ||
92 | |||
93 | cgSolve | ||
94 | :: Bool -- ^ is symmetric | ||
95 | -> GMatrix -- ^ coefficient matrix | ||
96 | -> Vector Double -- ^ right-hand side | ||
97 | -> Vector Double -- ^ solution | ||
98 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 | ||
99 | where | ||
100 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | ||
101 | |||
102 | cgSolve' | ||
103 | :: Bool -- ^ symmetric | ||
104 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | ||
105 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | ||
106 | -> Int -- ^ maximum number of iterations | ||
107 | -> GMatrix -- ^ coefficient matrix | ||
108 | -> V -- ^ initial solution | ||
109 | -> V -- ^ right-hand side | ||
110 | -> [CGState] -- ^ solution | ||
111 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | ||
112 | |||
113 | |||
114 | -------------------------------------------------------------------------------- | ||
115 | |||
116 | instance Testable GMatrix | ||
117 | where | ||
118 | checkT _ = (ok,info) | ||
119 | where | ||
120 | sma = convo2 20 3 | ||
121 | x1 = vect [1..20] | ||
122 | x2 = vect [1..40] | ||
123 | sm = mkSparse sma | ||
124 | dm = toDense sma | ||
125 | |||
126 | s1 = sm !#> x1 | ||
127 | d1 = dm #> x1 | ||
128 | |||
129 | s2 = tr sm !#> x2 | ||
130 | d2 = tr dm #> x2 | ||
131 | |||
132 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
133 | s3 = sdia !#> x1 | ||
134 | s4 = tr sdia !#> x2 | ||
135 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
136 | d3 = ddia #> x1 | ||
137 | d4 = tr ddia #> x2 | ||
138 | |||
139 | v = testb 40 | ||
140 | s5 = cgSolve False sm v | ||
141 | d5 = denseSolve dm v | ||
142 | |||
143 | info = do | ||
144 | print sm | ||
145 | disp (toDense sma) | ||
146 | print s1; print d1 | ||
147 | print s2; print d2 | ||
148 | print s3; print d3 | ||
149 | print s4; print d4 | ||
150 | print s5; print d5 | ||
151 | print $ relativeError (pnorm Infinity) s5 d5 | ||
152 | |||
153 | ok = s1==d1 | ||
154 | && s2==d2 | ||
155 | && s3==d3 | ||
156 | && s4==d4 | ||
157 | && relativeError (pnorm Infinity) s5 d5 < 1E-10 | ||
158 | |||
159 | disp = putStr . dispf 2 | ||
160 | |||
161 | vect = fromList :: [Double] -> Vector Double | ||
162 | |||
163 | convomat :: Int -> Int -> AssocMatrix | ||
164 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
165 | |||
166 | convo2 :: Int -> Int -> AssocMatrix | ||
167 | convo2 n k = m1 ++ m2 | ||
168 | where | ||
169 | m1 = convomat n k | ||
170 | m2 = map (((+n) *** id) *** id) m1 | ||
171 | |||
172 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
173 | |||
174 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
175 | |||
176 | -- mkDiag v = mkDiagR (dim v) (dim v) v | ||
177 | |||