summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-07-11 14:19:21 +0200
committerAlberto Ruiz <aruiz@um.es>2015-07-11 14:19:21 +0200
commitb2341058a2214d22dc23f516b6f09d3270faa18d (patch)
tree1d0734c367f35931822264a060142421edf356df
parenta27c3e2acfb2c37e6103639a9218a4cd20b54421 (diff)
ldl factorization
-rw-r--r--packages/base/src/Internal/Algorithms.hs23
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c107
-rw-r--r--packages/base/src/Internal/LAPACK.hs35
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs16
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs3
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs5
6 files changed, 181 insertions, 8 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index 6ce1830..c8b2d3e 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -55,6 +55,8 @@ class (Product t,
55 mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) 55 mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t)
56 linearSolve' :: Matrix t -> Matrix t -> Matrix t 56 linearSolve' :: Matrix t -> Matrix t -> Matrix t
57 cholSolve' :: Matrix t -> Matrix t -> Matrix t 57 cholSolve' :: Matrix t -> Matrix t -> Matrix t
58 ldlPacked' :: Matrix t -> (Matrix t, [Int])
59 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t
58 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t 60 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t
59 linearSolveLS' :: Matrix t -> Matrix t -> Matrix t 61 linearSolveLS' :: Matrix t -> Matrix t -> Matrix t
60 eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) 62 eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
@@ -90,6 +92,8 @@ instance Field Double where
90 qrgr' = qrgrR 92 qrgr' = qrgrR
91 hess' = unpackHess hessR 93 hess' = unpackHess hessR
92 schur' = schurR 94 schur' = schurR
95 ldlPacked' = ldlR
96 ldlSolve'= uncurry ldlsR
93 97
94instance Field (Complex Double) where 98instance Field (Complex Double) where
95#ifdef NOZGESDD 99#ifdef NOZGESDD
@@ -117,6 +121,8 @@ instance Field (Complex Double) where
117 qrgr' = qrgrC 121 qrgr' = qrgrC
118 hess' = unpackHess hessC 122 hess' = unpackHess hessC
119 schur' = schurC 123 schur' = schurC
124 ldlPacked' = ldlC
125 ldlSolve' = uncurry ldlsC
120 126
121-------------------------------------------------------------- 127--------------------------------------------------------------
122 128
@@ -333,6 +339,23 @@ linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD'
333linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t 339linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t
334linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' 340linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS'
335 341
342--------------------------------------------------------------------------------
343
344-- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part.
345ldlPackedSH :: Field t => Matrix t -> (Matrix t, [Int])
346ldlPackedSH = {-# SCC "ldlPacked" #-} ldlPacked'
347
348-- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'.
349ldlPacked :: Field t => Matrix t -> (Matrix t, [Int])
350ldlPacked m
351 | exactHermitian m = {-# SCC "ldlPacked" #-} ldlPackedSH m
352 | otherwise = error "ldlPacked requires complex Hermitian or real symmetrix matrix"
353
354
355-- | Solution of a linear system (for several right hand sides) from the precomputed LDL factorization obtained by 'ldlPacked'.
356ldlSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t
357ldlSolve = {-# SCC "ldlSolve" #-} ldlSolve'
358
336-------------------------------------------------------------- 359--------------------------------------------------------------
337 360
338{- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix. 361{- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix.
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 30689bf..177d373 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1086,6 +1086,113 @@ int luS_l_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) {
1086 OK 1086 OK
1087} 1087}
1088 1088
1089
1090//////////////////// LDL factorization /////////////////////////
1091
1092int dsytrf_(char *uplo, integer *n, doublereal *a, integer *lda, integer *ipiv,
1093 doublereal *work, integer *lwork, integer *info);
1094
1095int ldl_R(DVEC(ipiv), ODMAT(r)) {
1096 integer n = rr;
1097 REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE);
1098 DEBUGMSG("ldl_R");
1099 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1100 integer res;
1101 integer lda = rXc;
1102 integer lwork = -1;
1103 doublereal ans;
1104 dsytrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res);
1105 lwork = ceil(ans);
1106 doublereal* work = (doublereal*)malloc(lwork*sizeof(doublereal));
1107 dsytrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res);
1108 CHECK(res,res);
1109 int k;
1110 for (k=0; k<n; k++) {
1111 ipivp[k] = auxipiv[k];
1112 }
1113 free(auxipiv);
1114 free(work);
1115 OK
1116}
1117
1118
1119int zhetrf_(char *uplo, integer *n, doublecomplex *a, integer *lda, integer *ipiv,
1120 doublecomplex *work, integer *lwork, integer *info);
1121
1122int ldl_C(DVEC(ipiv), OCMAT(r)) {
1123 integer n = rr;
1124 REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE);
1125 DEBUGMSG("ldl_R");
1126 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1127 integer res;
1128 integer lda = rXc;
1129 integer lwork = -1;
1130 doublecomplex ans;
1131 zhetrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res);
1132 lwork = ceil(ans.r);
1133 doublecomplex* work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1134 zhetrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res);
1135 CHECK(res,res);
1136 int k;
1137 for (k=0; k<n; k++) {
1138 ipivp[k] = auxipiv[k];
1139 }
1140 free(auxipiv);
1141 free(work);
1142 OK
1143
1144}
1145
1146//////////////////// LDL solve /////////////////////////
1147
1148int dsytrs_(char *uplo, integer *n, integer *nrhs, doublereal *a, integer *lda,
1149 integer *ipiv, doublereal *b, integer *ldb, integer *info);
1150
1151int ldl_S_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) {
1152 integer m = ar;
1153 integer n = ac;
1154 integer lda = aXc;
1155 integer mrhs = br;
1156 integer nrhs = bc;
1157
1158 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1159 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1160 int k;
1161 for (k=0; k<n; k++) {
1162 auxipiv[k] = (integer)ipivp[k];
1163 }
1164 integer res;
1165 dsytrs_ ("L",&n,&nrhs,(/*no const (!?)*/ double*)ap,&lda,auxipiv,bp,&mrhs,&res);
1166 CHECK(res,res);
1167 free(auxipiv);
1168 OK
1169}
1170
1171
1172int zhetrs_(char *uplo, integer *n, integer *nrhs, doublecomplex *a, integer *lda,
1173 integer *ipiv, doublecomplex *b, integer *ldb, integer *info);
1174
1175int ldl_S_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) {
1176 integer m = ar;
1177 integer n = ac;
1178 integer lda = aXc;
1179 integer mrhs = br;
1180 integer nrhs = bc;
1181
1182 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1183 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1184 int k;
1185 for (k=0; k<n; k++) {
1186 auxipiv[k] = (integer)ipivp[k];
1187 }
1188 integer res;
1189 zhetrs_ ("L",&n,&nrhs,(doublecomplex*)ap,&lda,auxipiv,bp,&mrhs,&res);
1190 CHECK(res,res);
1191 free(auxipiv);
1192 OK
1193}
1194
1195
1089//////////////////// Matrix Product ///////////////////////// 1196//////////////////// Matrix Product /////////////////////////
1090 1197
1091void dgemm_(char *, char *, integer *, integer *, integer *, 1198void dgemm_(char *, char *, integer *, integer *, integer *,
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index f2fc68d..c2c140b 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -591,7 +591,7 @@ foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok
591lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double 591lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
592lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv b 592lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv b
593 593
594-- | Solve a real linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/. 594-- | Solve a complex linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/.
595lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) 595lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
596lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b 596lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b
597 597
@@ -600,10 +600,41 @@ lusAux f st a piv b
600 x <- copy ColumnMajor b 600 x <- copy ColumnMajor b
601 f # a # piv' # x #| st 601 f # a # piv' # x #| st
602 return x 602 return x
603 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" 603 | otherwise = error st
604 where 604 where
605 n1 = rows a 605 n1 = rows a
606 n2 = cols a 606 n2 = cols a
607 n = rows b 607 n = rows b
608 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double 608 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double
609 609
610-----------------------------------------------------------------------------------
611foreign import ccall unsafe "ldl_R" dsytrf :: R :> R ::> Ok
612foreign import ccall unsafe "ldl_C" zhetrf :: R :> C ::> Ok
613
614-- | LDL factorization of a symmetric real matrix, using LAPACK's /dsytrf/.
615ldlR :: Matrix Double -> (Matrix Double, [Int])
616ldlR = ldlAux dsytrf "ldlR"
617
618-- | LDL factorization of a hermitian complex matrix, using LAPACK's /zhetrf/.
619ldlC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int])
620ldlC = ldlAux zhetrf "ldlC"
621
622ldlAux f st a = unsafePerformIO $ do
623 ldl <- copy ColumnMajor a
624 piv <- createVector (rows a)
625 f # piv # ldl #| st
626 return (ldl, map (pred.round) (toList piv))
627
628-----------------------------------------------------------------------------------
629
630foreign import ccall unsafe "ldl_S_R" dsytrs :: R ::> R :> R ::> Ok
631foreign import ccall unsafe "ldl_S_C" zsytrs :: C ::> R :> C ::> Ok
632
633-- | Solve a real linear system from a precomputed LDL decomposition ('ldlR'), using LAPACK's /dsytrs/.
634ldlsR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
635ldlsR a piv b = lusAux dsytrs "ldlsR" (fmat a) piv b
636
637-- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/.
638ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
639ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b
640
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 0b8abbb..dd4cc67 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -15,7 +15,7 @@ module Numeric.LinearAlgebra (
15 -- * Basic types and data processing 15 -- * Basic types and data processing
16 module Numeric.LinearAlgebra.Data, 16 module Numeric.LinearAlgebra.Data,
17 17
18 -- * Arithmetic and numeric classes 18 -- * Numeric classes
19 -- | 19 -- |
20 -- The standard numeric classes are defined elementwise: 20 -- The standard numeric classes are defined elementwise:
21 -- 21 --
@@ -27,7 +27,9 @@ module Numeric.LinearAlgebra (
27 -- [ 1.0, 0.0, 0.0 27 -- [ 1.0, 0.0, 0.0
28 -- , 0.0, 5.0, 0.0 28 -- , 0.0, 5.0, 0.0
29 -- , 0.0, 0.0, 9.0 ] 29 -- , 0.0, 0.0, 9.0 ]
30 -- 30
31 -- * Autoconformable dimensions
32 -- |
31 -- In arithmetic operations single-element vectors and matrices 33 -- In arithmetic operations single-element vectors and matrices
32 -- (created from numeric literals or using 'scalar') automatically 34 -- (created from numeric literals or using 'scalar') automatically
33 -- expand to match the dimensions of the other operand: 35 -- expand to match the dimensions of the other operand:
@@ -79,6 +81,7 @@ module Numeric.LinearAlgebra (
79 luSolve, 81 luSolve,
80 luSolve', 82 luSolve',
81 cholSolve, 83 cholSolve,
84 ldlSolve,
82 cgSolve, 85 cgSolve,
83 cgSolve', 86 cgSolve',
84 87
@@ -115,15 +118,18 @@ module Numeric.LinearAlgebra (
115 -- * Cholesky 118 -- * Cholesky
116 chol, cholSH, mbCholSH, 119 chol, cholSH, mbCholSH,
117 120
121 -- * LU
122 lu, luPacked, luPacked', luFact,
123
124 -- * LDL
125 ldlPacked, ldlPackedSH,
126
118 -- * Hessenberg 127 -- * Hessenberg
119 hess, 128 hess,
120 129
121 -- * Schur 130 -- * Schur
122 schur, 131 schur,
123 132
124 -- * LU
125 lu, luPacked, luPacked', luFact,
126
127 -- * Matrix functions 133 -- * Matrix functions
128 expm, 134 expm,
129 sqrtm, 135 sqrtm,
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index d9bc9a0..2ff1580 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -587,6 +587,9 @@ runTests n = do
587 putStrLn "------ luSolve" 587 putStrLn "------ luSolve"
588 test (linearSolveProp (luSolve.luPacked) . rSqWC) 588 test (linearSolveProp (luSolve.luPacked) . rSqWC)
589 test (linearSolveProp (luSolve.luPacked) . cSqWC) 589 test (linearSolveProp (luSolve.luPacked) . cSqWC)
590 putStrLn "------ ldlSolve"
591 test (linearSolveProp (ldlSolve.ldlPacked) . rSymWC)
592 test (linearSolveProp (ldlSolve.ldlPacked) . cSymWC)
590 putStrLn "------ cholSolve" 593 putStrLn "------ cholSolve"
591 test (linearSolveProp (cholSolve.chol) . rPosDef) 594 test (linearSolveProp (cholSolve.chol) . rPosDef)
592 test (linearSolveProp (cholSolve.chol) . cPosDef) 595 test (linearSolveProp (cholSolve.chol) . cPosDef)
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
index 904ae05..7c54535 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs
@@ -16,7 +16,7 @@ module Numeric.LinearAlgebra.Tests.Instances(
16 Rot(..), rRot,cRot, 16 Rot(..), rRot,cRot,
17 Her(..), rHer,cHer, 17 Her(..), rHer,cHer,
18 WC(..), rWC,cWC, 18 WC(..), rWC,cWC,
19 SqWC(..), rSqWC, cSqWC, 19 SqWC(..), rSqWC, cSqWC, rSymWC, cSymWC,
20 PosDef(..), rPosDef, cPosDef, 20 PosDef(..), rPosDef, cPosDef,
21 Consistent(..), rConsist, cConsist, 21 Consistent(..), rConsist, cConsist,
22 RM,CM, rM,cM, 22 RM,CM, rM,cM,
@@ -176,6 +176,9 @@ cWC (WC m) = m :: CM
176rSqWC (SqWC m) = m :: RM 176rSqWC (SqWC m) = m :: RM
177cSqWC (SqWC m) = m :: CM 177cSqWC (SqWC m) = m :: CM
178 178
179rSymWC (SqWC m) = m + tr m :: RM
180cSymWC (SqWC m) = m + tr m :: CM
181
179rPosDef (PosDef m) = m :: RM 182rPosDef (PosDef m) = m :: RM
180cPosDef (PosDef m) = m :: CM 183cPosDef (PosDef m) = m :: CM
181 184