diff options
-rw-r--r-- | examples/tests.hs | 76 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 7 | ||||
-rw-r--r-- | lib/LAPACK.hs | 2 | ||||
-rw-r--r-- | lib/LAPACK/Internal.hs | 72 |
4 files changed, 129 insertions, 28 deletions
diff --git a/examples/tests.hs b/examples/tests.hs index 5af33ba..53436c8 100644 --- a/examples/tests.hs +++ b/examples/tests.hs | |||
@@ -81,8 +81,6 @@ cf = mulF af bc | |||
81 | 81 | ||
82 | r = mulC cc (trans cf) | 82 | r = mulC cc (trans cf) |
83 | 83 | ||
84 | ident n = diag (constant n 1) | ||
85 | |||
86 | rd = (2><2) | 84 | rd = (2><2) |
87 | [ 43492.0, 50572.0 | 85 | [ 43492.0, 50572.0 |
88 | , 102550.0, 119242.0 :: Double] | 86 | , 102550.0, 119242.0 :: Double] |
@@ -126,33 +124,64 @@ instance (Field a, Arbitrary a) => Arbitrary (SqM a) where | |||
126 | return $ SqM $ (n><n) l | 124 | return $ SqM $ (n><n) l |
127 | coarbitrary = undefined | 125 | coarbitrary = undefined |
128 | 126 | ||
127 | data Sym a = Sym (Matrix a) deriving Show | ||
128 | instance (Field a, Arbitrary a, Num a) => Arbitrary (Sym a) where | ||
129 | arbitrary = do | ||
130 | SqM m <- arbitrary | ||
131 | return $ Sym (m `addM` trans m) | ||
132 | coarbitrary = undefined | ||
129 | 133 | ||
130 | type BaseType = Double | 134 | data Her = Her (Matrix (Complex Double)) deriving Show |
135 | instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where | ||
136 | arbitrary = do | ||
137 | SqM m <- arbitrary | ||
138 | return $ Her (m `addM` (liftMatrix conj) (trans m)) | ||
139 | coarbitrary = undefined | ||
140 | |||
141 | |||
142 | |||
143 | addM m1 m2 = liftMatrix2 addV m1 m2 | ||
131 | 144 | ||
145 | addV v1 v2 = fromList $ zipWith (+) (toList v1) (toList v2) | ||
132 | 146 | ||
133 | svdTestR fun prod m = u <> s <> trans v |~| m | 147 | |
148 | type BaseType = Double | ||
149 | |||
150 | svdTestR prod m = u <> s <> trans v |~| m | ||
134 | && u <> trans u |~| ident (rows m) | 151 | && u <> trans u |~| ident (rows m) |
135 | && v <> trans v |~| ident (cols m) | 152 | && v <> trans v |~| ident (cols m) |
136 | where (u,s,v) = fun m | 153 | where (u,s,v) = svdR m |
137 | (<>) = prod | 154 | (<>) = prod |
138 | 155 | ||
139 | 156 | ||
140 | svdTestC fun prod m = u <> s' <> (trans v) |~~| m | 157 | svdTestC prod m = u <> s' <> (trans v) |~~| m |
141 | && u <> (liftMatrix conj) (trans u) |~~| ident (rows m) | 158 | && u <> (liftMatrix conj) (trans u) |~~| ident (rows m) |
142 | && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) | 159 | && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) |
143 | where (u,s,v) = fun m | 160 | where (u,s,v) = svdC m |
144 | (<>) = prod | 161 | (<>) = prod |
145 | s' = liftMatrix comp s | 162 | s' = liftMatrix comp s |
146 | 163 | ||
147 | eigTestC fun prod (SqM m) = (m <> v) |~~| (v <> diag s) | 164 | eigTestC prod (SqM m) = (m <> v) |~~| (v <> diag s) |
148 | && takeDiag ((liftMatrix conj (trans v)) `mulC` v) ~~ constant (rows m) 1 | 165 | && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant (rows m) 1 --normalized |
149 | where (s,v) = fun m | 166 | where (s,v) = eigC m |
167 | (<>) = prod | ||
168 | |||
169 | eigTestR prod (SqM m) = (liftMatrix comp m <> v) |~~| (v <> diag s) | ||
170 | -- && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant (rows m) 1 --normalized ??? | ||
171 | where (s,v) = eigR m | ||
172 | (<>) = prod | ||
173 | |||
174 | eigTestS prod (Sym m) = (m <> v) |~| (v <> diag s) | ||
175 | && v <> trans v |~| ident (cols m) | ||
176 | where (s,v) = eigS m | ||
150 | (<>) = prod | 177 | (<>) = prod |
151 | 178 | ||
152 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 179 | eigTestH prod (Her m) = (m <> v) |~~| (v <> diag (comp s)) |
180 | && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) | ||
181 | where (s,v) = eigH m | ||
182 | (<>) = prod | ||
153 | 183 | ||
154 | 184 | ||
155 | comp v = toComplex (v,constant (dim v) 0) | ||
156 | 185 | ||
157 | main = do | 186 | main = do |
158 | quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) | 187 | quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) |
@@ -162,8 +191,21 @@ main = do | |||
162 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| mulF m1 (m2 :: Matrix BaseType) | 191 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| mulF m1 (m2 :: Matrix BaseType) |
163 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| trans (mulF (trans m2) (trans m1 :: Matrix BaseType)) | 192 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| trans (mulF (trans m2) (trans m1 :: Matrix BaseType)) |
164 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| multiplyG m1 (m2 :: Matrix BaseType) | 193 | quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| multiplyG m1 (m2 :: Matrix BaseType) |
165 | quickCheck (svdTestR svdR mulC) | 194 | quickCheck (svdTestR mulC) |
166 | quickCheck (svdTestR svdR mulF) | 195 | quickCheck (svdTestR mulF) |
167 | quickCheck (svdTestC svdC mulC) | 196 | quickCheck (svdTestC mulC) |
168 | quickCheck (svdTestC svdC mulF) | 197 | quickCheck (svdTestC mulF) |
169 | quickCheck (eigTestC eigC mulC) | 198 | quickCheck (eigTestC mulC) |
199 | quickCheck (eigTestC mulF) | ||
200 | quickCheck (eigTestR mulC) | ||
201 | quickCheck (eigTestR mulF) | ||
202 | quickCheck (\(Sym m) -> m |=| (trans m:: Matrix BaseType)) | ||
203 | quickCheck (eigTestS mulC) | ||
204 | quickCheck (eigTestS mulF) | ||
205 | quickCheck (eigTestH mulC) | ||
206 | quickCheck (eigTestH mulF) | ||
207 | |||
208 | |||
209 | kk = (2><2) | ||
210 | [ 1.0, 0.0 | ||
211 | , -1.5, 1.0 ::Double] | ||
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index b8de245..bae56f1 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -190,6 +190,8 @@ conj :: Vector (Complex Double) -> Vector (Complex Double) | |||
190 | conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1]) | 190 | conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1]) |
191 | where mulC = multiply RowMajor | 191 | where mulC = multiply RowMajor |
192 | 192 | ||
193 | comp v = toComplex (v,constant (dim v) 0) | ||
194 | |||
193 | ------------------------------------------------------------------------------ | 195 | ------------------------------------------------------------------------------ |
194 | 196 | ||
195 | -- | Reverse rows | 197 | -- | Reverse rows |
@@ -203,6 +205,7 @@ fliprl m = fromColumns . reverse . toColumns $ m | |||
203 | ----------------------------------------------------------------- | 205 | ----------------------------------------------------------------- |
204 | 206 | ||
205 | liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes | 207 | liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes |
208 | liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes | ||
206 | 209 | ||
207 | ------------------------------------------------------------------ | 210 | ------------------------------------------------------------------ |
208 | 211 | ||
@@ -333,3 +336,7 @@ diagRect s r c | |||
333 | | r < c = trans $ diagRect s c r | 336 | | r < c = trans $ diagRect s c r |
334 | | r > c = joinVert [diag s , zeros (r-c,c)] | 337 | | r > c = joinVert [diag s , zeros (r-c,c)] |
335 | where zeros (r,c) = reshape c $ constant (r*c) 0 | 338 | where zeros (r,c) = reshape c $ constant (r*c) 0 |
339 | |||
340 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | ||
341 | |||
342 | ident n = diag (constant n 1) | ||
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs index 088b6d7..0f1a178 100644 --- a/lib/LAPACK.hs +++ b/lib/LAPACK.hs | |||
@@ -15,7 +15,7 @@ | |||
15 | module LAPACK ( | 15 | module LAPACK ( |
16 | --module LAPACK.Internal | 16 | --module LAPACK.Internal |
17 | svdR, svdR', svdC, svdC', | 17 | svdR, svdR', svdC, svdC', |
18 | eigC, | 18 | eigC, eigR, eigS, eigH, |
19 | linearSolveLSR | 19 | linearSolveLSR |
20 | ) where | 20 | ) where |
21 | 21 | ||
diff --git a/lib/LAPACK/Internal.hs b/lib/LAPACK/Internal.hs index e3c9927..ba50e6b 100644 --- a/lib/LAPACK/Internal.hs +++ b/lib/LAPACK/Internal.hs | |||
@@ -79,17 +79,48 @@ eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Dou | |||
79 | eigC (m@M {rows = r}) | 79 | eigC (m@M {rows = r}) |
80 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) | 80 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) |
81 | | otherwise = unsafePerformIO $ do | 81 | | otherwise = unsafePerformIO $ do |
82 | l <- createVector r | 82 | l <- createVector r |
83 | v <- createMatrix ColumnMajor r r | 83 | v <- createMatrix ColumnMajor r r |
84 | dummy <- createMatrix ColumnMajor 1 1 | 84 | dummy <- createMatrix ColumnMajor 1 1 |
85 | zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] | 85 | zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] |
86 | return (l,v) | 86 | return (l,v) |
87 | 87 | ||
88 | ----------------------------------------------------------------------------- | 88 | ----------------------------------------------------------------------------- |
89 | -- dgeev | 89 | -- dgeev |
90 | foreign import ccall "lapack-aux.h eig_l_R" | 90 | foreign import ccall "lapack-aux.h eig_l_R" |
91 | dgeev :: Double ::> Double ::> ((Complex Double) :> Double ::> IO Int) | 91 | dgeev :: Double ::> Double ::> ((Complex Double) :> Double ::> IO Int) |
92 | 92 | ||
93 | -- | Wrapper for LAPACK's /dgeev/, which computes the eigenvalues and right eigenvectors of a general real matrix: | ||
94 | -- | ||
95 | -- if @(l,v)=eigR m@ then @m \<\> v = v \<\> diag l@. | ||
96 | -- | ||
97 | -- The eigenvectors are the columns of v. | ||
98 | -- The eigenvalues are not sorted. | ||
99 | eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) | ||
100 | eigR (m@M {rows = r}) = (s', v'') | ||
101 | where (s,v) = eigRaux m | ||
102 | s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s)) | ||
103 | v' = toRows $ trans v | ||
104 | v'' = fromColumns $ fixeig (toList s') v' | ||
105 | |||
106 | eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) | ||
107 | eigRaux (m@M {rows = r}) | ||
108 | | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1) | ||
109 | | otherwise = unsafePerformIO $ do | ||
110 | l <- createVector r | ||
111 | v <- createMatrix ColumnMajor r r | ||
112 | dummy <- createMatrix ColumnMajor 1 1 | ||
113 | dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m] | ||
114 | return (l,v) | ||
115 | |||
116 | fixeig [] _ = [] | ||
117 | fixeig [r] [v] = [comp v] | ||
118 | fixeig ((r1:+i1):(r2:+i2):r) (v1:v2:vs) | ||
119 | | r1 == r2 && i1 == (-i2) = toComplex (v1,v2) : toComplex (v1,scale (-1) v2) : fixeig r vs | ||
120 | | otherwise = comp v1 : fixeig ((r2:+i2):r) (v2:vs) | ||
121 | |||
122 | scale r v = fromList [r] `outer` v | ||
123 | |||
93 | ----------------------------------------------------------------------------- | 124 | ----------------------------------------------------------------------------- |
94 | -- dsyev | 125 | -- dsyev |
95 | foreign import ccall "lapack-aux.h eig_l_S" | 126 | foreign import ccall "lapack-aux.h eig_l_S" |
@@ -106,17 +137,38 @@ eigS m = (s', fliprl v) | |||
106 | where (s,v) = eigS' m | 137 | where (s,v) = eigS' m |
107 | s' = fromList . reverse . toList $ s | 138 | s' = fromList . reverse . toList $ s |
108 | 139 | ||
109 | eigS' (m@M {rows = r}) = unsafePerformIO $ do | 140 | eigS' (m@M {rows = r}) |
110 | l <- createVector r | 141 | | r == 1 = (fromList [cdat m `at` 0], singleton 1) |
111 | v <- createMatrix ColumnMajor r r | 142 | | otherwise = unsafePerformIO $ do |
112 | dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] | 143 | l <- createVector r |
113 | return (l,v) | 144 | v <- createMatrix ColumnMajor r r |
145 | dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] | ||
146 | return (l,v) | ||
114 | 147 | ||
115 | ----------------------------------------------------------------------------- | 148 | ----------------------------------------------------------------------------- |
116 | -- zheev | 149 | -- zheev |
117 | foreign import ccall "lapack-aux.h eig_l_H" | 150 | foreign import ccall "lapack-aux.h eig_l_H" |
118 | zheev :: (Complex Double) ::> (Double :> (Complex Double) ::> IO Int) | 151 | zheev :: (Complex Double) ::> (Double :> (Complex Double) ::> IO Int) |
119 | 152 | ||
153 | -- | Wrapper for LAPACK's /zheev/, which computes the eigenvalues and right eigenvectors of a hermitian complex matrix: | ||
154 | -- | ||
155 | -- if @(l,v)=eigH m@ then @m \<\> s v = v \<\> diag l@. | ||
156 | -- | ||
157 | -- The eigenvectors are the columns of v. | ||
158 | -- The eigenvalues are sorted in descending order. | ||
159 | eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) | ||
160 | eigH m = (s', fliprl v) | ||
161 | where (s,v) = eigH' m | ||
162 | s' = fromList . reverse . toList $ s | ||
163 | |||
164 | eigH' (m@M {rows = r}) | ||
165 | | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1) | ||
166 | | otherwise = unsafePerformIO $ do | ||
167 | l <- createVector r | ||
168 | v <- createMatrix ColumnMajor r r | ||
169 | zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m] | ||
170 | return (l,v) | ||
171 | |||
120 | ----------------------------------------------------------------------------- | 172 | ----------------------------------------------------------------------------- |
121 | -- dgesv | 173 | -- dgesv |
122 | foreign import ccall "lapack-aux.h linearSolveR_l" | 174 | foreign import ccall "lapack-aux.h linearSolveR_l" |