summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra/Algorithms.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/LinearAlgebra/Algorithms.hs')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs58
1 files changed, 51 insertions, 7 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 069d9a3..b19c0ec 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -37,6 +37,8 @@ module Numeric.LinearAlgebra.Algorithms (
37 hess, 37 hess,
38-- ** Schur 38-- ** Schur
39 schur, 39 schur,
40-- ** LU
41 lu,
40-- * Matrix functions 42-- * Matrix functions
41 expm, 43 expm,
42 sqrtm, 44 sqrtm,
@@ -52,11 +54,11 @@ module Numeric.LinearAlgebra.Algorithms (
52-- * Util 54-- * Util
53 haussholder, 55 haussholder,
54 unpackQR, unpackHess, 56 unpackQR, unpackHess,
55 Field(linearSolveSVD,lu,eigSH',cholSH) 57 Field(linearSolveSVD,eigSH',cholSH)
56) where 58) where
57 59
58 60
59import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj) 61import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//))
60import Data.Packed 62import Data.Packed
61import qualified Numeric.GSL.Matrix as GSL 63import qualified Numeric.GSL.Matrix as GSL
62import Numeric.GSL.Vector 64import Numeric.GSL.Vector
@@ -64,12 +66,13 @@ import Numeric.LinearAlgebra.LAPACK as LAPACK
64import Complex 66import Complex
65import Numeric.LinearAlgebra.Linear 67import Numeric.LinearAlgebra.Linear
66import Data.List(foldl1') 68import Data.List(foldl1')
69import Data.Array
67 70
68-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 71-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
69class (Normed (Matrix t), Linear Matrix t) => Field t where 72class (Normed (Matrix t), Linear Matrix t) => Field t where
70 -- | Singular value decomposition using lapack's dgesvd or zgesvd. 73 -- | Singular value decomposition using lapack's dgesvd or zgesvd.
71 svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) 74 svd :: Matrix t -> (Matrix t, Vector Double, Matrix t)
72 lu :: Matrix t -> (Matrix t, Matrix t, [Int], t) 75 luPacked :: Matrix t -> (Matrix t, [Int])
73 -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. 76 -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv.
74 -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". 77 -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK".
75 linearSolve :: Matrix t -> Matrix t -> Matrix t 78 linearSolve :: Matrix t -> Matrix t -> Matrix t
@@ -106,7 +109,7 @@ class (Normed (Matrix t), Linear Matrix t) => Field t where
106 109
107instance Field Double where 110instance Field Double where
108 svd = svdR 111 svd = svdR
109 lu = GSL.luR 112 luPacked = luR
110 linearSolve = linearSolveR 113 linearSolve = linearSolveR
111 linearSolveSVD = linearSolveSVDR Nothing 114 linearSolveSVD = linearSolveSVDR Nothing
112 ctrans = trans 115 ctrans = trans
@@ -119,7 +122,7 @@ instance Field Double where
119 122
120instance Field (Complex Double) where 123instance Field (Complex Double) where
121 svd = svdC 124 svd = svdC
122 lu = GSL.luC 125 luPacked = luC
123 linearSolve = linearSolveC 126 linearSolve = linearSolveC
124 linearSolveSVD = linearSolveSVDC Nothing 127 linearSolveSVD = linearSolveSVDC Nothing
125 ctrans = conj . trans 128 ctrans = conj . trans
@@ -146,10 +149,19 @@ chol m | m `equal` ctrans m = cholSH m
146 149
147square m = rows m == cols m 150square m = rows m == cols m
148 151
152-- | determinant of a square matrix, computed from the LU decomposition.
149det :: Field t => Matrix t -> t 153det :: Field t => Matrix t -> t
150det m | square m = s * (product $ toList $ takeDiag $ u) 154det m | square m = s * (product $ toList $ takeDiag $ lu)
151 | otherwise = error "det of nonsquare matrix" 155 | otherwise = error "det of nonsquare matrix"
152 where (_,u,_,s) = lu m 156 where (lu,perm) = luPacked m
157 s = signlp (rows m) perm
158
159-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf.
160--
161-- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular,
162-- u is upper triangular, p is a permutation matrix and s is the signature of the permutation.
163lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t)
164lu = luFact . luPacked
153 165
154-- | Inverse of a square matrix using lapacks' dgesv and zgesv. 166-- | Inverse of a square matrix using lapacks' dgesv and zgesv.
155inv :: Field t => Matrix t -> Matrix t 167inv :: Field t => Matrix t -> Matrix t
@@ -457,3 +469,35 @@ sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x))
457 (.*) = scale 469 (.*) = scale
458 (|+|) = add 470 (|+|) = add
459 (|-|) = sub 471 (|-|) = sub
472
473------------------------------------------------------------------
474
475signlp r vals = foldl f 1 (zip [0..r-1] vals)
476 where f s (a,b) | a /= b = -s
477 | otherwise = s
478
479swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s)
480 | otherwise = (arr,s)
481
482fixPerm r vals = (fromColumns $ elems res, sign)
483 where v = [0..r-1]
484 s = toColumns (ident r)
485 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals)
486
487triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]]
488 where el i j = if j-i>=h then v else 1 - v
489
490luFact (lu,perm) | r <= c = (l ,u ,p, s)
491 | otherwise = (l',u',p, s)
492 where
493 r = rows lu
494 c = cols lu
495 tu = triang r c 0 1
496 tl = triang r c 0 0
497 l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r
498 u = lu |*| tu
499 (p,s) = fixPerm r perm
500 l' = (lu |*| tl) |+| diagRect (constant 1 c) r c
501 u' = takeRows c (lu |*| tu)
502 (|+|) = add
503 (|*|) = mul