summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Algorithms.hs23
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c107
-rw-r--r--packages/base/src/Internal/LAPACK.hs35
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs16
4 files changed, 174 insertions, 7 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index 6ce1830..c8b2d3e 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -55,6 +55,8 @@ class (Product t,
55 mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) 55 mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t)
56 linearSolve' :: Matrix t -> Matrix t -> Matrix t 56 linearSolve' :: Matrix t -> Matrix t -> Matrix t
57 cholSolve' :: Matrix t -> Matrix t -> Matrix t 57 cholSolve' :: Matrix t -> Matrix t -> Matrix t
58 ldlPacked' :: Matrix t -> (Matrix t, [Int])
59 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t
58 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t 60 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t
59 linearSolveLS' :: Matrix t -> Matrix t -> Matrix t 61 linearSolveLS' :: Matrix t -> Matrix t -> Matrix t
60 eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) 62 eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
@@ -90,6 +92,8 @@ instance Field Double where
90 qrgr' = qrgrR 92 qrgr' = qrgrR
91 hess' = unpackHess hessR 93 hess' = unpackHess hessR
92 schur' = schurR 94 schur' = schurR
95 ldlPacked' = ldlR
96 ldlSolve'= uncurry ldlsR
93 97
94instance Field (Complex Double) where 98instance Field (Complex Double) where
95#ifdef NOZGESDD 99#ifdef NOZGESDD
@@ -117,6 +121,8 @@ instance Field (Complex Double) where
117 qrgr' = qrgrC 121 qrgr' = qrgrC
118 hess' = unpackHess hessC 122 hess' = unpackHess hessC
119 schur' = schurC 123 schur' = schurC
124 ldlPacked' = ldlC
125 ldlSolve' = uncurry ldlsC
120 126
121-------------------------------------------------------------- 127--------------------------------------------------------------
122 128
@@ -333,6 +339,23 @@ linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD'
333linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t 339linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t
334linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' 340linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS'
335 341
342--------------------------------------------------------------------------------
343
344-- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part.
345ldlPackedSH :: Field t => Matrix t -> (Matrix t, [Int])
346ldlPackedSH = {-# SCC "ldlPacked" #-} ldlPacked'
347
348-- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'.
349ldlPacked :: Field t => Matrix t -> (Matrix t, [Int])
350ldlPacked m
351 | exactHermitian m = {-# SCC "ldlPacked" #-} ldlPackedSH m
352 | otherwise = error "ldlPacked requires complex Hermitian or real symmetrix matrix"
353
354
355-- | Solution of a linear system (for several right hand sides) from the precomputed LDL factorization obtained by 'ldlPacked'.
356ldlSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t
357ldlSolve = {-# SCC "ldlSolve" #-} ldlSolve'
358
336-------------------------------------------------------------- 359--------------------------------------------------------------
337 360
338{- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix. 361{- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix.
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 30689bf..177d373 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1086,6 +1086,113 @@ int luS_l_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) {
1086 OK 1086 OK
1087} 1087}
1088 1088
1089
1090//////////////////// LDL factorization /////////////////////////
1091
1092int dsytrf_(char *uplo, integer *n, doublereal *a, integer *lda, integer *ipiv,
1093 doublereal *work, integer *lwork, integer *info);
1094
1095int ldl_R(DVEC(ipiv), ODMAT(r)) {
1096 integer n = rr;
1097 REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE);
1098 DEBUGMSG("ldl_R");
1099 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1100 integer res;
1101 integer lda = rXc;
1102 integer lwork = -1;
1103 doublereal ans;
1104 dsytrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res);
1105 lwork = ceil(ans);
1106 doublereal* work = (doublereal*)malloc(lwork*sizeof(doublereal));
1107 dsytrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res);
1108 CHECK(res,res);
1109 int k;
1110 for (k=0; k<n; k++) {
1111 ipivp[k] = auxipiv[k];
1112 }
1113 free(auxipiv);
1114 free(work);
1115 OK
1116}
1117
1118
1119int zhetrf_(char *uplo, integer *n, doublecomplex *a, integer *lda, integer *ipiv,
1120 doublecomplex *work, integer *lwork, integer *info);
1121
1122int ldl_C(DVEC(ipiv), OCMAT(r)) {
1123 integer n = rr;
1124 REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE);
1125 DEBUGMSG("ldl_R");
1126 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1127 integer res;
1128 integer lda = rXc;
1129 integer lwork = -1;
1130 doublecomplex ans;
1131 zhetrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res);
1132 lwork = ceil(ans.r);
1133 doublecomplex* work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1134 zhetrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res);
1135 CHECK(res,res);
1136 int k;
1137 for (k=0; k<n; k++) {
1138 ipivp[k] = auxipiv[k];
1139 }
1140 free(auxipiv);
1141 free(work);
1142 OK
1143
1144}
1145
1146//////////////////// LDL solve /////////////////////////
1147
1148int dsytrs_(char *uplo, integer *n, integer *nrhs, doublereal *a, integer *lda,
1149 integer *ipiv, doublereal *b, integer *ldb, integer *info);
1150
1151int ldl_S_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) {
1152 integer m = ar;
1153 integer n = ac;
1154 integer lda = aXc;
1155 integer mrhs = br;
1156 integer nrhs = bc;
1157
1158 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1159 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1160 int k;
1161 for (k=0; k<n; k++) {
1162 auxipiv[k] = (integer)ipivp[k];
1163 }
1164 integer res;
1165 dsytrs_ ("L",&n,&nrhs,(/*no const (!?)*/ double*)ap,&lda,auxipiv,bp,&mrhs,&res);
1166 CHECK(res,res);
1167 free(auxipiv);
1168 OK
1169}
1170
1171
1172int zhetrs_(char *uplo, integer *n, integer *nrhs, doublecomplex *a, integer *lda,
1173 integer *ipiv, doublecomplex *b, integer *ldb, integer *info);
1174
1175int ldl_S_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) {
1176 integer m = ar;
1177 integer n = ac;
1178 integer lda = aXc;
1179 integer mrhs = br;
1180 integer nrhs = bc;
1181
1182 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1183 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1184 int k;
1185 for (k=0; k<n; k++) {
1186 auxipiv[k] = (integer)ipivp[k];
1187 }
1188 integer res;
1189 zhetrs_ ("L",&n,&nrhs,(doublecomplex*)ap,&lda,auxipiv,bp,&mrhs,&res);
1190 CHECK(res,res);
1191 free(auxipiv);
1192 OK
1193}
1194
1195
1089//////////////////// Matrix Product ///////////////////////// 1196//////////////////// Matrix Product /////////////////////////
1090 1197
1091void dgemm_(char *, char *, integer *, integer *, integer *, 1198void dgemm_(char *, char *, integer *, integer *, integer *,
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index f2fc68d..c2c140b 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -591,7 +591,7 @@ foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok
591lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double 591lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
592lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv b 592lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv b
593 593
594-- | Solve a real linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/. 594-- | Solve a complex linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/.
595lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) 595lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
596lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b 596lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b
597 597
@@ -600,10 +600,41 @@ lusAux f st a piv b
600 x <- copy ColumnMajor b 600 x <- copy ColumnMajor b
601 f # a # piv' # x #| st 601 f # a # piv' # x #| st
602 return x 602 return x
603 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" 603 | otherwise = error st
604 where 604 where
605 n1 = rows a 605 n1 = rows a
606 n2 = cols a 606 n2 = cols a
607 n = rows b 607 n = rows b
608 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double 608 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double
609 609
610-----------------------------------------------------------------------------------
611foreign import ccall unsafe "ldl_R" dsytrf :: R :> R ::> Ok
612foreign import ccall unsafe "ldl_C" zhetrf :: R :> C ::> Ok
613
614-- | LDL factorization of a symmetric real matrix, using LAPACK's /dsytrf/.
615ldlR :: Matrix Double -> (Matrix Double, [Int])
616ldlR = ldlAux dsytrf "ldlR"
617
618-- | LDL factorization of a hermitian complex matrix, using LAPACK's /zhetrf/.
619ldlC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int])
620ldlC = ldlAux zhetrf "ldlC"
621
622ldlAux f st a = unsafePerformIO $ do
623 ldl <- copy ColumnMajor a
624 piv <- createVector (rows a)
625 f # piv # ldl #| st
626 return (ldl, map (pred.round) (toList piv))
627
628-----------------------------------------------------------------------------------
629
630foreign import ccall unsafe "ldl_S_R" dsytrs :: R ::> R :> R ::> Ok
631foreign import ccall unsafe "ldl_S_C" zsytrs :: C ::> R :> C ::> Ok
632
633-- | Solve a real linear system from a precomputed LDL decomposition ('ldlR'), using LAPACK's /dsytrs/.
634ldlsR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
635ldlsR a piv b = lusAux dsytrs "ldlsR" (fmat a) piv b
636
637-- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/.
638ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
639ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b
640
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 0b8abbb..dd4cc67 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -15,7 +15,7 @@ module Numeric.LinearAlgebra (
15 -- * Basic types and data processing 15 -- * Basic types and data processing
16 module Numeric.LinearAlgebra.Data, 16 module Numeric.LinearAlgebra.Data,
17 17
18 -- * Arithmetic and numeric classes 18 -- * Numeric classes
19 -- | 19 -- |
20 -- The standard numeric classes are defined elementwise: 20 -- The standard numeric classes are defined elementwise:
21 -- 21 --
@@ -27,7 +27,9 @@ module Numeric.LinearAlgebra (
27 -- [ 1.0, 0.0, 0.0 27 -- [ 1.0, 0.0, 0.0
28 -- , 0.0, 5.0, 0.0 28 -- , 0.0, 5.0, 0.0
29 -- , 0.0, 0.0, 9.0 ] 29 -- , 0.0, 0.0, 9.0 ]
30 -- 30
31 -- * Autoconformable dimensions
32 -- |
31 -- In arithmetic operations single-element vectors and matrices 33 -- In arithmetic operations single-element vectors and matrices
32 -- (created from numeric literals or using 'scalar') automatically 34 -- (created from numeric literals or using 'scalar') automatically
33 -- expand to match the dimensions of the other operand: 35 -- expand to match the dimensions of the other operand:
@@ -79,6 +81,7 @@ module Numeric.LinearAlgebra (
79 luSolve, 81 luSolve,
80 luSolve', 82 luSolve',
81 cholSolve, 83 cholSolve,
84 ldlSolve,
82 cgSolve, 85 cgSolve,
83 cgSolve', 86 cgSolve',
84 87
@@ -115,15 +118,18 @@ module Numeric.LinearAlgebra (
115 -- * Cholesky 118 -- * Cholesky
116 chol, cholSH, mbCholSH, 119 chol, cholSH, mbCholSH,
117 120
121 -- * LU
122 lu, luPacked, luPacked', luFact,
123
124 -- * LDL
125 ldlPacked, ldlPackedSH,
126
118 -- * Hessenberg 127 -- * Hessenberg
119 hess, 128 hess,
120 129
121 -- * Schur 130 -- * Schur
122 schur, 131 schur,
123 132
124 -- * LU
125 lu, luPacked, luPacked', luFact,
126
127 -- * Matrix functions 133 -- * Matrix functions
128 expm, 134 expm,
129 sqrtm, 135 sqrtm,