diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-10-26 16:18:28 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-10-26 16:18:28 +0000 |
commit | 71320675021472b2f97191ba514c651ceb8a1617 (patch) | |
tree | 421fbf8f7d7d3e3d9c7fa5fdb87d2d9eb9ce0d96 | |
parent | 86406ad682436d55932318b85123fe1afc865bbf (diff) |
added Schur factorization
-rw-r--r-- | examples/tests.hs | 23 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 6 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK.hs | 28 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 49 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 4 |
5 files changed, 106 insertions, 4 deletions
diff --git a/examples/tests.hs b/examples/tests.hs index e91b9f1..8224255 100644 --- a/examples/tests.hs +++ b/examples/tests.hs | |||
@@ -193,6 +193,14 @@ unitary m = square m && m <> ctrans m |~| ident (rows m) | |||
193 | 193 | ||
194 | hermitian m = m |~| ctrans m | 194 | hermitian m = m |~| ctrans m |
195 | 195 | ||
196 | upperTriang m = rows m == 1 || down == z | ||
197 | where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m)) | ||
198 | z = constant 0 (dim down) | ||
199 | |||
200 | upperHessenberg m = rows m < 3 || down == z | ||
201 | where down = fromList $ concat $ zipWith drop [2..] (toLists (ctrans m)) | ||
202 | z = constant 0 (dim down) | ||
203 | |||
196 | svdTest svd m = u <> real d <> trans v |~| m | 204 | svdTest svd m = u <> real d <> trans v |~| m |
197 | && unitary u && unitary v | 205 | && unitary u && unitary v |
198 | where (u,d,v) = full svd m | 206 | where (u,d,v) = full svd m |
@@ -274,16 +282,24 @@ cholCTest = chol ((2><2) [1,2,2,9::Complex Double]) == (2><2) [1,2,0,2.236067977 | |||
274 | 282 | ||
275 | --------------------------------------------------------------------- | 283 | --------------------------------------------------------------------- |
276 | 284 | ||
277 | qrTest qr m = q <> r |~| m && unitary q | 285 | qrTest qr m = q <> r |~| m && unitary q && upperTriang r |
278 | where (q,r) = qr m | 286 | where (q,r) = qr m |
279 | 287 | ||
280 | --------------------------------------------------------------------- | 288 | --------------------------------------------------------------------- |
281 | 289 | ||
282 | hessTest m = m |~| p <> h <> ctrans p && unitary p | 290 | hessTest m = m |~| p <> h <> ctrans p && unitary p && upperHessenberg h |
283 | where (p,h) = hess m | 291 | where (p,h) = hess m |
284 | 292 | ||
285 | --------------------------------------------------------------------- | 293 | --------------------------------------------------------------------- |
286 | 294 | ||
295 | schurTest1 m = m |~| u <> s <> ctrans u && unitary u && upperTriang s | ||
296 | where (u,s) = schur m | ||
297 | |||
298 | schurTest2 m = m |~| u <> s <> ctrans u && unitary u && upperHessenberg s -- fixme | ||
299 | where (u,s) = schur m | ||
300 | |||
301 | --------------------------------------------------------------------- | ||
302 | |||
287 | asFortran m = (rows m >|< cols m) $ toList (fdat m) | 303 | asFortran m = (rows m >|< cols m) $ toList (fdat m) |
288 | asC m = (rows m >< cols m) $ toList (cdat m) | 304 | asC m = (rows m >< cols m) $ toList (cdat m) |
289 | 305 | ||
@@ -346,6 +362,9 @@ tests = do | |||
346 | putStrLn "--------- hess --------" | 362 | putStrLn "--------- hess --------" |
347 | quickCheck (hessTest . sqm ::SqM Double->Bool) | 363 | quickCheck (hessTest . sqm ::SqM Double->Bool) |
348 | quickCheck (hessTest . sqm ::SqM (Complex Double) -> Bool) | 364 | quickCheck (hessTest . sqm ::SqM (Complex Double) -> Bool) |
365 | putStrLn "--------- schur --------" | ||
366 | quickCheck (schurTest2 . sqm ::SqM Double->Bool) | ||
367 | quickCheck (schurTest1 . sqm ::SqM (Complex Double) -> Bool) | ||
349 | putStrLn "--------- nullspace ------" | 368 | putStrLn "--------- nullspace ------" |
350 | quickCheck (nullspaceTest :: RM -> Bool) | 369 | quickCheck (nullspaceTest :: RM -> Bool) |
351 | quickCheck (nullspaceTest :: CM -> Bool) | 370 | quickCheck (nullspaceTest :: CM -> Bool) |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 1345975..7f8e84c 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -35,6 +35,8 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
35 | chol, | 35 | chol, |
36 | -- ** Hessenberg | 36 | -- ** Hessenberg |
37 | hess, | 37 | hess, |
38 | -- ** Schur | ||
39 | schur, | ||
38 | -- * Nullspace | 40 | -- * Nullspace |
39 | nullspacePrec, | 41 | nullspacePrec, |
40 | nullVector, | 42 | nullVector, |
@@ -79,6 +81,8 @@ class (Linear Matrix t) => GenMat t where | |||
79 | qr :: Matrix t -> (Matrix t, Matrix t) | 81 | qr :: Matrix t -> (Matrix t, Matrix t) |
80 | -- | Hessenberg factorization using lapack's dgehrd or zgehrd. | 82 | -- | Hessenberg factorization using lapack's dgehrd or zgehrd. |
81 | hess :: Matrix t -> (Matrix t, Matrix t) | 83 | hess :: Matrix t -> (Matrix t, Matrix t) |
84 | -- | Schur factorization using lapack's dgees or zgees. | ||
85 | schur :: Matrix t -> (Matrix t, Matrix t) | ||
82 | -- | Conjugate transpose. | 86 | -- | Conjugate transpose. |
83 | ctrans :: Matrix t -> Matrix t | 87 | ctrans :: Matrix t -> Matrix t |
84 | 88 | ||
@@ -93,6 +97,7 @@ instance GenMat Double where | |||
93 | cholSH = cholS | 97 | cholSH = cholS |
94 | qr = GSL.unpackQR . qrR | 98 | qr = GSL.unpackQR . qrR |
95 | hess = unpackHess hessR | 99 | hess = unpackHess hessR |
100 | schur = schurR | ||
96 | 101 | ||
97 | instance GenMat (Complex Double) where | 102 | instance GenMat (Complex Double) where |
98 | svd = svdC | 103 | svd = svdC |
@@ -105,6 +110,7 @@ instance GenMat (Complex Double) where | |||
105 | cholSH = cholH | 110 | cholSH = cholH |
106 | qr = unpackQR . qrC | 111 | qr = unpackQR . qrC |
107 | hess = unpackHess hessC | 112 | hess = unpackHess hessC |
113 | schur = schurC | ||
108 | 114 | ||
109 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. If @(s,v) = eigSH m@ then @m == v \<> diag s \<> ctrans v@ | 115 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. If @(s,v) = eigSH m@ then @m == v \<> diag s \<> ctrans v@ |
110 | eigSH :: GenMat t => Matrix t -> (Vector Double, Matrix t) | 116 | eigSH :: GenMat t => Matrix t -> (Vector Double, Matrix t) |
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index a84a17e..628d4f8 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -21,7 +21,8 @@ module Numeric.LinearAlgebra.LAPACK ( | |||
21 | linearSolveSVDR, linearSolveSVDC, | 21 | linearSolveSVDR, linearSolveSVDC, |
22 | cholS, cholH, | 22 | cholS, cholH, |
23 | qrR, qrC, | 23 | qrR, qrC, |
24 | hessR, hessC | 24 | hessR, hessC, |
25 | schurR, schurC | ||
25 | ) where | 26 | ) where |
26 | 27 | ||
27 | import Data.Packed.Internal | 28 | import Data.Packed.Internal |
@@ -351,7 +352,7 @@ hessR a = unsafePerformIO $ do | |||
351 | ----------------------------------------------------------------------------------- | 352 | ----------------------------------------------------------------------------------- |
352 | foreign import ccall "LAPACK/lapack-aux.h hess_l_C" zgehrd :: TCMCVCM | 353 | foreign import ccall "LAPACK/lapack-aux.h hess_l_C" zgehrd :: TCMCVCM |
353 | 354 | ||
354 | -- | Wrapper for LAPACK's /zgeqr2/, which computes a Hessenberg factorization of a square complex matrix. | 355 | -- | Wrapper for LAPACK's /zgehrd/, which computes a Hessenberg factorization of a square complex matrix. |
355 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 356 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
356 | hessC a = unsafePerformIO $ do | 357 | hessC a = unsafePerformIO $ do |
357 | r <- createMatrix ColumnMajor m n | 358 | r <- createMatrix ColumnMajor m n |
@@ -362,3 +363,26 @@ hessC a = unsafePerformIO $ do | |||
362 | n = cols a | 363 | n = cols a |
363 | mn = min m n | 364 | mn = min m n |
364 | 365 | ||
366 | ----------------------------------------------------------------------------------- | ||
367 | foreign import ccall "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM | ||
368 | |||
369 | -- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix. | ||
370 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) | ||
371 | schurR a = unsafePerformIO $ do | ||
372 | u <- createMatrix ColumnMajor n n | ||
373 | s <- createMatrix ColumnMajor n n | ||
374 | dgees // mat fdat a // mat dat u // mat dat s // check "schurR" [fdat a] | ||
375 | return (u,s) | ||
376 | where n = rows a | ||
377 | |||
378 | ----------------------------------------------------------------------------------- | ||
379 | foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM | ||
380 | |||
381 | -- | Wrapper for LAPACK's /zgees/, which computes a Schur factorization of a square complex matrix. | ||
382 | schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) | ||
383 | schurC a = unsafePerformIO $ do | ||
384 | u <- createMatrix ColumnMajor n n | ||
385 | s <- createMatrix ColumnMajor n n | ||
386 | zgees // mat fdat a // mat dat u // mat dat s // check "schurC" [fdat a] | ||
387 | return (u,s) | ||
388 | where n = rows a | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 04ef416..cab3f5b 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -701,3 +701,52 @@ int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | |||
701 | free(WORK); | 701 | free(WORK); |
702 | OK | 702 | OK |
703 | } | 703 | } |
704 | |||
705 | //////////////////// Schur factorization ///////////////////////// | ||
706 | |||
707 | int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) { | ||
708 | integer m = ar; | ||
709 | integer n = ac; | ||
710 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | ||
711 | DEBUGMSG("schur_l_R"); | ||
712 | memcpy(sp,ap,n*n*sizeof(double)); | ||
713 | integer lwork = 6*n; // fixme | ||
714 | double *WORK = (double*)malloc(lwork*sizeof(double)); | ||
715 | double *WR = (double*)malloc(n*sizeof(double)); | ||
716 | double *WI = (double*)malloc(n*sizeof(double)); | ||
717 | // WR and WI not really required in this call | ||
718 | logical *BWORK = (logical*)malloc(n*sizeof(logical)); | ||
719 | integer res; | ||
720 | integer sdim; | ||
721 | dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res); | ||
722 | CHECK(res,res); | ||
723 | free(WR); | ||
724 | free(WI); | ||
725 | free(BWORK); | ||
726 | free(WORK); | ||
727 | OK | ||
728 | } | ||
729 | |||
730 | int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { | ||
731 | integer m = ar; | ||
732 | integer n = ac; | ||
733 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | ||
734 | DEBUGMSG("schur_l_C"); | ||
735 | memcpy(sp,ap,n*n*sizeof(doublecomplex)); | ||
736 | integer lwork = 6*n; // fixme | ||
737 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
738 | doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); | ||
739 | // W not really required in this call | ||
740 | logical *BWORK = (logical*)malloc(n*sizeof(logical)); | ||
741 | double *RWORK = (double*)malloc(n*sizeof(double)); | ||
742 | integer res; | ||
743 | integer sdim; | ||
744 | zgees_ ("V","N",NULL,&n,(doublecomplex*)sp,&n,&sdim,W, | ||
745 | (doublecomplex*)up,&n, | ||
746 | WORK,&lwork,RWORK,BWORK,&res); | ||
747 | CHECK(res,res); | ||
748 | free(W); | ||
749 | free(BWORK); | ||
750 | free(WORK); | ||
751 | OK | ||
752 | } | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 52ac41e..e5d74d7 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -48,3 +48,7 @@ int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)); | |||
48 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)); | 48 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)); |
49 | 49 | ||
50 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)); | 50 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)); |
51 | |||
52 | int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)); | ||
53 | |||
54 | int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); | ||