diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 171 |
1 files changed, 0 insertions, 171 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs deleted file mode 100644 index 899a5bf..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ /dev/null | |||
@@ -1,171 +0,0 @@ | |||
1 | {-# LANGUAGE FlexibleContexts, FlexibleInstances #-} | ||
2 | {-# LANGUAGE RecordWildCards #-} | ||
3 | |||
4 | module Numeric.LinearAlgebra.Util.CG( | ||
5 | cgSolve, cgSolve', | ||
6 | CGState(..), R, V | ||
7 | ) where | ||
8 | |||
9 | import Data.Packed.Numeric | ||
10 | import Numeric.Sparse | ||
11 | import Numeric.Vector() | ||
12 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, pnorm, NormType(..)) | ||
13 | import Control.Arrow((***)) | ||
14 | |||
15 | {- | ||
16 | import Util.Misc(debug, debugMat) | ||
17 | |||
18 | (//) :: Show a => a -> String -> a | ||
19 | infix 0 // -- , /// | ||
20 | a // b = debug b id a | ||
21 | |||
22 | (///) :: V -> String -> V | ||
23 | infix 0 /// | ||
24 | v /// b = debugMat b 2 asRow v | ||
25 | -} | ||
26 | |||
27 | type R = Double | ||
28 | type V = Vector R | ||
29 | |||
30 | data CGState = CGState | ||
31 | { cgp :: V -- ^ conjugate gradient | ||
32 | , cgr :: V -- ^ residual | ||
33 | , cgr2 :: R -- ^ squared norm of residual | ||
34 | , cgx :: V -- ^ current solution | ||
35 | , cgdx :: R -- ^ normalized size of correction | ||
36 | } | ||
37 | |||
38 | cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState | ||
39 | cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | ||
40 | where | ||
41 | ap1 = a p | ||
42 | ap | sym = ap1 | ||
43 | | otherwise = at ap1 | ||
44 | pap | sym = p <·> ap1 | ||
45 | | otherwise = norm2 ap1 ** 2 | ||
46 | alpha = r2 / pap | ||
47 | dx = scale alpha p | ||
48 | x' = x + dx | ||
49 | r' = r - scale alpha ap | ||
50 | r'2 = r' <·> r' | ||
51 | beta = r'2 / r2 | ||
52 | p' = r' + scale beta p | ||
53 | |||
54 | rdx = norm2 dx / max 1 (norm2 x) | ||
55 | |||
56 | conjugrad | ||
57 | :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] | ||
58 | conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b | ||
59 | |||
60 | solveG | ||
61 | :: (V -> V) -> (V -> V) | ||
62 | -> ((V -> V) -> (V -> V) -> CGState -> CGState) | ||
63 | -> V | ||
64 | -> V | ||
65 | -> R -> R | ||
66 | -> [CGState] | ||
67 | solveG mat ma meth rawb x0' ϵb ϵx | ||
68 | = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 | ||
69 | where | ||
70 | a = mat . ma | ||
71 | b = mat rawb | ||
72 | x0 = if x0' == 0 then konst 0 (dim b) else x0' | ||
73 | r0 = b - a x0 | ||
74 | r20 = r0 <·> r0 | ||
75 | p0 = r0 | ||
76 | nb2 = b <·> b | ||
77 | ok CGState {..} | ||
78 | = cgr2 <nb2*ϵb**2 | ||
79 | || cgdx < ϵx | ||
80 | |||
81 | |||
82 | takeUntil :: (a -> Bool) -> [a] -> [a] | ||
83 | takeUntil q xs = a++ take 1 b | ||
84 | where | ||
85 | (a,b) = break q xs | ||
86 | |||
87 | cgSolve | ||
88 | :: Bool -- ^ is symmetric | ||
89 | -> GMatrix -- ^ coefficient matrix | ||
90 | -> Vector Double -- ^ right-hand side | ||
91 | -> Vector Double -- ^ solution | ||
92 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 | ||
93 | where | ||
94 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | ||
95 | |||
96 | cgSolve' | ||
97 | :: Bool -- ^ symmetric | ||
98 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | ||
99 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | ||
100 | -> Int -- ^ maximum number of iterations | ||
101 | -> GMatrix -- ^ coefficient matrix | ||
102 | -> V -- ^ initial solution | ||
103 | -> V -- ^ right-hand side | ||
104 | -> [CGState] -- ^ solution | ||
105 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | ||
106 | |||
107 | |||
108 | -------------------------------------------------------------------------------- | ||
109 | |||
110 | instance Testable GMatrix | ||
111 | where | ||
112 | checkT _ = (ok,info) | ||
113 | where | ||
114 | sma = convo2 20 3 | ||
115 | x1 = vect [1..20] | ||
116 | x2 = vect [1..40] | ||
117 | sm = mkSparse sma | ||
118 | dm = toDense sma | ||
119 | |||
120 | s1 = sm !#> x1 | ||
121 | d1 = dm #> x1 | ||
122 | |||
123 | s2 = tr sm !#> x2 | ||
124 | d2 = tr dm #> x2 | ||
125 | |||
126 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
127 | s3 = sdia !#> x1 | ||
128 | s4 = tr sdia !#> x2 | ||
129 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
130 | d3 = ddia #> x1 | ||
131 | d4 = tr ddia #> x2 | ||
132 | |||
133 | v = testb 40 | ||
134 | s5 = cgSolve False sm v | ||
135 | d5 = denseSolve dm v | ||
136 | |||
137 | info = do | ||
138 | print sm | ||
139 | disp (toDense sma) | ||
140 | print s1; print d1 | ||
141 | print s2; print d2 | ||
142 | print s3; print d3 | ||
143 | print s4; print d4 | ||
144 | print s5; print d5 | ||
145 | print $ relativeError (pnorm Infinity) s5 d5 | ||
146 | |||
147 | ok = s1==d1 | ||
148 | && s2==d2 | ||
149 | && s3==d3 | ||
150 | && s4==d4 | ||
151 | && relativeError (pnorm Infinity) s5 d5 < 1E-10 | ||
152 | |||
153 | disp = putStr . dispf 2 | ||
154 | |||
155 | vect = fromList :: [Double] -> Vector Double | ||
156 | |||
157 | convomat :: Int -> Int -> AssocMatrix | ||
158 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
159 | |||
160 | convo2 :: Int -> Int -> AssocMatrix | ||
161 | convo2 n k = m1 ++ m2 | ||
162 | where | ||
163 | m1 = convomat n k | ||
164 | m2 = map (((+n) *** id) *** id) m1 | ||
165 | |||
166 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
167 | |||
168 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
169 | |||
170 | -- mkDiag v = mkDiagR (dim v) (dim v) v | ||
171 | |||