diff options
Diffstat (limited to 'packages/base/src/Internal/Algorithms.hs')
-rw-r--r-- | packages/base/src/Internal/Algorithms.hs | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index c4f1a60..d5cf98e 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -30,6 +30,7 @@ import Internal.LAPACK as LAPACK | |||
30 | import Internal.Numeric | 30 | import Internal.Numeric |
31 | import Data.List(foldl1') | 31 | import Data.List(foldl1') |
32 | import qualified Data.Array as A | 32 | import qualified Data.Array as A |
33 | import qualified Data.Vector.Storable as Vector | ||
33 | import Internal.ST | 34 | import Internal.ST |
34 | import Internal.Vectorized(range) | 35 | import Internal.Vectorized(range) |
35 | import Control.DeepSeq | 36 | import Control.DeepSeq |
@@ -475,9 +476,14 @@ instance (NFData t, Numeric t) => NFData (QR t) | |||
475 | -- | QR factorization. | 476 | -- | QR factorization. |
476 | -- | 477 | -- |
477 | -- If @(q,r) = qr m@ then @m == q \<> r@, where q is unitary and r is upper triangular. | 478 | -- If @(q,r) = qr m@ then @m == q \<> r@, where q is unitary and r is upper triangular. |
479 | -- Note: the current implementation is very slow for large matrices. 'thinQR' is much faster. | ||
478 | qr :: Field t => Matrix t -> (Matrix t, Matrix t) | 480 | qr :: Field t => Matrix t -> (Matrix t, Matrix t) |
479 | qr = {-# SCC "qr" #-} unpackQR . qr' | 481 | qr = {-# SCC "qr" #-} unpackQR . qr' |
480 | 482 | ||
483 | -- | A version of 'qr' which returns only the @min (rows m) (cols m)@ columns of @q@ and rows of @r@. | ||
484 | thinQR :: Field t => Matrix t -> (Matrix t, Matrix t) | ||
485 | thinQR = {-# SCC "thinQR" #-} thinUnpackQR . qr' | ||
486 | |||
481 | -- | Compute the QR decomposition of a matrix in compact form. | 487 | -- | Compute the QR decomposition of a matrix in compact form. |
482 | qrRaw :: Field t => Matrix t -> QR t | 488 | qrRaw :: Field t => Matrix t -> QR t |
483 | qrRaw m = QR x v | 489 | qrRaw m = QR x v |
@@ -494,9 +500,17 @@ qrgr n (QR a t) | |||
494 | -- | RQ factorization. | 500 | -- | RQ factorization. |
495 | -- | 501 | -- |
496 | -- If @(r,q) = rq m@ then @m == r \<> q@, where q is unitary and r is upper triangular. | 502 | -- If @(r,q) = rq m@ then @m == r \<> q@, where q is unitary and r is upper triangular. |
503 | -- Note: the current implementation is very slow for large matrices. 'thinRQ' is much faster. | ||
497 | rq :: Field t => Matrix t -> (Matrix t, Matrix t) | 504 | rq :: Field t => Matrix t -> (Matrix t, Matrix t) |
498 | rq m = {-# SCC "rq" #-} (r,q) where | 505 | rq = {-# SCC "rq" #-} rqFromQR qr |
499 | (q',r') = qr $ trans $ rev1 m | 506 | |
507 | -- | A version of 'rq' which returns only the @min (rows m) (cols m)@ columns of @r@ and rows of @q@. | ||
508 | thinRQ :: Field t => Matrix t -> (Matrix t, Matrix t) | ||
509 | thinRQ = {-# SCC "thinQR" #-} rqFromQR thinQR | ||
510 | |||
511 | rqFromQR :: Field t => (Matrix t -> (Matrix t, Matrix t)) -> Matrix t -> (Matrix t, Matrix t) | ||
512 | rqFromQR qr0 m = (r,q) where | ||
513 | (q',r') = qr0 $ trans $ rev1 m | ||
500 | r = rev2 (trans r') | 514 | r = rev2 (trans r') |
501 | q = rev2 (trans q') | 515 | q = rev2 (trans q') |
502 | rev1 = flipud . fliprl | 516 | rev1 = flipud . fliprl |
@@ -724,6 +738,12 @@ unpackQR (pq, tau) = {-# SCC "unpackQR" #-} (q,r) | |||
724 | hs = zipWith haussholder (toList tau) vs | 738 | hs = zipWith haussholder (toList tau) vs |
725 | q = foldl1' mXm hs | 739 | q = foldl1' mXm hs |
726 | 740 | ||
741 | thinUnpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) | ||
742 | thinUnpackQR (pq, tau) = (q, r) | ||
743 | where mn = uncurry min $ size pq | ||
744 | q = qrgr mn $ QR pq tau | ||
745 | r = fromRows $ zipWith (\i v -> Vector.replicate i 0 Vector.++ Vector.drop i v) [0..mn-1] (toRows pq) | ||
746 | |||
727 | unpackHess :: (Field t) => (Matrix t -> (Matrix t,Vector t)) -> Matrix t -> (Matrix t, Matrix t) | 747 | unpackHess :: (Field t) => (Matrix t -> (Matrix t,Vector t)) -> Matrix t -> (Matrix t, Matrix t) |
728 | unpackHess hf m | 748 | unpackHess hf m |
729 | | rows m == 1 = ((1><1)[1],m) | 749 | | rows m == 1 = ((1><1)[1],m) |