From ff72d6c45d36306ea3fdb0587749bfa99d6802b8 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 26 Oct 2007 10:12:54 +0000 Subject: added Hessenberg factorization --- examples/tests.hs | 16 ++++++++--- lib/Numeric/LinearAlgebra/Algorithms.hs | 38 +++++++++++++++++++++----- lib/Numeric/LinearAlgebra/LAPACK.hs | 39 +++++++++++++++++++++++---- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 38 ++++++++++++++++++++++++++ lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 4 +++ 5 files changed, 120 insertions(+), 15 deletions(-) diff --git a/examples/tests.hs b/examples/tests.hs index 9388671..e91b9f1 100644 --- a/examples/tests.hs +++ b/examples/tests.hs @@ -189,12 +189,12 @@ pinvTest m = m <> p <> m |~| m square m = rows m == cols m -orthonormal m = square m && m <> ctrans m |~| ident (rows m) +unitary m = square m && m <> ctrans m |~| ident (rows m) hermitian m = m |~| ctrans m svdTest svd m = u <> real d <> trans v |~| m - && orthonormal u && orthonormal v + && unitary u && unitary v where (u,d,v) = full svd m svdTest' svd m = m |~| 0 || u <> real (diag s) <> trans v |~| m @@ -204,7 +204,7 @@ eigTest m = complex m <> v |~| v <> diag s where (s, v) = eig m eigTestSH m = m <> v |~| v <> real (diag s) - && orthonormal v + && unitary v && m |~| v <> real (diag s) <> ctrans v where (s, v) = eigSH m @@ -274,11 +274,16 @@ cholCTest = chol ((2><2) [1,2,2,9::Complex Double]) == (2><2) [1,2,0,2.236067977 --------------------------------------------------------------------- -qrTest qr m = q <> r |~| m && q <> ctrans q |~| ident (rows m) +qrTest qr m = q <> r |~| m && unitary q where (q,r) = qr m --------------------------------------------------------------------- +hessTest m = m |~| p <> h <> ctrans p && unitary p + where (p,h) = hess m + +--------------------------------------------------------------------- + asFortran m = (rows m >|< cols m) $ toList (fdat m) asC m = (rows m >< cols m) $ toList (cdat m) @@ -338,6 +343,9 @@ tests = do quickCheck (qrTest ( unpackQR . GSL.qrPacked)) quickCheck (qrTest qr ::RM->Bool) quickCheck (qrTest qr ::CM->Bool) + putStrLn "--------- hess --------" + quickCheck (hessTest . sqm ::SqM Double->Bool) + quickCheck (hessTest . sqm ::SqM (Complex Double) -> Bool) putStrLn "--------- nullspace ------" quickCheck (nullspaceTest :: RM -> Bool) quickCheck (nullspaceTest :: CM -> Bool) diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 84c399a..1345975 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -33,6 +33,8 @@ module Numeric.LinearAlgebra.Algorithms ( qr, -- ** Cholesky chol, +-- ** Hessenberg + hess, -- * Nullspace nullspacePrec, nullVector, @@ -42,7 +44,9 @@ module Numeric.LinearAlgebra.Algorithms ( ctrans, eps, i, -- * Util - GenMat(linearSolveSVD,lu,eigSH',cholSH), unpackQR, haussholder + GenMat(linearSolveSVD,lu,eigSH',cholSH), + haussholder, + unpackQR, unpackHess ) where @@ -64,7 +68,8 @@ class (Linear Matrix t) => GenMat t where -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". linearSolve :: Matrix t -> Matrix t -> Matrix t linearSolveSVD :: Matrix t -> Matrix t -> Matrix t - -- | Eigenvalues and eigenvectors of a general square matrix using lapack's dgeev or zgeev. If @(s,v) = eig m@ then @m \<> v == v \<> diag s@ + -- | Eigenvalues and eigenvectors of a general square matrix using lapack's dgeev or zgeev. + -- If @(s,v) = eig m@ then @m \<> v == v \<> diag s@ eig :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) -- | Similar to eigSH without checking that the input matrix is hermitian or symmetric. eigSH' :: Matrix t -> (Vector Double, Matrix t) @@ -72,6 +77,8 @@ class (Linear Matrix t) => GenMat t where cholSH :: Matrix t -> Matrix t -- | QR factorization using lapack's dgeqr2 or zgeqr2. qr :: Matrix t -> (Matrix t, Matrix t) + -- | Hessenberg factorization using lapack's dgehrd or zgehrd. + hess :: Matrix t -> (Matrix t, Matrix t) -- | Conjugate transpose. ctrans :: Matrix t -> Matrix t @@ -85,6 +92,7 @@ instance GenMat Double where eigSH' = eigS cholSH = cholS qr = GSL.unpackQR . qrR + hess = unpackHess hessR instance GenMat (Complex Double) where svd = svdC @@ -96,6 +104,7 @@ instance GenMat (Complex Double) where eigSH' = eigH cholSH = cholH qr = unpackQR . qrC + hess = unpackHess hessC -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. If @(s,v) = eigSH m@ then @m == v \<> diag s \<> ctrans v@ eigSH :: GenMat t => Matrix t -> (Vector Double, Matrix t) @@ -277,6 +286,14 @@ haussholder :: (GenMat a) => a -> Vector a -> Matrix a haussholder tau v = ident (dim v) `sub` (tau `scale` (w `mXm` ctrans w)) where w = asColumn v + +zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs) + where xs = toList v + +zt 0 v = v +zt k v = join [subVector 0 (dim v - k) v, constant 0 k] + + unpackQR :: (GenMat t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) unpackQR (pq, tau) = (q,r) where cs = toColumns pq @@ -288,8 +305,17 @@ unpackQR (pq, tau) = (q,r) hs = zipWith haussholder (toList tau) vs q = foldl1' mXm hs - zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs) - where xs = toList v +unpackHess :: (GenMat t) => (Matrix t -> (Matrix t,Vector t)) -> Matrix t -> (Matrix t, Matrix t) +unpackHess hf m + | rows m == 1 = ((1><1)[1],m) + | otherwise = (uH . hf) m - zt 0 v = v - zt k v = join [subVector 0 (dim v - k) v, constant 0 k] +uH (pq, tau) = (p,h) + where cs = toColumns pq + m = rows pq + n = cols pq + mn = min m n + h = fromColumns $ zipWith zt ([m-2, m-3 .. 1] ++ repeat 0) cs + vs = zipWith zh [2..mn] cs + hs = zipWith haussholder (toList tau) vs + p = foldl1' mXm hs diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 648e59f..a84a17e 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -20,7 +20,8 @@ module Numeric.LinearAlgebra.LAPACK ( linearSolveLSR, linearSolveLSC, linearSolveSVDR, linearSolveSVDC, cholS, cholH, - qrR, qrC + qrR, qrC, + hessR, hessC ) where import Data.Packed.Internal @@ -284,7 +285,7 @@ linearSolveSVDC_l rcond a b = unsafePerformIO $ do ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h chol_l_H" zpotrf :: TCMCM --- | Wrapper for LAPACK's /zpotrf/,which computes the Cholesky factorization of a +-- | Wrapper for LAPACK's /zpotrf/, which computes the Cholesky factorization of a -- complex Hermitian positive definite matrix. cholH :: Matrix (Complex Double) -> Matrix (Complex Double) cholH a = unsafePerformIO $ do @@ -296,7 +297,7 @@ cholH a = unsafePerformIO $ do ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h chol_l_S" dpotrf :: TMM --- | Wrapper for LAPACK's /dpotrf/,which computes the Cholesky factorization of a +-- | Wrapper for LAPACK's /dpotrf/, which computes the Cholesky factorization of a -- real symmetric positive definite matrix. cholS :: Matrix Double -> Matrix Double cholS a = unsafePerformIO $ do @@ -308,7 +309,7 @@ cholS a = unsafePerformIO $ do ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h qr_l_R" dgeqr2 :: TMVM --- | Wrapper for LAPACK's /dgeqr2/,which computes a QR factorization of a real matrix. +-- | Wrapper for LAPACK's /dgeqr2/, which computes a QR factorization of a real matrix. qrR :: Matrix Double -> (Matrix Double, Vector Double) qrR a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n @@ -322,7 +323,7 @@ qrR a = unsafePerformIO $ do ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h qr_l_C" zgeqr2 :: TCMCVCM --- | Wrapper for LAPACK's /zgeqr2/,which computes a QR factorization of a complex matrix. +-- | Wrapper for LAPACK's /zgeqr2/, which computes a QR factorization of a complex matrix. qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) qrC a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n @@ -333,3 +334,31 @@ qrC a = unsafePerformIO $ do n = cols a mn = min m n +----------------------------------------------------------------------------------- +foreign import ccall "LAPACK/lapack-aux.h hess_l_R" dgehrd :: TMVM + +-- | Wrapper for LAPACK's /dgehrd/, which computes a Hessenberg factorization of a square real matrix. +hessR :: Matrix Double -> (Matrix Double, Vector Double) +hessR a = unsafePerformIO $ do + r <- createMatrix ColumnMajor m n + tau <- createVector (mn-1) + dgehrd // mat fdat a // vec tau // mat dat r // check "hessR" [fdat a] + return (r,tau) + where m = rows a + n = cols a + mn = min m n + +----------------------------------------------------------------------------------- +foreign import ccall "LAPACK/lapack-aux.h hess_l_C" zgehrd :: TCMCVCM + +-- | Wrapper for LAPACK's /zgeqr2/, which computes a Hessenberg factorization of a square complex matrix. +hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) +hessC a = unsafePerformIO $ do + r <- createMatrix ColumnMajor m n + tau <- createVector (mn-1) + zgehrd // mat fdat a // vec tau // mat dat r // check "hessC" [fdat a] + return (r,tau) + where m = rows a + n = cols a + mn = min m n + diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 9b6c1db..04ef416 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -663,3 +663,41 @@ int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { free(WORK); OK } + +//////////////////// Hessenberg factorization ///////////////////////// + +int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { + integer m = ar; + integer n = ac; + integer mn = MIN(m,n); + REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); + DEBUGMSG("hess_l_R"); + integer lwork = 5*n; // fixme + double *WORK = (double*)malloc(lwork*sizeof(double)); + CHECK(!WORK,MEM); + memcpy(rp,ap,m*n*sizeof(double)); + integer res; + integer one = 1; + dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); + CHECK(res,res); + free(WORK); + OK +} + +int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { + integer m = ar; + integer n = ac; + integer mn = MIN(m,n); + REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); + DEBUGMSG("hess_l_C"); + integer lwork = 5*n; // fixme + doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); + CHECK(!WORK,MEM); + memcpy(rp,ap,m*n*sizeof(doublecomplex)); + integer res; + integer one = 1; + zgehrd_ (&n,&one,&n,(doublecomplex*)rp,&n,(doublecomplex*)taup,WORK,&lwork,&res); + CHECK(res,res); + free(WORK); + OK +} diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index ea71847..52ac41e 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h @@ -44,3 +44,7 @@ int chol_l_S(KDMAT(a),DMAT(r)); int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)); int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)); + +int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)); + +int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)); -- cgit v1.2.3