diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/src/Internal/Algorithms.hs | 23 | ||||
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 107 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 35 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 16 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 3 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | 5 |
6 files changed, 181 insertions, 8 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 6ce1830..c8b2d3e 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -55,6 +55,8 @@ class (Product t, | |||
55 | mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) | 55 | mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) |
56 | linearSolve' :: Matrix t -> Matrix t -> Matrix t | 56 | linearSolve' :: Matrix t -> Matrix t -> Matrix t |
57 | cholSolve' :: Matrix t -> Matrix t -> Matrix t | 57 | cholSolve' :: Matrix t -> Matrix t -> Matrix t |
58 | ldlPacked' :: Matrix t -> (Matrix t, [Int]) | ||
59 | ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t | ||
58 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t | 60 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t |
59 | linearSolveLS' :: Matrix t -> Matrix t -> Matrix t | 61 | linearSolveLS' :: Matrix t -> Matrix t -> Matrix t |
60 | eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) | 62 | eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) |
@@ -90,6 +92,8 @@ instance Field Double where | |||
90 | qrgr' = qrgrR | 92 | qrgr' = qrgrR |
91 | hess' = unpackHess hessR | 93 | hess' = unpackHess hessR |
92 | schur' = schurR | 94 | schur' = schurR |
95 | ldlPacked' = ldlR | ||
96 | ldlSolve'= uncurry ldlsR | ||
93 | 97 | ||
94 | instance Field (Complex Double) where | 98 | instance Field (Complex Double) where |
95 | #ifdef NOZGESDD | 99 | #ifdef NOZGESDD |
@@ -117,6 +121,8 @@ instance Field (Complex Double) where | |||
117 | qrgr' = qrgrC | 121 | qrgr' = qrgrC |
118 | hess' = unpackHess hessC | 122 | hess' = unpackHess hessC |
119 | schur' = schurC | 123 | schur' = schurC |
124 | ldlPacked' = ldlC | ||
125 | ldlSolve' = uncurry ldlsC | ||
120 | 126 | ||
121 | -------------------------------------------------------------- | 127 | -------------------------------------------------------------- |
122 | 128 | ||
@@ -333,6 +339,23 @@ linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' | |||
333 | linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t | 339 | linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t |
334 | linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' | 340 | linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS' |
335 | 341 | ||
342 | -------------------------------------------------------------------------------- | ||
343 | |||
344 | -- | Similar to 'ldlPacked', without checking that the input matrix is hermitian or symmetric. It works with the lower triangular part. | ||
345 | ldlPackedSH :: Field t => Matrix t -> (Matrix t, [Int]) | ||
346 | ldlPackedSH = {-# SCC "ldlPacked" #-} ldlPacked' | ||
347 | |||
348 | -- | Obtains the LDL decomposition of a matrix in a compact data structure suitable for 'ldlSolve'. | ||
349 | ldlPacked :: Field t => Matrix t -> (Matrix t, [Int]) | ||
350 | ldlPacked m | ||
351 | | exactHermitian m = {-# SCC "ldlPacked" #-} ldlPackedSH m | ||
352 | | otherwise = error "ldlPacked requires complex Hermitian or real symmetrix matrix" | ||
353 | |||
354 | |||
355 | -- | Solution of a linear system (for several right hand sides) from the precomputed LDL factorization obtained by 'ldlPacked'. | ||
356 | ldlSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t | ||
357 | ldlSolve = {-# SCC "ldlSolve" #-} ldlSolve' | ||
358 | |||
336 | -------------------------------------------------------------- | 359 | -------------------------------------------------------------- |
337 | 360 | ||
338 | {- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix. | 361 | {- | Eigenvalues (not ordered) and eigenvectors (as columns) of a general square matrix. |
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 30689bf..177d373 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -1086,6 +1086,113 @@ int luS_l_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) { | |||
1086 | OK | 1086 | OK |
1087 | } | 1087 | } |
1088 | 1088 | ||
1089 | |||
1090 | //////////////////// LDL factorization ///////////////////////// | ||
1091 | |||
1092 | int dsytrf_(char *uplo, integer *n, doublereal *a, integer *lda, integer *ipiv, | ||
1093 | doublereal *work, integer *lwork, integer *info); | ||
1094 | |||
1095 | int ldl_R(DVEC(ipiv), ODMAT(r)) { | ||
1096 | integer n = rr; | ||
1097 | REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE); | ||
1098 | DEBUGMSG("ldl_R"); | ||
1099 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1100 | integer res; | ||
1101 | integer lda = rXc; | ||
1102 | integer lwork = -1; | ||
1103 | doublereal ans; | ||
1104 | dsytrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res); | ||
1105 | lwork = ceil(ans); | ||
1106 | doublereal* work = (doublereal*)malloc(lwork*sizeof(doublereal)); | ||
1107 | dsytrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res); | ||
1108 | CHECK(res,res); | ||
1109 | int k; | ||
1110 | for (k=0; k<n; k++) { | ||
1111 | ipivp[k] = auxipiv[k]; | ||
1112 | } | ||
1113 | free(auxipiv); | ||
1114 | free(work); | ||
1115 | OK | ||
1116 | } | ||
1117 | |||
1118 | |||
1119 | int zhetrf_(char *uplo, integer *n, doublecomplex *a, integer *lda, integer *ipiv, | ||
1120 | doublecomplex *work, integer *lwork, integer *info); | ||
1121 | |||
1122 | int ldl_C(DVEC(ipiv), OCMAT(r)) { | ||
1123 | integer n = rr; | ||
1124 | REQUIRES(n>=1 && rc==n && ipivn == n, BAD_SIZE); | ||
1125 | DEBUGMSG("ldl_R"); | ||
1126 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1127 | integer res; | ||
1128 | integer lda = rXc; | ||
1129 | integer lwork = -1; | ||
1130 | doublecomplex ans; | ||
1131 | zhetrf_ ("L",&n,rp,&lda,auxipiv,&ans,&lwork,&res); | ||
1132 | lwork = ceil(ans.r); | ||
1133 | doublecomplex* work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | ||
1134 | zhetrf_ ("L",&n,rp,&lda,auxipiv,work,&lwork,&res); | ||
1135 | CHECK(res,res); | ||
1136 | int k; | ||
1137 | for (k=0; k<n; k++) { | ||
1138 | ipivp[k] = auxipiv[k]; | ||
1139 | } | ||
1140 | free(auxipiv); | ||
1141 | free(work); | ||
1142 | OK | ||
1143 | |||
1144 | } | ||
1145 | |||
1146 | //////////////////// LDL solve ///////////////////////// | ||
1147 | |||
1148 | int dsytrs_(char *uplo, integer *n, integer *nrhs, doublereal *a, integer *lda, | ||
1149 | integer *ipiv, doublereal *b, integer *ldb, integer *info); | ||
1150 | |||
1151 | int ldl_S_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) { | ||
1152 | integer m = ar; | ||
1153 | integer n = ac; | ||
1154 | integer lda = aXc; | ||
1155 | integer mrhs = br; | ||
1156 | integer nrhs = bc; | ||
1157 | |||
1158 | REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); | ||
1159 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1160 | int k; | ||
1161 | for (k=0; k<n; k++) { | ||
1162 | auxipiv[k] = (integer)ipivp[k]; | ||
1163 | } | ||
1164 | integer res; | ||
1165 | dsytrs_ ("L",&n,&nrhs,(/*no const (!?)*/ double*)ap,&lda,auxipiv,bp,&mrhs,&res); | ||
1166 | CHECK(res,res); | ||
1167 | free(auxipiv); | ||
1168 | OK | ||
1169 | } | ||
1170 | |||
1171 | |||
1172 | int zhetrs_(char *uplo, integer *n, integer *nrhs, doublecomplex *a, integer *lda, | ||
1173 | integer *ipiv, doublecomplex *b, integer *ldb, integer *info); | ||
1174 | |||
1175 | int ldl_S_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) { | ||
1176 | integer m = ar; | ||
1177 | integer n = ac; | ||
1178 | integer lda = aXc; | ||
1179 | integer mrhs = br; | ||
1180 | integer nrhs = bc; | ||
1181 | |||
1182 | REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); | ||
1183 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
1184 | int k; | ||
1185 | for (k=0; k<n; k++) { | ||
1186 | auxipiv[k] = (integer)ipivp[k]; | ||
1187 | } | ||
1188 | integer res; | ||
1189 | zhetrs_ ("L",&n,&nrhs,(doublecomplex*)ap,&lda,auxipiv,bp,&mrhs,&res); | ||
1190 | CHECK(res,res); | ||
1191 | free(auxipiv); | ||
1192 | OK | ||
1193 | } | ||
1194 | |||
1195 | |||
1089 | //////////////////// Matrix Product ///////////////////////// | 1196 | //////////////////// Matrix Product ///////////////////////// |
1090 | 1197 | ||
1091 | void dgemm_(char *, char *, integer *, integer *, integer *, | 1198 | void dgemm_(char *, char *, integer *, integer *, integer *, |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index f2fc68d..c2c140b 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -591,7 +591,7 @@ foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok | |||
591 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | 591 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double |
592 | lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv b | 592 | lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv b |
593 | 593 | ||
594 | -- | Solve a real linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/. | 594 | -- | Solve a complex linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/. |
595 | lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) | 595 | lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) |
596 | lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b | 596 | lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b |
597 | 597 | ||
@@ -600,10 +600,41 @@ lusAux f st a piv b | |||
600 | x <- copy ColumnMajor b | 600 | x <- copy ColumnMajor b |
601 | f # a # piv' # x #| st | 601 | f # a # piv' # x #| st |
602 | return x | 602 | return x |
603 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" | 603 | | otherwise = error st |
604 | where | 604 | where |
605 | n1 = rows a | 605 | n1 = rows a |
606 | n2 = cols a | 606 | n2 = cols a |
607 | n = rows b | 607 | n = rows b |
608 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | 608 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double |
609 | 609 | ||
610 | ----------------------------------------------------------------------------------- | ||
611 | foreign import ccall unsafe "ldl_R" dsytrf :: R :> R ::> Ok | ||
612 | foreign import ccall unsafe "ldl_C" zhetrf :: R :> C ::> Ok | ||
613 | |||
614 | -- | LDL factorization of a symmetric real matrix, using LAPACK's /dsytrf/. | ||
615 | ldlR :: Matrix Double -> (Matrix Double, [Int]) | ||
616 | ldlR = ldlAux dsytrf "ldlR" | ||
617 | |||
618 | -- | LDL factorization of a hermitian complex matrix, using LAPACK's /zhetrf/. | ||
619 | ldlC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) | ||
620 | ldlC = ldlAux zhetrf "ldlC" | ||
621 | |||
622 | ldlAux f st a = unsafePerformIO $ do | ||
623 | ldl <- copy ColumnMajor a | ||
624 | piv <- createVector (rows a) | ||
625 | f # piv # ldl #| st | ||
626 | return (ldl, map (pred.round) (toList piv)) | ||
627 | |||
628 | ----------------------------------------------------------------------------------- | ||
629 | |||
630 | foreign import ccall unsafe "ldl_S_R" dsytrs :: R ::> R :> R ::> Ok | ||
631 | foreign import ccall unsafe "ldl_S_C" zsytrs :: C ::> R :> C ::> Ok | ||
632 | |||
633 | -- | Solve a real linear system from a precomputed LDL decomposition ('ldlR'), using LAPACK's /dsytrs/. | ||
634 | ldlsR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | ||
635 | ldlsR a piv b = lusAux dsytrs "ldlsR" (fmat a) piv b | ||
636 | |||
637 | -- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/. | ||
638 | ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
639 | ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b | ||
640 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 0b8abbb..dd4cc67 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -15,7 +15,7 @@ module Numeric.LinearAlgebra ( | |||
15 | -- * Basic types and data processing | 15 | -- * Basic types and data processing |
16 | module Numeric.LinearAlgebra.Data, | 16 | module Numeric.LinearAlgebra.Data, |
17 | 17 | ||
18 | -- * Arithmetic and numeric classes | 18 | -- * Numeric classes |
19 | -- | | 19 | -- | |
20 | -- The standard numeric classes are defined elementwise: | 20 | -- The standard numeric classes are defined elementwise: |
21 | -- | 21 | -- |
@@ -27,7 +27,9 @@ module Numeric.LinearAlgebra ( | |||
27 | -- [ 1.0, 0.0, 0.0 | 27 | -- [ 1.0, 0.0, 0.0 |
28 | -- , 0.0, 5.0, 0.0 | 28 | -- , 0.0, 5.0, 0.0 |
29 | -- , 0.0, 0.0, 9.0 ] | 29 | -- , 0.0, 0.0, 9.0 ] |
30 | -- | 30 | |
31 | -- * Autoconformable dimensions | ||
32 | -- | | ||
31 | -- In arithmetic operations single-element vectors and matrices | 33 | -- In arithmetic operations single-element vectors and matrices |
32 | -- (created from numeric literals or using 'scalar') automatically | 34 | -- (created from numeric literals or using 'scalar') automatically |
33 | -- expand to match the dimensions of the other operand: | 35 | -- expand to match the dimensions of the other operand: |
@@ -79,6 +81,7 @@ module Numeric.LinearAlgebra ( | |||
79 | luSolve, | 81 | luSolve, |
80 | luSolve', | 82 | luSolve', |
81 | cholSolve, | 83 | cholSolve, |
84 | ldlSolve, | ||
82 | cgSolve, | 85 | cgSolve, |
83 | cgSolve', | 86 | cgSolve', |
84 | 87 | ||
@@ -115,15 +118,18 @@ module Numeric.LinearAlgebra ( | |||
115 | -- * Cholesky | 118 | -- * Cholesky |
116 | chol, cholSH, mbCholSH, | 119 | chol, cholSH, mbCholSH, |
117 | 120 | ||
121 | -- * LU | ||
122 | lu, luPacked, luPacked', luFact, | ||
123 | |||
124 | -- * LDL | ||
125 | ldlPacked, ldlPackedSH, | ||
126 | |||
118 | -- * Hessenberg | 127 | -- * Hessenberg |
119 | hess, | 128 | hess, |
120 | 129 | ||
121 | -- * Schur | 130 | -- * Schur |
122 | schur, | 131 | schur, |
123 | 132 | ||
124 | -- * LU | ||
125 | lu, luPacked, luPacked', luFact, | ||
126 | |||
127 | -- * Matrix functions | 133 | -- * Matrix functions |
128 | expm, | 134 | expm, |
129 | sqrtm, | 135 | sqrtm, |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index d9bc9a0..2ff1580 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -587,6 +587,9 @@ runTests n = do | |||
587 | putStrLn "------ luSolve" | 587 | putStrLn "------ luSolve" |
588 | test (linearSolveProp (luSolve.luPacked) . rSqWC) | 588 | test (linearSolveProp (luSolve.luPacked) . rSqWC) |
589 | test (linearSolveProp (luSolve.luPacked) . cSqWC) | 589 | test (linearSolveProp (luSolve.luPacked) . cSqWC) |
590 | putStrLn "------ ldlSolve" | ||
591 | test (linearSolveProp (ldlSolve.ldlPacked) . rSymWC) | ||
592 | test (linearSolveProp (ldlSolve.ldlPacked) . cSymWC) | ||
590 | putStrLn "------ cholSolve" | 593 | putStrLn "------ cholSolve" |
591 | test (linearSolveProp (cholSolve.chol) . rPosDef) | 594 | test (linearSolveProp (cholSolve.chol) . rPosDef) |
592 | test (linearSolveProp (cholSolve.chol) . cPosDef) | 595 | test (linearSolveProp (cholSolve.chol) . cPosDef) |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs index 904ae05..7c54535 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests/Instances.hs | |||
@@ -16,7 +16,7 @@ module Numeric.LinearAlgebra.Tests.Instances( | |||
16 | Rot(..), rRot,cRot, | 16 | Rot(..), rRot,cRot, |
17 | Her(..), rHer,cHer, | 17 | Her(..), rHer,cHer, |
18 | WC(..), rWC,cWC, | 18 | WC(..), rWC,cWC, |
19 | SqWC(..), rSqWC, cSqWC, | 19 | SqWC(..), rSqWC, cSqWC, rSymWC, cSymWC, |
20 | PosDef(..), rPosDef, cPosDef, | 20 | PosDef(..), rPosDef, cPosDef, |
21 | Consistent(..), rConsist, cConsist, | 21 | Consistent(..), rConsist, cConsist, |
22 | RM,CM, rM,cM, | 22 | RM,CM, rM,cM, |
@@ -176,6 +176,9 @@ cWC (WC m) = m :: CM | |||
176 | rSqWC (SqWC m) = m :: RM | 176 | rSqWC (SqWC m) = m :: RM |
177 | cSqWC (SqWC m) = m :: CM | 177 | cSqWC (SqWC m) = m :: CM |
178 | 178 | ||
179 | rSymWC (SqWC m) = m + tr m :: RM | ||
180 | cSymWC (SqWC m) = m + tr m :: CM | ||
181 | |||
179 | rPosDef (PosDef m) = m :: RM | 182 | rPosDef (PosDef m) = m :: RM |
180 | cPosDef (PosDef m) = m :: CM | 183 | cPosDef (PosDef m) = m :: CM |
181 | 184 | ||