summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util/CG.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs86
1 files changed, 75 insertions, 11 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index 5e2ea84..50372f1 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -3,11 +3,14 @@
3 3
4module Numeric.LinearAlgebra.Util.CG( 4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve, cgSolve', 5 cgSolve, cgSolve',
6 CGMat, CGState(..), R, V 6 CGState(..), R, V
7) where 7) where
8 8
9import Data.Packed.Numeric 9import Data.Packed.Numeric
10import Numeric.Sparse
10import Numeric.Vector() 11import Numeric.Vector()
12import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..))
13import Control.Arrow((***))
11 14
12{- 15{-
13import Util.Misc(debug, debugMat) 16import Util.Misc(debug, debugMat)
@@ -51,7 +54,7 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
51 rdx = norm2 dx / max 1 (norm2 x) 54 rdx = norm2 dx / max 1 (norm2 x)
52 55
53conjugrad 56conjugrad
54 :: (Transposable m, Contraction m V V) 57 :: (Transposable m mt, Contraction m V V, Contraction mt V V)
55 => Bool -> m -> V -> V -> R -> R -> [CGState] 58 => Bool -> m -> V -> V -> R -> R -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b 59conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57 60
@@ -82,27 +85,88 @@ takeUntil q xs = a++ take 1 b
82 where 85 where
83 (a,b) = break q xs 86 (a,b) = break q xs
84 87
85class (Transposable m, Contraction m V V) => CGMat m
86
87cgSolve 88cgSolve
88 :: CGMat m 89 :: Bool -- ^ is symmetric
89 => Bool -- ^ is symmetric 90 -> GMatrix -- ^ coefficient matrix
90 -> m -- ^ coefficient matrix
91 -> Vector Double -- ^ right-hand side 91 -> Vector Double -- ^ right-hand side
92 -> Vector Double -- ^ solution 92 -> Vector Double -- ^ solution
93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
94 where 94 where
95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) 95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
96 96
97cgSolve' 97cgSolve'
98 :: CGMat m 98 :: Bool -- ^ symmetric
99 => Bool -- ^ symmetric
100 -> R -- ^ relative tolerance for the residual (e.g. 1E-4) 99 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
101 -> R -- ^ relative tolerance for δx (e.g. 1E-3) 100 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
102 -> Int -- ^ maximum number of iterations 101 -> Int -- ^ maximum number of iterations
103 -> m -- ^ coefficient matrix 102 -> GMatrix -- ^ coefficient matrix
104 -> V -- ^ initial solution 103 -> V -- ^ initial solution
105 -> V -- ^ right-hand side 104 -> V -- ^ right-hand side
106 -> [CGState] -- ^ solution 105 -> [CGState] -- ^ solution
107cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es 106cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
108 107
108
109--------------------------------------------------------------------------------
110
111instance Testable GMatrix
112 where
113 checkT _ = (ok,info)
114 where
115 sma = convo2 20 3
116 x1 = vect [1..20]
117 x2 = vect [1..40]
118 sm = mkSparse sma
119 dm = toDense sma
120
121 s1 = sm !#> x1
122 d1 = dm #> x1
123
124 s2 = tr sm !#> x2
125 d2 = tr dm #> x2
126
127 sdia = mkDiagR 40 20 (vect [1..10])
128 s3 = sdia !#> x1
129 s4 = tr sdia !#> x2
130 ddia = diagRect 0 (vect [1..10]) 40 20
131 d3 = ddia #> x1
132 d4 = tr ddia #> x2
133
134 v = testb 40
135 s5 = cgSolve False sm v
136 d5 = denseSolve dm v
137
138 info = do
139 print sm
140 disp (toDense sma)
141 print s1; print d1
142 print s2; print d2
143 print s3; print d3
144 print s4; print d4
145 print s5; print d5
146 print $ relativeError Infinity s5 d5
147
148 ok = s1==d1
149 && s2==d2
150 && s3==d3
151 && s4==d4
152 && relativeError Infinity s5 d5 < 1E-10
153
154 disp = putStr . dispf 2
155
156 vect = fromList :: [Double] -> Vector Double
157
158 convomat :: Int -> Int -> AssocMatrix
159 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
160
161 convo2 :: Int -> Int -> AssocMatrix
162 convo2 n k = m1 ++ m2
163 where
164 m1 = convomat n k
165 m2 = map (((+n) *** id) *** id) m1
166
167 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
168
169 denseSolve a = flatten . linearSolveLS a . asColumn
170
171 -- mkDiag v = mkDiagR (dim v) (dim v) v
172