summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-11 10:46:39 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-11 10:46:39 +0000
commit473df6136476dfa07331dd25a6020260c4f02a9b (patch)
tree0639081371f7f0d3d03aba2a975921690c19f149
parentf2cf177e93d4578b404909c68b24625a76466ee5 (diff)
all eig
-rw-r--r--examples/tests.hs76
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs7
-rw-r--r--lib/LAPACK.hs2
-rw-r--r--lib/LAPACK/Internal.hs72
4 files changed, 129 insertions, 28 deletions
diff --git a/examples/tests.hs b/examples/tests.hs
index 5af33ba..53436c8 100644
--- a/examples/tests.hs
+++ b/examples/tests.hs
@@ -81,8 +81,6 @@ cf = mulF af bc
81 81
82r = mulC cc (trans cf) 82r = mulC cc (trans cf)
83 83
84ident n = diag (constant n 1)
85
86rd = (2><2) 84rd = (2><2)
87 [ 43492.0, 50572.0 85 [ 43492.0, 50572.0
88 , 102550.0, 119242.0 :: Double] 86 , 102550.0, 119242.0 :: Double]
@@ -126,33 +124,64 @@ instance (Field a, Arbitrary a) => Arbitrary (SqM a) where
126 return $ SqM $ (n><n) l 124 return $ SqM $ (n><n) l
127 coarbitrary = undefined 125 coarbitrary = undefined
128 126
127data Sym a = Sym (Matrix a) deriving Show
128instance (Field a, Arbitrary a, Num a) => Arbitrary (Sym a) where
129 arbitrary = do
130 SqM m <- arbitrary
131 return $ Sym (m `addM` trans m)
132 coarbitrary = undefined
129 133
130type BaseType = Double 134data Her = Her (Matrix (Complex Double)) deriving Show
135instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where
136 arbitrary = do
137 SqM m <- arbitrary
138 return $ Her (m `addM` (liftMatrix conj) (trans m))
139 coarbitrary = undefined
140
141
142
143addM m1 m2 = liftMatrix2 addV m1 m2
131 144
145addV v1 v2 = fromList $ zipWith (+) (toList v1) (toList v2)
132 146
133svdTestR fun prod m = u <> s <> trans v |~| m 147
148type BaseType = Double
149
150svdTestR prod m = u <> s <> trans v |~| m
134 && u <> trans u |~| ident (rows m) 151 && u <> trans u |~| ident (rows m)
135 && v <> trans v |~| ident (cols m) 152 && v <> trans v |~| ident (cols m)
136 where (u,s,v) = fun m 153 where (u,s,v) = svdR m
137 (<>) = prod 154 (<>) = prod
138 155
139 156
140svdTestC fun prod m = u <> s' <> (trans v) |~~| m 157svdTestC prod m = u <> s' <> (trans v) |~~| m
141 && u <> (liftMatrix conj) (trans u) |~~| ident (rows m) 158 && u <> (liftMatrix conj) (trans u) |~~| ident (rows m)
142 && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) 159 && v <> (liftMatrix conj) (trans v) |~~| ident (cols m)
143 where (u,s,v) = fun m 160 where (u,s,v) = svdC m
144 (<>) = prod 161 (<>) = prod
145 s' = liftMatrix comp s 162 s' = liftMatrix comp s
146 163
147eigTestC fun prod (SqM m) = (m <> v) |~~| (v <> diag s) 164eigTestC prod (SqM m) = (m <> v) |~~| (v <> diag s)
148 && takeDiag ((liftMatrix conj (trans v)) `mulC` v) ~~ constant (rows m) 1 165 && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant (rows m) 1 --normalized
149 where (s,v) = fun m 166 where (s,v) = eigC m
167 (<>) = prod
168
169eigTestR prod (SqM m) = (liftMatrix comp m <> v) |~~| (v <> diag s)
170 -- && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant (rows m) 1 --normalized ???
171 where (s,v) = eigR m
172 (<>) = prod
173
174eigTestS prod (Sym m) = (m <> v) |~| (v <> diag s)
175 && v <> trans v |~| ident (cols m)
176 where (s,v) = eigS m
150 (<>) = prod 177 (<>) = prod
151 178
152takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 179eigTestH prod (Her m) = (m <> v) |~~| (v <> diag (comp s))
180 && v <> (liftMatrix conj) (trans v) |~~| ident (cols m)
181 where (s,v) = eigH m
182 (<>) = prod
153 183
154 184
155comp v = toComplex (v,constant (dim v) 0)
156 185
157main = do 186main = do
158 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) 187 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType])
@@ -162,8 +191,21 @@ main = do
162 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| mulF m1 (m2 :: Matrix BaseType) 191 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| mulF m1 (m2 :: Matrix BaseType)
163 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| trans (mulF (trans m2) (trans m1 :: Matrix BaseType)) 192 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| trans (mulF (trans m2) (trans m1 :: Matrix BaseType))
164 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| multiplyG m1 (m2 :: Matrix BaseType) 193 quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| multiplyG m1 (m2 :: Matrix BaseType)
165 quickCheck (svdTestR svdR mulC) 194 quickCheck (svdTestR mulC)
166 quickCheck (svdTestR svdR mulF) 195 quickCheck (svdTestR mulF)
167 quickCheck (svdTestC svdC mulC) 196 quickCheck (svdTestC mulC)
168 quickCheck (svdTestC svdC mulF) 197 quickCheck (svdTestC mulF)
169 quickCheck (eigTestC eigC mulC) 198 quickCheck (eigTestC mulC)
199 quickCheck (eigTestC mulF)
200 quickCheck (eigTestR mulC)
201 quickCheck (eigTestR mulF)
202 quickCheck (\(Sym m) -> m |=| (trans m:: Matrix BaseType))
203 quickCheck (eigTestS mulC)
204 quickCheck (eigTestS mulF)
205 quickCheck (eigTestH mulC)
206 quickCheck (eigTestH mulF)
207
208
209kk = (2><2)
210 [ 1.0, 0.0
211 , -1.5, 1.0 ::Double]
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index b8de245..bae56f1 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -190,6 +190,8 @@ conj :: Vector (Complex Double) -> Vector (Complex Double)
190conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1]) 190conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1])
191 where mulC = multiply RowMajor 191 where mulC = multiply RowMajor
192 192
193comp v = toComplex (v,constant (dim v) 0)
194
193------------------------------------------------------------------------------ 195------------------------------------------------------------------------------
194 196
195-- | Reverse rows 197-- | Reverse rows
@@ -203,6 +205,7 @@ fliprl m = fromColumns . reverse . toColumns $ m
203----------------------------------------------------------------- 205-----------------------------------------------------------------
204 206
205liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes 207liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes
208liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes
206 209
207------------------------------------------------------------------ 210------------------------------------------------------------------
208 211
@@ -333,3 +336,7 @@ diagRect s r c
333 | r < c = trans $ diagRect s c r 336 | r < c = trans $ diagRect s c r
334 | r > c = joinVert [diag s , zeros (r-c,c)] 337 | r > c = joinVert [diag s , zeros (r-c,c)]
335 where zeros (r,c) = reshape c $ constant (r*c) 0 338 where zeros (r,c) = reshape c $ constant (r*c) 0
339
340takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
341
342ident n = diag (constant n 1)
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs
index 088b6d7..0f1a178 100644
--- a/lib/LAPACK.hs
+++ b/lib/LAPACK.hs
@@ -15,7 +15,7 @@
15module LAPACK ( 15module LAPACK (
16 --module LAPACK.Internal 16 --module LAPACK.Internal
17 svdR, svdR', svdC, svdC', 17 svdR, svdR', svdC, svdC',
18 eigC, 18 eigC, eigR, eigS, eigH,
19 linearSolveLSR 19 linearSolveLSR
20) where 20) where
21 21
diff --git a/lib/LAPACK/Internal.hs b/lib/LAPACK/Internal.hs
index e3c9927..ba50e6b 100644
--- a/lib/LAPACK/Internal.hs
+++ b/lib/LAPACK/Internal.hs
@@ -79,17 +79,48 @@ eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Dou
79eigC (m@M {rows = r}) 79eigC (m@M {rows = r})
80 | r == 1 = (fromList [cdat m `at` 0], singleton 1) 80 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
81 | otherwise = unsafePerformIO $ do 81 | otherwise = unsafePerformIO $ do
82 l <- createVector r 82 l <- createVector r
83 v <- createMatrix ColumnMajor r r 83 v <- createMatrix ColumnMajor r r
84 dummy <- createMatrix ColumnMajor 1 1 84 dummy <- createMatrix ColumnMajor 1 1
85 zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] 85 zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m]
86 return (l,v) 86 return (l,v)
87 87
88----------------------------------------------------------------------------- 88-----------------------------------------------------------------------------
89-- dgeev 89-- dgeev
90foreign import ccall "lapack-aux.h eig_l_R" 90foreign import ccall "lapack-aux.h eig_l_R"
91 dgeev :: Double ::> Double ::> ((Complex Double) :> Double ::> IO Int) 91 dgeev :: Double ::> Double ::> ((Complex Double) :> Double ::> IO Int)
92 92
93-- | Wrapper for LAPACK's /dgeev/, which computes the eigenvalues and right eigenvectors of a general real matrix:
94--
95-- if @(l,v)=eigR m@ then @m \<\> v = v \<\> diag l@.
96--
97-- The eigenvectors are the columns of v.
98-- The eigenvalues are not sorted.
99eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double))
100eigR (m@M {rows = r}) = (s', v'')
101 where (s,v) = eigRaux m
102 s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s))
103 v' = toRows $ trans v
104 v'' = fromColumns $ fixeig (toList s') v'
105
106eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double)
107eigRaux (m@M {rows = r})
108 | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1)
109 | otherwise = unsafePerformIO $ do
110 l <- createVector r
111 v <- createMatrix ColumnMajor r r
112 dummy <- createMatrix ColumnMajor 1 1
113 dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m]
114 return (l,v)
115
116fixeig [] _ = []
117fixeig [r] [v] = [comp v]
118fixeig ((r1:+i1):(r2:+i2):r) (v1:v2:vs)
119 | r1 == r2 && i1 == (-i2) = toComplex (v1,v2) : toComplex (v1,scale (-1) v2) : fixeig r vs
120 | otherwise = comp v1 : fixeig ((r2:+i2):r) (v2:vs)
121
122scale r v = fromList [r] `outer` v
123
93----------------------------------------------------------------------------- 124-----------------------------------------------------------------------------
94-- dsyev 125-- dsyev
95foreign import ccall "lapack-aux.h eig_l_S" 126foreign import ccall "lapack-aux.h eig_l_S"
@@ -106,17 +137,38 @@ eigS m = (s', fliprl v)
106 where (s,v) = eigS' m 137 where (s,v) = eigS' m
107 s' = fromList . reverse . toList $ s 138 s' = fromList . reverse . toList $ s
108 139
109eigS' (m@M {rows = r}) = unsafePerformIO $ do 140eigS' (m@M {rows = r})
110 l <- createVector r 141 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
111 v <- createMatrix ColumnMajor r r 142 | otherwise = unsafePerformIO $ do
112 dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] 143 l <- createVector r
113 return (l,v) 144 v <- createMatrix ColumnMajor r r
145 dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m]
146 return (l,v)
114 147
115----------------------------------------------------------------------------- 148-----------------------------------------------------------------------------
116-- zheev 149-- zheev
117foreign import ccall "lapack-aux.h eig_l_H" 150foreign import ccall "lapack-aux.h eig_l_H"
118 zheev :: (Complex Double) ::> (Double :> (Complex Double) ::> IO Int) 151 zheev :: (Complex Double) ::> (Double :> (Complex Double) ::> IO Int)
119 152
153-- | Wrapper for LAPACK's /zheev/, which computes the eigenvalues and right eigenvectors of a hermitian complex matrix:
154--
155-- if @(l,v)=eigH m@ then @m \<\> s v = v \<\> diag l@.
156--
157-- The eigenvectors are the columns of v.
158-- The eigenvalues are sorted in descending order.
159eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double))
160eigH m = (s', fliprl v)
161 where (s,v) = eigH' m
162 s' = fromList . reverse . toList $ s
163
164eigH' (m@M {rows = r})
165 | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1)
166 | otherwise = unsafePerformIO $ do
167 l <- createVector r
168 v <- createMatrix ColumnMajor r r
169 zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m]
170 return (l,v)
171
120----------------------------------------------------------------------------- 172-----------------------------------------------------------------------------
121-- dgesv 173-- dgesv
122foreign import ccall "lapack-aux.h linearSolveR_l" 174foreign import ccall "lapack-aux.h linearSolveR_l"