summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Util
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-05 16:36:08 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-05 16:36:08 +0200
commit379f6a9855a36979c0670a3f89b6c7202836369c (patch)
tree0447a77cbc32fab7193b89f758e33ef4b7ed77c1 /packages/base/src/Numeric/LinearAlgebra/Util
parent2876998f04380c9e835c6177b440447368dfe623 (diff)
move cg
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs171
1 files changed, 0 insertions, 171 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
deleted file mode 100644
index 899a5bf..0000000
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ /dev/null
@@ -1,171 +0,0 @@
1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
2{-# LANGUAGE RecordWildCards #-}
3
4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve, cgSolve',
6 CGState(..), R, V
7) where
8
9import Data.Packed.Numeric
10import Numeric.Sparse
11import Numeric.Vector()
12import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, pnorm, NormType(..))
13import Control.Arrow((***))
14
15{-
16import Util.Misc(debug, debugMat)
17
18(//) :: Show a => a -> String -> a
19infix 0 // -- , ///
20a // b = debug b id a
21
22(///) :: V -> String -> V
23infix 0 ///
24v /// b = debugMat b 2 asRow v
25-}
26
27type R = Double
28type V = Vector R
29
30data CGState = CGState
31 { cgp :: V -- ^ conjugate gradient
32 , cgr :: V -- ^ residual
33 , cgr2 :: R -- ^ squared norm of residual
34 , cgx :: V -- ^ current solution
35 , cgdx :: R -- ^ normalized size of correction
36 }
37
38cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
39cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
40 where
41 ap1 = a p
42 ap | sym = ap1
43 | otherwise = at ap1
44 pap | sym = p <·> ap1
45 | otherwise = norm2 ap1 ** 2
46 alpha = r2 / pap
47 dx = scale alpha p
48 x' = x + dx
49 r' = r - scale alpha ap
50 r'2 = r' <·> r'
51 beta = r'2 / r2
52 p' = r' + scale beta p
53
54 rdx = norm2 dx / max 1 (norm2 x)
55
56conjugrad
57 :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState]
58conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b
59
60solveG
61 :: (V -> V) -> (V -> V)
62 -> ((V -> V) -> (V -> V) -> CGState -> CGState)
63 -> V
64 -> V
65 -> R -> R
66 -> [CGState]
67solveG mat ma meth rawb x0' ϵb ϵx
68 = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1
69 where
70 a = mat . ma
71 b = mat rawb
72 x0 = if x0' == 0 then konst 0 (dim b) else x0'
73 r0 = b - a x0
74 r20 = r0 <·> r0
75 p0 = r0
76 nb2 = b <·> b
77 ok CGState {..}
78 = cgr2 <nb2*ϵb**2
79 || cgdx < ϵx
80
81
82takeUntil :: (a -> Bool) -> [a] -> [a]
83takeUntil q xs = a++ take 1 b
84 where
85 (a,b) = break q xs
86
87cgSolve
88 :: Bool -- ^ is symmetric
89 -> GMatrix -- ^ coefficient matrix
90 -> Vector Double -- ^ right-hand side
91 -> Vector Double -- ^ solution
92cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
93 where
94 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
95
96cgSolve'
97 :: Bool -- ^ symmetric
98 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
99 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
100 -> Int -- ^ maximum number of iterations
101 -> GMatrix -- ^ coefficient matrix
102 -> V -- ^ initial solution
103 -> V -- ^ right-hand side
104 -> [CGState] -- ^ solution
105cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
106
107
108--------------------------------------------------------------------------------
109
110instance Testable GMatrix
111 where
112 checkT _ = (ok,info)
113 where
114 sma = convo2 20 3
115 x1 = vect [1..20]
116 x2 = vect [1..40]
117 sm = mkSparse sma
118 dm = toDense sma
119
120 s1 = sm !#> x1
121 d1 = dm #> x1
122
123 s2 = tr sm !#> x2
124 d2 = tr dm #> x2
125
126 sdia = mkDiagR 40 20 (vect [1..10])
127 s3 = sdia !#> x1
128 s4 = tr sdia !#> x2
129 ddia = diagRect 0 (vect [1..10]) 40 20
130 d3 = ddia #> x1
131 d4 = tr ddia #> x2
132
133 v = testb 40
134 s5 = cgSolve False sm v
135 d5 = denseSolve dm v
136
137 info = do
138 print sm
139 disp (toDense sma)
140 print s1; print d1
141 print s2; print d2
142 print s3; print d3
143 print s4; print d4
144 print s5; print d5
145 print $ relativeError (pnorm Infinity) s5 d5
146
147 ok = s1==d1
148 && s2==d2
149 && s3==d3
150 && s4==d4
151 && relativeError (pnorm Infinity) s5 d5 < 1E-10
152
153 disp = putStr . dispf 2
154
155 vect = fromList :: [Double] -> Vector Double
156
157 convomat :: Int -> Int -> AssocMatrix
158 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
159
160 convo2 :: Int -> Int -> AssocMatrix
161 convo2 n k = m1 ++ m2
162 where
163 m1 = convomat n k
164 m2 = map (((+n) *** id) *** id) m1
165
166 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
167
168 denseSolve a = flatten . linearSolveLS a . asColumn
169
170 -- mkDiag v = mkDiagR (dim v) (dim v) v
171