diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-27 10:41:40 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-27 10:41:40 +0200 |
commit | cf3c788f0c44577ac1a5365e8154200b53a36409 (patch) | |
tree | d667ea10609e74b69b11309bb59b7e000b240a92 /packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |
parent | 365e2435e71de10ebe849acac5a107b6f43817c4 (diff) |
static dimensions, cont.
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util/CG.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 86 |
1 files changed, 75 insertions, 11 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 5e2ea84..50372f1 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -3,11 +3,14 @@ | |||
3 | 3 | ||
4 | module Numeric.LinearAlgebra.Util.CG( | 4 | module Numeric.LinearAlgebra.Util.CG( |
5 | cgSolve, cgSolve', | 5 | cgSolve, cgSolve', |
6 | CGMat, CGState(..), R, V | 6 | CGState(..), R, V |
7 | ) where | 7 | ) where |
8 | 8 | ||
9 | import Data.Packed.Numeric | 9 | import Data.Packed.Numeric |
10 | import Numeric.Sparse | ||
10 | import Numeric.Vector() | 11 | import Numeric.Vector() |
12 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) | ||
13 | import Control.Arrow((***)) | ||
11 | 14 | ||
12 | {- | 15 | {- |
13 | import Util.Misc(debug, debugMat) | 16 | import Util.Misc(debug, debugMat) |
@@ -51,7 +54,7 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
51 | rdx = norm2 dx / max 1 (norm2 x) | 54 | rdx = norm2 dx / max 1 (norm2 x) |
52 | 55 | ||
53 | conjugrad | 56 | conjugrad |
54 | :: (Transposable m, Contraction m V V) | 57 | :: (Transposable m mt, Contraction m V V, Contraction mt V V) |
55 | => Bool -> m -> V -> V -> R -> R -> [CGState] | 58 | => Bool -> m -> V -> V -> R -> R -> [CGState] |
56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | 59 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b |
57 | 60 | ||
@@ -82,27 +85,88 @@ takeUntil q xs = a++ take 1 b | |||
82 | where | 85 | where |
83 | (a,b) = break q xs | 86 | (a,b) = break q xs |
84 | 87 | ||
85 | class (Transposable m, Contraction m V V) => CGMat m | ||
86 | |||
87 | cgSolve | 88 | cgSolve |
88 | :: CGMat m | 89 | :: Bool -- ^ is symmetric |
89 | => Bool -- ^ is symmetric | 90 | -> GMatrix -- ^ coefficient matrix |
90 | -> m -- ^ coefficient matrix | ||
91 | -> Vector Double -- ^ right-hand side | 91 | -> Vector Double -- ^ right-hand side |
92 | -> Vector Double -- ^ solution | 92 | -> Vector Double -- ^ solution |
93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 | 93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 |
94 | where | 94 | where |
95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | 95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) |
96 | 96 | ||
97 | cgSolve' | 97 | cgSolve' |
98 | :: CGMat m | 98 | :: Bool -- ^ symmetric |
99 | => Bool -- ^ symmetric | ||
100 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | 99 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) |
101 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | 100 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) |
102 | -> Int -- ^ maximum number of iterations | 101 | -> Int -- ^ maximum number of iterations |
103 | -> m -- ^ coefficient matrix | 102 | -> GMatrix -- ^ coefficient matrix |
104 | -> V -- ^ initial solution | 103 | -> V -- ^ initial solution |
105 | -> V -- ^ right-hand side | 104 | -> V -- ^ right-hand side |
106 | -> [CGState] -- ^ solution | 105 | -> [CGState] -- ^ solution |
107 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | 106 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es |
108 | 107 | ||
108 | |||
109 | -------------------------------------------------------------------------------- | ||
110 | |||
111 | instance Testable GMatrix | ||
112 | where | ||
113 | checkT _ = (ok,info) | ||
114 | where | ||
115 | sma = convo2 20 3 | ||
116 | x1 = vect [1..20] | ||
117 | x2 = vect [1..40] | ||
118 | sm = mkSparse sma | ||
119 | dm = toDense sma | ||
120 | |||
121 | s1 = sm !#> x1 | ||
122 | d1 = dm #> x1 | ||
123 | |||
124 | s2 = tr sm !#> x2 | ||
125 | d2 = tr dm #> x2 | ||
126 | |||
127 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
128 | s3 = sdia !#> x1 | ||
129 | s4 = tr sdia !#> x2 | ||
130 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
131 | d3 = ddia #> x1 | ||
132 | d4 = tr ddia #> x2 | ||
133 | |||
134 | v = testb 40 | ||
135 | s5 = cgSolve False sm v | ||
136 | d5 = denseSolve dm v | ||
137 | |||
138 | info = do | ||
139 | print sm | ||
140 | disp (toDense sma) | ||
141 | print s1; print d1 | ||
142 | print s2; print d2 | ||
143 | print s3; print d3 | ||
144 | print s4; print d4 | ||
145 | print s5; print d5 | ||
146 | print $ relativeError Infinity s5 d5 | ||
147 | |||
148 | ok = s1==d1 | ||
149 | && s2==d2 | ||
150 | && s3==d3 | ||
151 | && s4==d4 | ||
152 | && relativeError Infinity s5 d5 < 1E-10 | ||
153 | |||
154 | disp = putStr . dispf 2 | ||
155 | |||
156 | vect = fromList :: [Double] -> Vector Double | ||
157 | |||
158 | convomat :: Int -> Int -> AssocMatrix | ||
159 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
160 | |||
161 | convo2 :: Int -> Int -> AssocMatrix | ||
162 | convo2 n k = m1 ++ m2 | ||
163 | where | ||
164 | m1 = convomat n k | ||
165 | m2 = map (((+n) *** id) *** id) m1 | ||
166 | |||
167 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
168 | |||
169 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
170 | |||
171 | -- mkDiag v = mkDiagR (dim v) (dim v) v | ||
172 | |||