diff options
-rw-r--r-- | examples/tests.hs | 16 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 38 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK.hs | 39 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 38 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 4 |
5 files changed, 120 insertions, 15 deletions
diff --git a/examples/tests.hs b/examples/tests.hs index 9388671..e91b9f1 100644 --- a/examples/tests.hs +++ b/examples/tests.hs | |||
@@ -189,12 +189,12 @@ pinvTest m = m <> p <> m |~| m | |||
189 | 189 | ||
190 | square m = rows m == cols m | 190 | square m = rows m == cols m |
191 | 191 | ||
192 | orthonormal m = square m && m <> ctrans m |~| ident (rows m) | 192 | unitary m = square m && m <> ctrans m |~| ident (rows m) |
193 | 193 | ||
194 | hermitian m = m |~| ctrans m | 194 | hermitian m = m |~| ctrans m |
195 | 195 | ||
196 | svdTest svd m = u <> real d <> trans v |~| m | 196 | svdTest svd m = u <> real d <> trans v |~| m |
197 | && orthonormal u && orthonormal v | 197 | && unitary u && unitary v |
198 | where (u,d,v) = full svd m | 198 | where (u,d,v) = full svd m |
199 | 199 | ||
200 | svdTest' svd m = m |~| 0 || u <> real (diag s) <> trans v |~| m | 200 | svdTest' svd m = m |~| 0 || u <> real (diag s) <> trans v |~| m |
@@ -204,7 +204,7 @@ eigTest m = complex m <> v |~| v <> diag s | |||
204 | where (s, v) = eig m | 204 | where (s, v) = eig m |
205 | 205 | ||
206 | eigTestSH m = m <> v |~| v <> real (diag s) | 206 | eigTestSH m = m <> v |~| v <> real (diag s) |
207 | && orthonormal v | 207 | && unitary v |
208 | && m |~| v <> real (diag s) <> ctrans v | 208 | && m |~| v <> real (diag s) <> ctrans v |
209 | where (s, v) = eigSH m | 209 | where (s, v) = eigSH m |
210 | 210 | ||
@@ -274,11 +274,16 @@ cholCTest = chol ((2><2) [1,2,2,9::Complex Double]) == (2><2) [1,2,0,2.236067977 | |||
274 | 274 | ||
275 | --------------------------------------------------------------------- | 275 | --------------------------------------------------------------------- |
276 | 276 | ||
277 | qrTest qr m = q <> r |~| m && q <> ctrans q |~| ident (rows m) | 277 | qrTest qr m = q <> r |~| m && unitary q |
278 | where (q,r) = qr m | 278 | where (q,r) = qr m |
279 | 279 | ||
280 | --------------------------------------------------------------------- | 280 | --------------------------------------------------------------------- |
281 | 281 | ||
282 | hessTest m = m |~| p <> h <> ctrans p && unitary p | ||
283 | where (p,h) = hess m | ||
284 | |||
285 | --------------------------------------------------------------------- | ||
286 | |||
282 | asFortran m = (rows m >|< cols m) $ toList (fdat m) | 287 | asFortran m = (rows m >|< cols m) $ toList (fdat m) |
283 | asC m = (rows m >< cols m) $ toList (cdat m) | 288 | asC m = (rows m >< cols m) $ toList (cdat m) |
284 | 289 | ||
@@ -338,6 +343,9 @@ tests = do | |||
338 | quickCheck (qrTest ( unpackQR . GSL.qrPacked)) | 343 | quickCheck (qrTest ( unpackQR . GSL.qrPacked)) |
339 | quickCheck (qrTest qr ::RM->Bool) | 344 | quickCheck (qrTest qr ::RM->Bool) |
340 | quickCheck (qrTest qr ::CM->Bool) | 345 | quickCheck (qrTest qr ::CM->Bool) |
346 | putStrLn "--------- hess --------" | ||
347 | quickCheck (hessTest . sqm ::SqM Double->Bool) | ||
348 | quickCheck (hessTest . sqm ::SqM (Complex Double) -> Bool) | ||
341 | putStrLn "--------- nullspace ------" | 349 | putStrLn "--------- nullspace ------" |
342 | quickCheck (nullspaceTest :: RM -> Bool) | 350 | quickCheck (nullspaceTest :: RM -> Bool) |
343 | quickCheck (nullspaceTest :: CM -> Bool) | 351 | quickCheck (nullspaceTest :: CM -> Bool) |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 84c399a..1345975 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -33,6 +33,8 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
33 | qr, | 33 | qr, |
34 | -- ** Cholesky | 34 | -- ** Cholesky |
35 | chol, | 35 | chol, |
36 | -- ** Hessenberg | ||
37 | hess, | ||
36 | -- * Nullspace | 38 | -- * Nullspace |
37 | nullspacePrec, | 39 | nullspacePrec, |
38 | nullVector, | 40 | nullVector, |
@@ -42,7 +44,9 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
42 | ctrans, | 44 | ctrans, |
43 | eps, i, | 45 | eps, i, |
44 | -- * Util | 46 | -- * Util |
45 | GenMat(linearSolveSVD,lu,eigSH',cholSH), unpackQR, haussholder | 47 | GenMat(linearSolveSVD,lu,eigSH',cholSH), |
48 | haussholder, | ||
49 | unpackQR, unpackHess | ||
46 | ) where | 50 | ) where |
47 | 51 | ||
48 | 52 | ||
@@ -64,7 +68,8 @@ class (Linear Matrix t) => GenMat t where | |||
64 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". | 68 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". |
65 | linearSolve :: Matrix t -> Matrix t -> Matrix t | 69 | linearSolve :: Matrix t -> Matrix t -> Matrix t |
66 | linearSolveSVD :: Matrix t -> Matrix t -> Matrix t | 70 | linearSolveSVD :: Matrix t -> Matrix t -> Matrix t |
67 | -- | Eigenvalues and eigenvectors of a general square matrix using lapack's dgeev or zgeev. If @(s,v) = eig m@ then @m \<> v == v \<> diag s@ | 71 | -- | Eigenvalues and eigenvectors of a general square matrix using lapack's dgeev or zgeev. |
72 | -- If @(s,v) = eig m@ then @m \<> v == v \<> diag s@ | ||
68 | eig :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) | 73 | eig :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) |
69 | -- | Similar to eigSH without checking that the input matrix is hermitian or symmetric. | 74 | -- | Similar to eigSH without checking that the input matrix is hermitian or symmetric. |
70 | eigSH' :: Matrix t -> (Vector Double, Matrix t) | 75 | eigSH' :: Matrix t -> (Vector Double, Matrix t) |
@@ -72,6 +77,8 @@ class (Linear Matrix t) => GenMat t where | |||
72 | cholSH :: Matrix t -> Matrix t | 77 | cholSH :: Matrix t -> Matrix t |
73 | -- | QR factorization using lapack's dgeqr2 or zgeqr2. | 78 | -- | QR factorization using lapack's dgeqr2 or zgeqr2. |
74 | qr :: Matrix t -> (Matrix t, Matrix t) | 79 | qr :: Matrix t -> (Matrix t, Matrix t) |
80 | -- | Hessenberg factorization using lapack's dgehrd or zgehrd. | ||
81 | hess :: Matrix t -> (Matrix t, Matrix t) | ||
75 | -- | Conjugate transpose. | 82 | -- | Conjugate transpose. |
76 | ctrans :: Matrix t -> Matrix t | 83 | ctrans :: Matrix t -> Matrix t |
77 | 84 | ||
@@ -85,6 +92,7 @@ instance GenMat Double where | |||
85 | eigSH' = eigS | 92 | eigSH' = eigS |
86 | cholSH = cholS | 93 | cholSH = cholS |
87 | qr = GSL.unpackQR . qrR | 94 | qr = GSL.unpackQR . qrR |
95 | hess = unpackHess hessR | ||
88 | 96 | ||
89 | instance GenMat (Complex Double) where | 97 | instance GenMat (Complex Double) where |
90 | svd = svdC | 98 | svd = svdC |
@@ -96,6 +104,7 @@ instance GenMat (Complex Double) where | |||
96 | eigSH' = eigH | 104 | eigSH' = eigH |
97 | cholSH = cholH | 105 | cholSH = cholH |
98 | qr = unpackQR . qrC | 106 | qr = unpackQR . qrC |
107 | hess = unpackHess hessC | ||
99 | 108 | ||
100 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. If @(s,v) = eigSH m@ then @m == v \<> diag s \<> ctrans v@ | 109 | -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. If @(s,v) = eigSH m@ then @m == v \<> diag s \<> ctrans v@ |
101 | eigSH :: GenMat t => Matrix t -> (Vector Double, Matrix t) | 110 | eigSH :: GenMat t => Matrix t -> (Vector Double, Matrix t) |
@@ -277,6 +286,14 @@ haussholder :: (GenMat a) => a -> Vector a -> Matrix a | |||
277 | haussholder tau v = ident (dim v) `sub` (tau `scale` (w `mXm` ctrans w)) | 286 | haussholder tau v = ident (dim v) `sub` (tau `scale` (w `mXm` ctrans w)) |
278 | where w = asColumn v | 287 | where w = asColumn v |
279 | 288 | ||
289 | |||
290 | zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs) | ||
291 | where xs = toList v | ||
292 | |||
293 | zt 0 v = v | ||
294 | zt k v = join [subVector 0 (dim v - k) v, constant 0 k] | ||
295 | |||
296 | |||
280 | unpackQR :: (GenMat t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) | 297 | unpackQR :: (GenMat t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) |
281 | unpackQR (pq, tau) = (q,r) | 298 | unpackQR (pq, tau) = (q,r) |
282 | where cs = toColumns pq | 299 | where cs = toColumns pq |
@@ -288,8 +305,17 @@ unpackQR (pq, tau) = (q,r) | |||
288 | hs = zipWith haussholder (toList tau) vs | 305 | hs = zipWith haussholder (toList tau) vs |
289 | q = foldl1' mXm hs | 306 | q = foldl1' mXm hs |
290 | 307 | ||
291 | zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs) | 308 | unpackHess :: (GenMat t) => (Matrix t -> (Matrix t,Vector t)) -> Matrix t -> (Matrix t, Matrix t) |
292 | where xs = toList v | 309 | unpackHess hf m |
310 | | rows m == 1 = ((1><1)[1],m) | ||
311 | | otherwise = (uH . hf) m | ||
293 | 312 | ||
294 | zt 0 v = v | 313 | uH (pq, tau) = (p,h) |
295 | zt k v = join [subVector 0 (dim v - k) v, constant 0 k] | 314 | where cs = toColumns pq |
315 | m = rows pq | ||
316 | n = cols pq | ||
317 | mn = min m n | ||
318 | h = fromColumns $ zipWith zt ([m-2, m-3 .. 1] ++ repeat 0) cs | ||
319 | vs = zipWith zh [2..mn] cs | ||
320 | hs = zipWith haussholder (toList tau) vs | ||
321 | p = foldl1' mXm hs | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 648e59f..a84a17e 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -20,7 +20,8 @@ module Numeric.LinearAlgebra.LAPACK ( | |||
20 | linearSolveLSR, linearSolveLSC, | 20 | linearSolveLSR, linearSolveLSC, |
21 | linearSolveSVDR, linearSolveSVDC, | 21 | linearSolveSVDR, linearSolveSVDC, |
22 | cholS, cholH, | 22 | cholS, cholH, |
23 | qrR, qrC | 23 | qrR, qrC, |
24 | hessR, hessC | ||
24 | ) where | 25 | ) where |
25 | 26 | ||
26 | import Data.Packed.Internal | 27 | import Data.Packed.Internal |
@@ -284,7 +285,7 @@ linearSolveSVDC_l rcond a b = unsafePerformIO $ do | |||
284 | ----------------------------------------------------------------------------------- | 285 | ----------------------------------------------------------------------------------- |
285 | foreign import ccall "LAPACK/lapack-aux.h chol_l_H" zpotrf :: TCMCM | 286 | foreign import ccall "LAPACK/lapack-aux.h chol_l_H" zpotrf :: TCMCM |
286 | 287 | ||
287 | -- | Wrapper for LAPACK's /zpotrf/,which computes the Cholesky factorization of a | 288 | -- | Wrapper for LAPACK's /zpotrf/, which computes the Cholesky factorization of a |
288 | -- complex Hermitian positive definite matrix. | 289 | -- complex Hermitian positive definite matrix. |
289 | cholH :: Matrix (Complex Double) -> Matrix (Complex Double) | 290 | cholH :: Matrix (Complex Double) -> Matrix (Complex Double) |
290 | cholH a = unsafePerformIO $ do | 291 | cholH a = unsafePerformIO $ do |
@@ -296,7 +297,7 @@ cholH a = unsafePerformIO $ do | |||
296 | ----------------------------------------------------------------------------------- | 297 | ----------------------------------------------------------------------------------- |
297 | foreign import ccall "LAPACK/lapack-aux.h chol_l_S" dpotrf :: TMM | 298 | foreign import ccall "LAPACK/lapack-aux.h chol_l_S" dpotrf :: TMM |
298 | 299 | ||
299 | -- | Wrapper for LAPACK's /dpotrf/,which computes the Cholesky factorization of a | 300 | -- | Wrapper for LAPACK's /dpotrf/, which computes the Cholesky factorization of a |
300 | -- real symmetric positive definite matrix. | 301 | -- real symmetric positive definite matrix. |
301 | cholS :: Matrix Double -> Matrix Double | 302 | cholS :: Matrix Double -> Matrix Double |
302 | cholS a = unsafePerformIO $ do | 303 | cholS a = unsafePerformIO $ do |
@@ -308,7 +309,7 @@ cholS a = unsafePerformIO $ do | |||
308 | ----------------------------------------------------------------------------------- | 309 | ----------------------------------------------------------------------------------- |
309 | foreign import ccall "LAPACK/lapack-aux.h qr_l_R" dgeqr2 :: TMVM | 310 | foreign import ccall "LAPACK/lapack-aux.h qr_l_R" dgeqr2 :: TMVM |
310 | 311 | ||
311 | -- | Wrapper for LAPACK's /dgeqr2/,which computes a QR factorization of a real matrix. | 312 | -- | Wrapper for LAPACK's /dgeqr2/, which computes a QR factorization of a real matrix. |
312 | qrR :: Matrix Double -> (Matrix Double, Vector Double) | 313 | qrR :: Matrix Double -> (Matrix Double, Vector Double) |
313 | qrR a = unsafePerformIO $ do | 314 | qrR a = unsafePerformIO $ do |
314 | r <- createMatrix ColumnMajor m n | 315 | r <- createMatrix ColumnMajor m n |
@@ -322,7 +323,7 @@ qrR a = unsafePerformIO $ do | |||
322 | ----------------------------------------------------------------------------------- | 323 | ----------------------------------------------------------------------------------- |
323 | foreign import ccall "LAPACK/lapack-aux.h qr_l_C" zgeqr2 :: TCMCVCM | 324 | foreign import ccall "LAPACK/lapack-aux.h qr_l_C" zgeqr2 :: TCMCVCM |
324 | 325 | ||
325 | -- | Wrapper for LAPACK's /zgeqr2/,which computes a QR factorization of a complex matrix. | 326 | -- | Wrapper for LAPACK's /zgeqr2/, which computes a QR factorization of a complex matrix. |
326 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 327 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
327 | qrC a = unsafePerformIO $ do | 328 | qrC a = unsafePerformIO $ do |
328 | r <- createMatrix ColumnMajor m n | 329 | r <- createMatrix ColumnMajor m n |
@@ -333,3 +334,31 @@ qrC a = unsafePerformIO $ do | |||
333 | n = cols a | 334 | n = cols a |
334 | mn = min m n | 335 | mn = min m n |
335 | 336 | ||
337 | ----------------------------------------------------------------------------------- | ||
338 | foreign import ccall "LAPACK/lapack-aux.h hess_l_R" dgehrd :: TMVM | ||
339 | |||
340 | -- | Wrapper for LAPACK's /dgehrd/, which computes a Hessenberg factorization of a square real matrix. | ||
341 | hessR :: Matrix Double -> (Matrix Double, Vector Double) | ||
342 | hessR a = unsafePerformIO $ do | ||
343 | r <- createMatrix ColumnMajor m n | ||
344 | tau <- createVector (mn-1) | ||
345 | dgehrd // mat fdat a // vec tau // mat dat r // check "hessR" [fdat a] | ||
346 | return (r,tau) | ||
347 | where m = rows a | ||
348 | n = cols a | ||
349 | mn = min m n | ||
350 | |||
351 | ----------------------------------------------------------------------------------- | ||
352 | foreign import ccall "LAPACK/lapack-aux.h hess_l_C" zgehrd :: TCMCVCM | ||
353 | |||
354 | -- | Wrapper for LAPACK's /zgeqr2/, which computes a Hessenberg factorization of a square complex matrix. | ||
355 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | ||
356 | hessC a = unsafePerformIO $ do | ||
357 | r <- createMatrix ColumnMajor m n | ||
358 | tau <- createVector (mn-1) | ||
359 | zgehrd // mat fdat a // vec tau // mat dat r // check "hessC" [fdat a] | ||
360 | return (r,tau) | ||
361 | where m = rows a | ||
362 | n = cols a | ||
363 | mn = min m n | ||
364 | |||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 9b6c1db..04ef416 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -663,3 +663,41 @@ int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | |||
663 | free(WORK); | 663 | free(WORK); |
664 | OK | 664 | OK |
665 | } | 665 | } |
666 | |||
667 | //////////////////// Hessenberg factorization ///////////////////////// | ||
668 | |||
669 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | ||
670 | integer m = ar; | ||
671 | integer n = ac; | ||
672 | integer mn = MIN(m,n); | ||
673 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | ||
674 | DEBUGMSG("hess_l_R"); | ||
675 | integer lwork = 5*n; // fixme | ||
676 | double *WORK = (double*)malloc(lwork*sizeof(double)); | ||
677 | CHECK(!WORK,MEM); | ||
678 | memcpy(rp,ap,m*n*sizeof(double)); | ||
679 | integer res; | ||
680 | integer one = 1; | ||
681 | dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); | ||
682 | CHECK(res,res); | ||
683 | free(WORK); | ||
684 | OK | ||
685 | } | ||
686 | |||
687 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | ||
688 | integer m = ar; | ||
689 | integer n = ac; | ||
690 | integer mn = MIN(m,n); | ||
691 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | ||
692 | DEBUGMSG("hess_l_C"); | ||
693 | integer lwork = 5*n; // fixme | ||
694 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
695 | CHECK(!WORK,MEM); | ||
696 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
697 | integer res; | ||
698 | integer one = 1; | ||
699 | zgehrd_ (&n,&one,&n,(doublecomplex*)rp,&n,(doublecomplex*)taup,WORK,&lwork,&res); | ||
700 | CHECK(res,res); | ||
701 | free(WORK); | ||
702 | OK | ||
703 | } | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index ea71847..52ac41e 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -44,3 +44,7 @@ int chol_l_S(KDMAT(a),DMAT(r)); | |||
44 | int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)); | 44 | int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)); |
45 | 45 | ||
46 | int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)); | 46 | int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)); |
47 | |||
48 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)); | ||
49 | |||
50 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)); | ||