summaryrefslogtreecommitdiff
path: root/lib/Numeric
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs58
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs21
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c46
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h13
4 files changed, 120 insertions, 18 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 069d9a3..b19c0ec 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -37,6 +37,8 @@ module Numeric.LinearAlgebra.Algorithms (
37 hess, 37 hess,
38-- ** Schur 38-- ** Schur
39 schur, 39 schur,
40-- ** LU
41 lu,
40-- * Matrix functions 42-- * Matrix functions
41 expm, 43 expm,
42 sqrtm, 44 sqrtm,
@@ -52,11 +54,11 @@ module Numeric.LinearAlgebra.Algorithms (
52-- * Util 54-- * Util
53 haussholder, 55 haussholder,
54 unpackQR, unpackHess, 56 unpackQR, unpackHess,
55 Field(linearSolveSVD,lu,eigSH',cholSH) 57 Field(linearSolveSVD,eigSH',cholSH)
56) where 58) where
57 59
58 60
59import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj) 61import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//))
60import Data.Packed 62import Data.Packed
61import qualified Numeric.GSL.Matrix as GSL 63import qualified Numeric.GSL.Matrix as GSL
62import Numeric.GSL.Vector 64import Numeric.GSL.Vector
@@ -64,12 +66,13 @@ import Numeric.LinearAlgebra.LAPACK as LAPACK
64import Complex 66import Complex
65import Numeric.LinearAlgebra.Linear 67import Numeric.LinearAlgebra.Linear
66import Data.List(foldl1') 68import Data.List(foldl1')
69import Data.Array
67 70
68-- | Auxiliary typeclass used to define generic computations for both real and complex matrices. 71-- | Auxiliary typeclass used to define generic computations for both real and complex matrices.
69class (Normed (Matrix t), Linear Matrix t) => Field t where 72class (Normed (Matrix t), Linear Matrix t) => Field t where
70 -- | Singular value decomposition using lapack's dgesvd or zgesvd. 73 -- | Singular value decomposition using lapack's dgesvd or zgesvd.
71 svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) 74 svd :: Matrix t -> (Matrix t, Vector Double, Matrix t)
72 lu :: Matrix t -> (Matrix t, Matrix t, [Int], t) 75 luPacked :: Matrix t -> (Matrix t, [Int])
73 -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. 76 -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv.
74 -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". 77 -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK".
75 linearSolve :: Matrix t -> Matrix t -> Matrix t 78 linearSolve :: Matrix t -> Matrix t -> Matrix t
@@ -106,7 +109,7 @@ class (Normed (Matrix t), Linear Matrix t) => Field t where
106 109
107instance Field Double where 110instance Field Double where
108 svd = svdR 111 svd = svdR
109 lu = GSL.luR 112 luPacked = luR
110 linearSolve = linearSolveR 113 linearSolve = linearSolveR
111 linearSolveSVD = linearSolveSVDR Nothing 114 linearSolveSVD = linearSolveSVDR Nothing
112 ctrans = trans 115 ctrans = trans
@@ -119,7 +122,7 @@ instance Field Double where
119 122
120instance Field (Complex Double) where 123instance Field (Complex Double) where
121 svd = svdC 124 svd = svdC
122 lu = GSL.luC 125 luPacked = luC
123 linearSolve = linearSolveC 126 linearSolve = linearSolveC
124 linearSolveSVD = linearSolveSVDC Nothing 127 linearSolveSVD = linearSolveSVDC Nothing
125 ctrans = conj . trans 128 ctrans = conj . trans
@@ -146,10 +149,19 @@ chol m | m `equal` ctrans m = cholSH m
146 149
147square m = rows m == cols m 150square m = rows m == cols m
148 151
152-- | determinant of a square matrix, computed from the LU decomposition.
149det :: Field t => Matrix t -> t 153det :: Field t => Matrix t -> t
150det m | square m = s * (product $ toList $ takeDiag $ u) 154det m | square m = s * (product $ toList $ takeDiag $ lu)
151 | otherwise = error "det of nonsquare matrix" 155 | otherwise = error "det of nonsquare matrix"
152 where (_,u,_,s) = lu m 156 where (lu,perm) = luPacked m
157 s = signlp (rows m) perm
158
159-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf.
160--
161-- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular,
162-- u is upper triangular, p is a permutation matrix and s is the signature of the permutation.
163lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t)
164lu = luFact . luPacked
153 165
154-- | Inverse of a square matrix using lapacks' dgesv and zgesv. 166-- | Inverse of a square matrix using lapacks' dgesv and zgesv.
155inv :: Field t => Matrix t -> Matrix t 167inv :: Field t => Matrix t -> Matrix t
@@ -457,3 +469,35 @@ sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x))
457 (.*) = scale 469 (.*) = scale
458 (|+|) = add 470 (|+|) = add
459 (|-|) = sub 471 (|-|) = sub
472
473------------------------------------------------------------------
474
475signlp r vals = foldl f 1 (zip [0..r-1] vals)
476 where f s (a,b) | a /= b = -s
477 | otherwise = s
478
479swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s)
480 | otherwise = (arr,s)
481
482fixPerm r vals = (fromColumns $ elems res, sign)
483 where v = [0..r-1]
484 s = toColumns (ident r)
485 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals)
486
487triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]]
488 where el i j = if j-i>=h then v else 1 - v
489
490luFact (lu,perm) | r <= c = (l ,u ,p, s)
491 | otherwise = (l',u',p, s)
492 where
493 r = rows lu
494 c = cols lu
495 tu = triang r c 0 1
496 tl = triang r c 0 0
497 l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r
498 u = lu |*| tu
499 (p,s) = fixPerm r perm
500 l' = (lu |*| tl) |+| diagRect (constant 1 c) r c
501 u' = takeRows c (lu |*| tu)
502 (|+|) = add
503 (|*|) = mul
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index cacad87..83db901 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -19,6 +19,7 @@ module Numeric.LinearAlgebra.LAPACK (
19 linearSolveR, linearSolveC, 19 linearSolveR, linearSolveC,
20 linearSolveLSR, linearSolveLSC, 20 linearSolveLSR, linearSolveLSC,
21 linearSolveSVDR, linearSolveSVDC, 21 linearSolveSVDR, linearSolveSVDC,
22 luR, luC,
22 cholS, cholH, 23 cholS, cholH,
23 qrR, qrC, 24 qrR, qrC,
24 hessR, hessC, 25 hessR, hessC,
@@ -299,7 +300,7 @@ hessAux f st a = unsafePerformIO $ do
299 mn = min m n 300 mn = min m n
300 301
301----------------------------------------------------------------------------------- 302-----------------------------------------------------------------------------------
302foreign import ccall safe "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM 303foreign import ccall "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM
303foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM 304foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM
304 305
305-- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix. 306-- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix.
@@ -318,3 +319,21 @@ schurAux f st a = unsafePerformIO $ do
318 where n = rows a 319 where n = rows a
319 320
320----------------------------------------------------------------------------------- 321-----------------------------------------------------------------------------------
322foreign import ccall "LAPACK/lapack-aux.h lu_l_R" dgetrf :: TMVM
323foreign import ccall "LAPACK/lapack-aux.h lu_l_C" zgetrf :: TCMVCM
324
325-- | Wrapper for LAPACK's /dgetrf/, which computes a LU factorization of a general real matrix.
326luR :: Matrix Double -> (Matrix Double, [Int])
327luR = luAux dgetrf "luR" . fmat
328
329-- | Wrapper for LAPACK's /zgees/, which computes a Schur factorization of a square complex matrix.
330luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int])
331luC = luAux zgetrf "luC" . fmat
332
333luAux f st a = unsafePerformIO $ do
334 lu <- createMatrix ColumnMajor n m
335 piv <- createVector (min n m)
336 app3 f mat a vec piv mat lu st
337 return (lu, map (pred.round) (toList piv))
338 where n = rows a
339 m = cols a
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index 8392feb..310f6ee 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -768,3 +768,49 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) {
768 OK 768 OK
769 #endif 769 #endif
770} 770}
771
772//////////////////// LU factorization /////////////////////////
773
774int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) {
775 integer m = ar;
776 integer n = ac;
777 integer mn = MIN(m,n);
778 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
779 DEBUGMSG("lu_l_R");
780 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
781 memcpy(rp,ap,m*n*sizeof(double));
782 integer res;
783 dgetrf_ (&m,&n,rp,&m,auxipiv,&res);
784 if(res>0) {
785 res = 0; // fixme
786 }
787 CHECK(res,res);
788 int k;
789 for (k=0; k<mn; k++) {
790 ipivp[k] = auxipiv[k];
791 }
792 free(auxipiv);
793 OK
794}
795
796int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) {
797 integer m = ar;
798 integer n = ac;
799 integer mn = MIN(m,n);
800 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
801 DEBUGMSG("lu_l_C");
802 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
803 memcpy(rp,ap,m*n*sizeof(doublecomplex));
804 integer res;
805 zgetrf_ (&m,&n,(doublecomplex*)rp,&m,auxipiv,&res);
806 if(res>0) {
807 res = 0; // fixme
808 }
809 CHECK(res,res);
810 int k;
811 for (k=0; k<mn; k++) {
812 ipivp[k] = auxipiv[k];
813 }
814 free(auxipiv);
815 OK
816}
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index e5d74d7..bccd4b8 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -14,41 +14,34 @@
14 14
15int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 15int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
16int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 16int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
17
18int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); 17int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v));
19 18
20int eig_l_C(KCMAT(a),CMAT(u),CVEC(s),CMAT(v)); 19int eig_l_C(KCMAT(a),CMAT(u),CVEC(s),CMAT(v));
21
22int eig_l_R(KDMAT(a),DMAT(u),CVEC(s),DMAT(v)); 20int eig_l_R(KDMAT(a),DMAT(u),CVEC(s),DMAT(v));
23 21
24int eig_l_S(KDMAT(a),DVEC(s),DMAT(v)); 22int eig_l_S(KDMAT(a),DVEC(s),DMAT(v));
25
26int eig_l_H(KCMAT(a),DVEC(s),CMAT(v)); 23int eig_l_H(KCMAT(a),DVEC(s),CMAT(v));
27 24
28int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)); 25int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x));
29
30int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)); 26int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x));
31 27
32int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)); 28int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x));
33
34int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)); 29int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x));
35 30
36int linearSolveSVDR_l(double,KDMAT(a),KDMAT(b),DMAT(x)); 31int linearSolveSVDR_l(double,KDMAT(a),KDMAT(b),DMAT(x));
37
38int linearSolveSVDC_l(double,KCMAT(a),KCMAT(b),CMAT(x)); 32int linearSolveSVDC_l(double,KCMAT(a),KCMAT(b),CMAT(x));
39 33
40int chol_l_H(KCMAT(a),CMAT(r)); 34int chol_l_H(KCMAT(a),CMAT(r));
41
42int chol_l_S(KDMAT(a),DMAT(r)); 35int chol_l_S(KDMAT(a),DMAT(r));
43 36
44int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)); 37int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r));
45
46int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)); 38int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r));
47 39
48int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)); 40int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r));
49
50int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)); 41int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r));
51 42
52int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)); 43int schur_l_R(KDMAT(a), DMAT(u), DMAT(s));
53
54int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); 44int schur_l_C(KCMAT(a), CMAT(u), CMAT(s));
45
46int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r));
47int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r));