diff options
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 58 |
1 files changed, 51 insertions, 7 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 069d9a3..b19c0ec 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -37,6 +37,8 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
37 | hess, | 37 | hess, |
38 | -- ** Schur | 38 | -- ** Schur |
39 | schur, | 39 | schur, |
40 | -- ** LU | ||
41 | lu, | ||
40 | -- * Matrix functions | 42 | -- * Matrix functions |
41 | expm, | 43 | expm, |
42 | sqrtm, | 44 | sqrtm, |
@@ -52,11 +54,11 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
52 | -- * Util | 54 | -- * Util |
53 | haussholder, | 55 | haussholder, |
54 | unpackQR, unpackHess, | 56 | unpackQR, unpackHess, |
55 | Field(linearSolveSVD,lu,eigSH',cholSH) | 57 | Field(linearSolveSVD,eigSH',cholSH) |
56 | ) where | 58 | ) where |
57 | 59 | ||
58 | 60 | ||
59 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj) | 61 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) |
60 | import Data.Packed | 62 | import Data.Packed |
61 | import qualified Numeric.GSL.Matrix as GSL | 63 | import qualified Numeric.GSL.Matrix as GSL |
62 | import Numeric.GSL.Vector | 64 | import Numeric.GSL.Vector |
@@ -64,12 +66,13 @@ import Numeric.LinearAlgebra.LAPACK as LAPACK | |||
64 | import Complex | 66 | import Complex |
65 | import Numeric.LinearAlgebra.Linear | 67 | import Numeric.LinearAlgebra.Linear |
66 | import Data.List(foldl1') | 68 | import Data.List(foldl1') |
69 | import Data.Array | ||
67 | 70 | ||
68 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. | 71 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. |
69 | class (Normed (Matrix t), Linear Matrix t) => Field t where | 72 | class (Normed (Matrix t), Linear Matrix t) => Field t where |
70 | -- | Singular value decomposition using lapack's dgesvd or zgesvd. | 73 | -- | Singular value decomposition using lapack's dgesvd or zgesvd. |
71 | svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 74 | svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
72 | lu :: Matrix t -> (Matrix t, Matrix t, [Int], t) | 75 | luPacked :: Matrix t -> (Matrix t, [Int]) |
73 | -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. | 76 | -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. |
74 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". | 77 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". |
75 | linearSolve :: Matrix t -> Matrix t -> Matrix t | 78 | linearSolve :: Matrix t -> Matrix t -> Matrix t |
@@ -106,7 +109,7 @@ class (Normed (Matrix t), Linear Matrix t) => Field t where | |||
106 | 109 | ||
107 | instance Field Double where | 110 | instance Field Double where |
108 | svd = svdR | 111 | svd = svdR |
109 | lu = GSL.luR | 112 | luPacked = luR |
110 | linearSolve = linearSolveR | 113 | linearSolve = linearSolveR |
111 | linearSolveSVD = linearSolveSVDR Nothing | 114 | linearSolveSVD = linearSolveSVDR Nothing |
112 | ctrans = trans | 115 | ctrans = trans |
@@ -119,7 +122,7 @@ instance Field Double where | |||
119 | 122 | ||
120 | instance Field (Complex Double) where | 123 | instance Field (Complex Double) where |
121 | svd = svdC | 124 | svd = svdC |
122 | lu = GSL.luC | 125 | luPacked = luC |
123 | linearSolve = linearSolveC | 126 | linearSolve = linearSolveC |
124 | linearSolveSVD = linearSolveSVDC Nothing | 127 | linearSolveSVD = linearSolveSVDC Nothing |
125 | ctrans = conj . trans | 128 | ctrans = conj . trans |
@@ -146,10 +149,19 @@ chol m | m `equal` ctrans m = cholSH m | |||
146 | 149 | ||
147 | square m = rows m == cols m | 150 | square m = rows m == cols m |
148 | 151 | ||
152 | -- | determinant of a square matrix, computed from the LU decomposition. | ||
149 | det :: Field t => Matrix t -> t | 153 | det :: Field t => Matrix t -> t |
150 | det m | square m = s * (product $ toList $ takeDiag $ u) | 154 | det m | square m = s * (product $ toList $ takeDiag $ lu) |
151 | | otherwise = error "det of nonsquare matrix" | 155 | | otherwise = error "det of nonsquare matrix" |
152 | where (_,u,_,s) = lu m | 156 | where (lu,perm) = luPacked m |
157 | s = signlp (rows m) perm | ||
158 | |||
159 | -- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. | ||
160 | -- | ||
161 | -- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular, | ||
162 | -- u is upper triangular, p is a permutation matrix and s is the signature of the permutation. | ||
163 | lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t) | ||
164 | lu = luFact . luPacked | ||
153 | 165 | ||
154 | -- | Inverse of a square matrix using lapacks' dgesv and zgesv. | 166 | -- | Inverse of a square matrix using lapacks' dgesv and zgesv. |
155 | inv :: Field t => Matrix t -> Matrix t | 167 | inv :: Field t => Matrix t -> Matrix t |
@@ -457,3 +469,35 @@ sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x)) | |||
457 | (.*) = scale | 469 | (.*) = scale |
458 | (|+|) = add | 470 | (|+|) = add |
459 | (|-|) = sub | 471 | (|-|) = sub |
472 | |||
473 | ------------------------------------------------------------------ | ||
474 | |||
475 | signlp r vals = foldl f 1 (zip [0..r-1] vals) | ||
476 | where f s (a,b) | a /= b = -s | ||
477 | | otherwise = s | ||
478 | |||
479 | swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s) | ||
480 | | otherwise = (arr,s) | ||
481 | |||
482 | fixPerm r vals = (fromColumns $ elems res, sign) | ||
483 | where v = [0..r-1] | ||
484 | s = toColumns (ident r) | ||
485 | (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) | ||
486 | |||
487 | triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] | ||
488 | where el i j = if j-i>=h then v else 1 - v | ||
489 | |||
490 | luFact (lu,perm) | r <= c = (l ,u ,p, s) | ||
491 | | otherwise = (l',u',p, s) | ||
492 | where | ||
493 | r = rows lu | ||
494 | c = cols lu | ||
495 | tu = triang r c 0 1 | ||
496 | tl = triang r c 0 0 | ||
497 | l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r | ||
498 | u = lu |*| tu | ||
499 | (p,s) = fixPerm r perm | ||
500 | l' = (lu |*| tl) |+| diagRect (constant 1 c) r c | ||
501 | u' = takeRows c (lu |*| tu) | ||
502 | (|+|) = add | ||
503 | (|*|) = mul | ||