summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/CG.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-05 16:36:08 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-05 16:36:08 +0200
commit379f6a9855a36979c0670a3f89b6c7202836369c (patch)
tree0447a77cbc32fab7193b89f758e33ef4b7ed77c1 /packages/base/src/Internal/CG.hs
parent2876998f04380c9e835c6177b440447368dfe623 (diff)
move cg
Diffstat (limited to 'packages/base/src/Internal/CG.hs')
-rw-r--r--packages/base/src/Internal/CG.hs177
1 files changed, 177 insertions, 0 deletions
diff --git a/packages/base/src/Internal/CG.hs b/packages/base/src/Internal/CG.hs
new file mode 100644
index 0000000..1193b18
--- /dev/null
+++ b/packages/base/src/Internal/CG.hs
@@ -0,0 +1,177 @@
1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
2{-# LANGUAGE RecordWildCards #-}
3
4module Internal.CG(
5 cgSolve, cgSolve',
6 CGState(..), R, V
7) where
8
9import Internal.Vector
10import Internal.Matrix hiding (mat)
11import Internal.Numeric
12import Internal.Element
13import Internal.IO
14import Internal.Container
15import Internal.Sparse
16import Numeric.Vector()
17import Internal.Algorithms(linearSolveLS, relativeError, pnorm, NormType(..))
18import Control.Arrow((***))
19import Data.Vector.Storable(fromList)
20
21{-
22import Util.Misc(debug, debugMat)
23
24(//) :: Show a => a -> String -> a
25infix 0 // -- , ///
26a // b = debug b id a
27
28(///) :: V -> String -> V
29infix 0 ///
30v /// b = debugMat b 2 asRow v
31-}
32
33type R = Double
34type V = Vector R
35
36data CGState = CGState
37 { cgp :: V -- ^ conjugate gradient
38 , cgr :: V -- ^ residual
39 , cgr2 :: R -- ^ squared norm of residual
40 , cgx :: V -- ^ current solution
41 , cgdx :: R -- ^ normalized size of correction
42 }
43
44cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
45cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
46 where
47 ap1 = a p
48 ap | sym = ap1
49 | otherwise = at ap1
50 pap | sym = p <·> ap1
51 | otherwise = norm2 ap1 ** 2
52 alpha = r2 / pap
53 dx = scale alpha p
54 x' = x + dx
55 r' = r - scale alpha ap
56 r'2 = r' <·> r'
57 beta = r'2 / r2
58 p' = r' + scale beta p
59
60 rdx = norm2 dx / max 1 (norm2 x)
61
62conjugrad
63 :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState]
64conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b
65
66solveG
67 :: (V -> V) -> (V -> V)
68 -> ((V -> V) -> (V -> V) -> CGState -> CGState)
69 -> V
70 -> V
71 -> R -> R
72 -> [CGState]
73solveG mat ma meth rawb x0' ϵb ϵx
74 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
75 where
76 a = mat . ma
77 b = mat rawb
78 x0 = if x0' == 0 then konst 0 (dim b) else x0'
79 r0 = b - a x0
80 r20 = r0 <·> r0
81 p0 = r0
82 nb2 = b <·> b
83 ok CGState {..}
84 = cgr2 <nb2*ϵb**2
85 || cgdx < ϵx
86
87
88takeUntil :: (a -> Bool) -> [a] -> [a]
89takeUntil q xs = a++ take 1 b
90 where
91 (a,b) = break q xs
92
93cgSolve
94 :: Bool -- ^ is symmetric
95 -> GMatrix -- ^ coefficient matrix
96 -> Vector Double -- ^ right-hand side
97 -> Vector Double -- ^ solution
98cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
99 where
100 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
101
102cgSolve'
103 :: Bool -- ^ symmetric
104 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
105 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
106 -> Int -- ^ maximum number of iterations
107 -> GMatrix -- ^ coefficient matrix
108 -> V -- ^ initial solution
109 -> V -- ^ right-hand side
110 -> [CGState] -- ^ solution
111cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
112
113
114--------------------------------------------------------------------------------
115
116instance Testable GMatrix
117 where
118 checkT _ = (ok,info)
119 where
120 sma = convo2 20 3
121 x1 = vect [1..20]
122 x2 = vect [1..40]
123 sm = mkSparse sma
124 dm = toDense sma
125
126 s1 = sm !#> x1
127 d1 = dm #> x1
128
129 s2 = tr sm !#> x2
130 d2 = tr dm #> x2
131
132 sdia = mkDiagR 40 20 (vect [1..10])
133 s3 = sdia !#> x1
134 s4 = tr sdia !#> x2
135 ddia = diagRect 0 (vect [1..10]) 40 20
136 d3 = ddia #> x1
137 d4 = tr ddia #> x2
138
139 v = testb 40
140 s5 = cgSolve False sm v
141 d5 = denseSolve dm v
142
143 info = do
144 print sm
145 disp (toDense sma)
146 print s1; print d1
147 print s2; print d2
148 print s3; print d3
149 print s4; print d4
150 print s5; print d5
151 print $ relativeError (pnorm Infinity) s5 d5
152
153 ok = s1==d1
154 && s2==d2
155 && s3==d3
156 && s4==d4
157 && relativeError (pnorm Infinity) s5 d5 < 1E-10
158
159 disp = putStr . dispf 2
160
161 vect = fromList :: [Double] -> Vector Double
162
163 convomat :: Int -> Int -> AssocMatrix
164 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
165
166 convo2 :: Int -> Int -> AssocMatrix
167 convo2 n k = m1 ++ m2
168 where
169 m1 = convomat n k
170 m2 = map (((+n) *** id) *** id) m1
171
172 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
173
174 denseSolve a = flatten . linearSolveLS a . asColumn
175
176 -- mkDiag v = mkDiagR (dim v) (dim v) v
177