From 473df6136476dfa07331dd25a6020260c4f02a9b Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 11 Jun 2007 10:46:39 +0000 Subject: all eig --- examples/tests.hs | 76 +++++++++++++++++++++++++++++--------- lib/Data/Packed/Internal/Matrix.hs | 7 ++++ lib/LAPACK.hs | 2 +- lib/LAPACK/Internal.hs | 72 +++++++++++++++++++++++++++++++----- 4 files changed, 129 insertions(+), 28 deletions(-) diff --git a/examples/tests.hs b/examples/tests.hs index 5af33ba..53436c8 100644 --- a/examples/tests.hs +++ b/examples/tests.hs @@ -81,8 +81,6 @@ cf = mulF af bc r = mulC cc (trans cf) -ident n = diag (constant n 1) - rd = (2><2) [ 43492.0, 50572.0 , 102550.0, 119242.0 :: Double] @@ -126,33 +124,64 @@ instance (Field a, Arbitrary a) => Arbitrary (SqM a) where return $ SqM $ (n> Arbitrary (Sym a) where + arbitrary = do + SqM m <- arbitrary + return $ Sym (m `addM` trans m) + coarbitrary = undefined -type BaseType = Double +data Her = Her (Matrix (Complex Double)) deriving Show +instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where + arbitrary = do + SqM m <- arbitrary + return $ Her (m `addM` (liftMatrix conj) (trans m)) + coarbitrary = undefined + + + +addM m1 m2 = liftMatrix2 addV m1 m2 +addV v1 v2 = fromList $ zipWith (+) (toList v1) (toList v2) -svdTestR fun prod m = u <> s <> trans v |~| m + +type BaseType = Double + +svdTestR prod m = u <> s <> trans v |~| m && u <> trans u |~| ident (rows m) && v <> trans v |~| ident (cols m) - where (u,s,v) = fun m + where (u,s,v) = svdR m (<>) = prod -svdTestC fun prod m = u <> s' <> (trans v) |~~| m +svdTestC prod m = u <> s' <> (trans v) |~~| m && u <> (liftMatrix conj) (trans u) |~~| ident (rows m) && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) - where (u,s,v) = fun m + where (u,s,v) = svdC m (<>) = prod s' = liftMatrix comp s -eigTestC fun prod (SqM m) = (m <> v) |~~| (v <> diag s) - && takeDiag ((liftMatrix conj (trans v)) `mulC` v) ~~ constant (rows m) 1 - where (s,v) = fun m +eigTestC prod (SqM m) = (m <> v) |~~| (v <> diag s) + && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant (rows m) 1 --normalized + where (s,v) = eigC m + (<>) = prod + +eigTestR prod (SqM m) = (liftMatrix comp m <> v) |~~| (v <> diag s) + -- && takeDiag ((liftMatrix conj (trans v)) <> v) ~~ constant (rows m) 1 --normalized ??? + where (s,v) = eigR m + (<>) = prod + +eigTestS prod (Sym m) = (m <> v) |~| (v <> diag s) + && v <> trans v |~| ident (cols m) + where (s,v) = eigS m (<>) = prod -takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] +eigTestH prod (Her m) = (m <> v) |~~| (v <> diag (comp s)) + && v <> (liftMatrix conj) (trans v) |~~| ident (cols m) + where (s,v) = eigH m + (<>) = prod -comp v = toComplex (v,constant (dim v) 0) main = do quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) @@ -162,8 +191,21 @@ main = do quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| mulF m1 (m2 :: Matrix BaseType) quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| trans (mulF (trans m2) (trans m1 :: Matrix BaseType)) quickCheck $ \(PairM m1 m2) -> mulC m1 m2 |=| multiplyG m1 (m2 :: Matrix BaseType) - quickCheck (svdTestR svdR mulC) - quickCheck (svdTestR svdR mulF) - quickCheck (svdTestC svdC mulC) - quickCheck (svdTestC svdC mulF) - quickCheck (eigTestC eigC mulC) + quickCheck (svdTestR mulC) + quickCheck (svdTestR mulF) + quickCheck (svdTestC mulC) + quickCheck (svdTestC mulF) + quickCheck (eigTestC mulC) + quickCheck (eigTestC mulF) + quickCheck (eigTestR mulC) + quickCheck (eigTestR mulF) + quickCheck (\(Sym m) -> m |=| (trans m:: Matrix BaseType)) + quickCheck (eigTestS mulC) + quickCheck (eigTestS mulF) + quickCheck (eigTestH mulC) + quickCheck (eigTestH mulF) + + +kk = (2><2) + [ 1.0, 0.0 + , -1.5, 1.0 ::Double] diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index b8de245..bae56f1 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -190,6 +190,8 @@ conj :: Vector (Complex Double) -> Vector (Complex Double) conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1]) where mulC = multiply RowMajor +comp v = toComplex (v,constant (dim v) 0) + ------------------------------------------------------------------------------ -- | Reverse rows @@ -203,6 +205,7 @@ fliprl m = fromColumns . reverse . toColumns $ m ----------------------------------------------------------------- liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes +liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes ------------------------------------------------------------------ @@ -333,3 +336,7 @@ diagRect s r c | r < c = trans $ diagRect s c r | r > c = joinVert [diag s , zeros (r-c,c)] where zeros (r,c) = reshape c $ constant (r*c) 0 + +takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] + +ident n = diag (constant n 1) diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs index 088b6d7..0f1a178 100644 --- a/lib/LAPACK.hs +++ b/lib/LAPACK.hs @@ -15,7 +15,7 @@ module LAPACK ( --module LAPACK.Internal svdR, svdR', svdC, svdC', - eigC, + eigC, eigR, eigS, eigH, linearSolveLSR ) where diff --git a/lib/LAPACK/Internal.hs b/lib/LAPACK/Internal.hs index e3c9927..ba50e6b 100644 --- a/lib/LAPACK/Internal.hs +++ b/lib/LAPACK/Internal.hs @@ -79,17 +79,48 @@ eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Dou eigC (m@M {rows = r}) | r == 1 = (fromList [cdat m `at` 0], singleton 1) | otherwise = unsafePerformIO $ do - l <- createVector r - v <- createMatrix ColumnMajor r r - dummy <- createMatrix ColumnMajor 1 1 - zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] - return (l,v) + l <- createVector r + v <- createMatrix ColumnMajor r r + dummy <- createMatrix ColumnMajor 1 1 + zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] + return (l,v) ----------------------------------------------------------------------------- -- dgeev foreign import ccall "lapack-aux.h eig_l_R" dgeev :: Double ::> Double ::> ((Complex Double) :> Double ::> IO Int) +-- | Wrapper for LAPACK's /dgeev/, which computes the eigenvalues and right eigenvectors of a general real matrix: +-- +-- if @(l,v)=eigR m@ then @m \<\> v = v \<\> diag l@. +-- +-- The eigenvectors are the columns of v. +-- The eigenvalues are not sorted. +eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) +eigR (m@M {rows = r}) = (s', v'') + where (s,v) = eigRaux m + s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s)) + v' = toRows $ trans v + v'' = fromColumns $ fixeig (toList s') v' + +eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) +eigRaux (m@M {rows = r}) + | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1) + | otherwise = unsafePerformIO $ do + l <- createVector r + v <- createMatrix ColumnMajor r r + dummy <- createMatrix ColumnMajor 1 1 + dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m] + return (l,v) + +fixeig [] _ = [] +fixeig [r] [v] = [comp v] +fixeig ((r1:+i1):(r2:+i2):r) (v1:v2:vs) + | r1 == r2 && i1 == (-i2) = toComplex (v1,v2) : toComplex (v1,scale (-1) v2) : fixeig r vs + | otherwise = comp v1 : fixeig ((r2:+i2):r) (v2:vs) + +scale r v = fromList [r] `outer` v + ----------------------------------------------------------------------------- -- dsyev foreign import ccall "lapack-aux.h eig_l_S" @@ -106,17 +137,38 @@ eigS m = (s', fliprl v) where (s,v) = eigS' m s' = fromList . reverse . toList $ s -eigS' (m@M {rows = r}) = unsafePerformIO $ do - l <- createVector r - v <- createMatrix ColumnMajor r r - dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] - return (l,v) +eigS' (m@M {rows = r}) + | r == 1 = (fromList [cdat m `at` 0], singleton 1) + | otherwise = unsafePerformIO $ do + l <- createVector r + v <- createMatrix ColumnMajor r r + dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] + return (l,v) ----------------------------------------------------------------------------- -- zheev foreign import ccall "lapack-aux.h eig_l_H" zheev :: (Complex Double) ::> (Double :> (Complex Double) ::> IO Int) +-- | Wrapper for LAPACK's /zheev/, which computes the eigenvalues and right eigenvectors of a hermitian complex matrix: +-- +-- if @(l,v)=eigH m@ then @m \<\> s v = v \<\> diag l@. +-- +-- The eigenvectors are the columns of v. +-- The eigenvalues are sorted in descending order. +eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) +eigH m = (s', fliprl v) + where (s,v) = eigH' m + s' = fromList . reverse . toList $ s + +eigH' (m@M {rows = r}) + | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1) + | otherwise = unsafePerformIO $ do + l <- createVector r + v <- createMatrix ColumnMajor r r + zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m] + return (l,v) + ----------------------------------------------------------------------------- -- dgesv foreign import ccall "lapack-aux.h linearSolveR_l" -- cgit v1.2.3