diff options
author | Alberto Ruiz <aruiz@um.es> | 2008-10-27 13:03:41 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2008-10-27 13:03:41 +0000 |
commit | edf12982f21c56c21bfc21eb2b2fcbc406838130 (patch) | |
tree | 4f9463ceccc49dc5b9dfdf77b16dccef9bc8d3e5 | |
parent | d8639b28ec9e83b54b45c987508d270d5469451c (diff) |
added luSolve
-rw-r--r-- | hmatrix.cabal | 2 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 19 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK.hs | 29 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 27 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 3 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 8 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Properties.hs | 7 |
7 files changed, 69 insertions, 26 deletions
diff --git a/hmatrix.cabal b/hmatrix.cabal index 6bffaf9..3307870 100644 --- a/hmatrix.cabal +++ b/hmatrix.cabal | |||
@@ -1,5 +1,5 @@ | |||
1 | Name: hmatrix | 1 | Name: hmatrix |
2 | Version: 0.5.0.0 | 2 | Version: 0.5.0.1 |
3 | License: GPL | 3 | License: GPL |
4 | License-file: LICENSE | 4 | License-file: LICENSE |
5 | Author: Alberto Ruiz | 5 | Author: Alberto Ruiz |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index e2fec9d..75f4ba3 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -40,7 +40,7 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
40 | -- ** Schur | 40 | -- ** Schur |
41 | schur, | 41 | schur, |
42 | -- ** LU | 42 | -- ** LU |
43 | lu, | 43 | lu, luPacked, luSolve, |
44 | -- * Matrix functions | 44 | -- * Matrix functions |
45 | expm, | 45 | expm, |
46 | sqrtm, | 46 | sqrtm, |
@@ -77,8 +77,13 @@ import Foreign.C.Types | |||
77 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | 77 | class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where |
78 | -- | Singular value decomposition using lapack's dgesvd or zgesvd. | 78 | -- | Singular value decomposition using lapack's dgesvd or zgesvd. |
79 | svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 79 | svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
80 | -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. | ||
80 | luPacked :: Matrix t -> (Matrix t, [Int]) | 81 | luPacked :: Matrix t -> (Matrix t, [Int]) |
81 | -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. | 82 | -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization |
83 | -- obtained by 'luPacked'. | ||
84 | luSolve :: (Matrix t, [Int]) -> Matrix t -> Matrix t | ||
85 | -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv or zgesv. | ||
86 | -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. | ||
82 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". | 87 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". |
83 | linearSolve :: Matrix t -> Matrix t -> Matrix t | 88 | linearSolve :: Matrix t -> Matrix t -> Matrix t |
84 | linearSolveSVD :: Matrix t -> Matrix t -> Matrix t | 89 | linearSolveSVD :: Matrix t -> Matrix t -> Matrix t |
@@ -110,13 +115,15 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where | |||
110 | schur :: Matrix t -> (Matrix t, Matrix t) | 115 | schur :: Matrix t -> (Matrix t, Matrix t) |
111 | -- | Conjugate transpose. | 116 | -- | Conjugate transpose. |
112 | ctrans :: Matrix t -> Matrix t | 117 | ctrans :: Matrix t -> Matrix t |
118 | -- | Matrix product. | ||
113 | multiply :: Matrix t -> Matrix t -> Matrix t | 119 | multiply :: Matrix t -> Matrix t -> Matrix t |
114 | 120 | ||
115 | 121 | ||
116 | instance Field Double where | 122 | instance Field Double where |
117 | svd = svdR | 123 | svd = svdR |
118 | luPacked = luR | 124 | luPacked = luR |
119 | linearSolve = linearSolveR | 125 | luSolve (l_u,perm) = lusR l_u perm |
126 | linearSolve = linearSolveR -- (luSolve . luPacked) ?? | ||
120 | linearSolveSVD = linearSolveSVDR Nothing | 127 | linearSolveSVD = linearSolveSVDR Nothing |
121 | ctrans = trans | 128 | ctrans = trans |
122 | eig = eigR | 129 | eig = eigR |
@@ -130,6 +137,7 @@ instance Field Double where | |||
130 | instance Field (Complex Double) where | 137 | instance Field (Complex Double) where |
131 | svd = svdC | 138 | svd = svdC |
132 | luPacked = luC | 139 | luPacked = luC |
140 | luSolve (l_u,perm) = lusC l_u perm | ||
133 | linearSolve = linearSolveC | 141 | linearSolve = linearSolveC |
134 | linearSolveSVD = linearSolveSVDC Nothing | 142 | linearSolveSVD = linearSolveSVDC Nothing |
135 | ctrans = conj . trans | 143 | ctrans = conj . trans |
@@ -165,7 +173,7 @@ det m | square m = s * (product $ toList $ takeDiag $ lup) | |||
165 | where (lup,perm) = luPacked m | 173 | where (lup,perm) = luPacked m |
166 | s = signlp (rows m) perm | 174 | s = signlp (rows m) perm |
167 | 175 | ||
168 | -- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. | 176 | -- | Explicit LU factorization of a general matrix using lapack's dgetrf or zgetrf. |
169 | -- | 177 | -- |
170 | -- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular, | 178 | -- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular, |
171 | -- u is upper triangular, p is a permutation matrix and s is the signature of the permutation. | 179 | -- u is upper triangular, p is a permutation matrix and s is the signature of the permutation. |
@@ -513,7 +521,7 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) | |||
513 | 521 | ||
514 | -------------------------------------------------- | 522 | -------------------------------------------------- |
515 | 523 | ||
516 | -- | euclidean inner product | 524 | -- | Euclidean inner product. |
517 | dot :: (Field t) => Vector t -> Vector t -> t | 525 | dot :: (Field t) => Vector t -> Vector t -> t |
518 | dot u v = multiply r c @@> (0,0) | 526 | dot u v = multiply r c @@> (0,0) |
519 | where r = asRow u | 527 | where r = asRow u |
@@ -629,5 +637,4 @@ multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat | |||
629 | free palpha | 637 | free palpha |
630 | free pbeta | 638 | free pbeta |
631 | return s | 639 | return s |
632 | -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s)) | ||
633 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | 640 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" |
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index d78b506..8bc2492 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -19,7 +19,7 @@ module Numeric.LinearAlgebra.LAPACK ( | |||
19 | linearSolveR, linearSolveC, | 19 | linearSolveR, linearSolveC, |
20 | linearSolveLSR, linearSolveLSC, | 20 | linearSolveLSR, linearSolveLSC, |
21 | linearSolveSVDR, linearSolveSVDC, | 21 | linearSolveSVDR, linearSolveSVDC, |
22 | luR, luC, lusR, | 22 | luR, luC, lusR, lusC, |
23 | cholS, cholH, | 23 | cholS, cholH, |
24 | qrR, qrC, | 24 | qrR, qrC, |
25 | hessR, hessC, | 25 | hessR, hessC, |
@@ -34,6 +34,7 @@ import Data.Packed.Matrix | |||
34 | import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) | 34 | import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) |
35 | import Complex | 35 | import Complex |
36 | import Foreign | 36 | import Foreign |
37 | import Foreign.C.Types (CInt) | ||
37 | 38 | ||
38 | ----------------------------------------------------------------------------- | 39 | ----------------------------------------------------------------------------- |
39 | foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM | 40 | foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM |
@@ -338,17 +339,29 @@ luAux f st a = unsafePerformIO $ do | |||
338 | where n = rows a | 339 | where n = rows a |
339 | m = cols a | 340 | m = cols a |
340 | 341 | ||
341 | |||
342 | ----------------------------------------------------------------------------------- | 342 | ----------------------------------------------------------------------------------- |
343 | type TW a = CInt -> PD -> a | ||
344 | type TQ a = CInt -> CInt -> PC -> a | ||
345 | |||
343 | foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM | 346 | foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM |
347 | foreign import ccall "LAPACK/lapack-aux.h luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt)))) | ||
344 | 348 | ||
349 | -- | Wrapper for LAPACK's /dgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition. | ||
345 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | 350 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double |
346 | lusR a piv b = lusR' (fmat a) piv (fmat b) | 351 | lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b) |
352 | |||
353 | -- | Wrapper for LAPACK's /zgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition. | ||
354 | lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
355 | lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) | ||
347 | 356 | ||
348 | lusR' a piv b = unsafePerformIO $ do | 357 | lusAux f st a piv b |
358 | | n1==n2 && n2==n =unsafePerformIO $ do | ||
349 | x <- createMatrix ColumnMajor n m | 359 | x <- createMatrix ColumnMajor n m |
350 | app4 dgetrs mat a vec piv' mat b mat x "lusR" | 360 | app4 f mat a vec piv' mat b mat x st |
351 | return x | 361 | return x |
352 | where n = rows b | 362 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" |
353 | m = cols b | 363 | where n1 = rows a |
354 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | 364 | n2 = cols a |
365 | n = rows b | ||
366 | m = cols b | ||
367 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index de3cc98..842b5ad 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -820,9 +820,8 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | |||
820 | 820 | ||
821 | 821 | ||
822 | //////////////////// LU substitution ///////////////////////// | 822 | //////////////////// LU substitution ///////////////////////// |
823 | char charN = 'N'; | ||
824 | 823 | ||
825 | int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { | 824 | int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { |
826 | integer m = ar; | 825 | integer m = ar; |
827 | integer n = ac; | 826 | integer n = ac; |
828 | integer mrhs = br; | 827 | integer mrhs = br; |
@@ -836,10 +835,28 @@ int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { | |||
836 | } | 835 | } |
837 | integer res; | 836 | integer res; |
838 | memcpy(xp,bp,mrhs*nrhs*sizeof(double)); | 837 | memcpy(xp,bp,mrhs*nrhs*sizeof(double)); |
839 | integer ldb = mrhs; | 838 | dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res); |
840 | integer lda = n; | ||
841 | dgetrs_ (&charN,&n,&nrhs,ap,&lda,auxipiv,xp,&ldb,&res); | ||
842 | CHECK(res,res); | 839 | CHECK(res,res); |
843 | free(auxipiv); | 840 | free(auxipiv); |
844 | OK | 841 | OK |
845 | } | 842 | } |
843 | |||
844 | int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { | ||
845 | integer m = ar; | ||
846 | integer n = ac; | ||
847 | integer mrhs = br; | ||
848 | integer nrhs = bc; | ||
849 | |||
850 | REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); | ||
851 | integer* auxipiv = (integer*)malloc(n*sizeof(integer)); | ||
852 | int k; | ||
853 | for (k=0; k<n; k++) { | ||
854 | auxipiv[k] = (integer)ipivp[k]; | ||
855 | } | ||
856 | integer res; | ||
857 | memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex)); | ||
858 | zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,(doublecomplex*)xp,&mrhs,&res); | ||
859 | CHECK(res,res); | ||
860 | free(auxipiv); | ||
861 | OK | ||
862 | } | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index e98e9dc..23e5e28 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -85,4 +85,5 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); | |||
85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); | 85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); |
86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); | 86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); |
87 | 87 | ||
88 | int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)); | 88 | int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)); |
89 | int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)); | ||
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 7ebd1f2..9617a7a 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -119,7 +119,6 @@ rotTest = fun (10^5) :~12~: rot 5E4 | |||
119 | where fun n = foldl1' (<>) (map rot angles) | 119 | where fun n = foldl1' (<>) (map rot angles) |
120 | where angles = toList $ linspace n (0,1) | 120 | where angles = toList $ linspace n (0,1) |
121 | 121 | ||
122 | |||
123 | -- | All tests must pass with a maximum dimension of about 20 | 122 | -- | All tests must pass with a maximum dimension of about 20 |
124 | -- (some tests may fail with bigger sizes due to precision loss). | 123 | -- (some tests may fail with bigger sizes due to precision loss). |
125 | runTests :: Int -- ^ maximum dimension | 124 | runTests :: Int -- ^ maximum dimension |
@@ -135,10 +134,13 @@ runTests n = do | |||
135 | putStrLn "------ lu" | 134 | putStrLn "------ lu" |
136 | test (luProp . rM) | 135 | test (luProp . rM) |
137 | test (luProp . cM) | 136 | test (luProp . cM) |
138 | putStrLn "------ inv" | 137 | putStrLn "------ inv (linearSolve)" |
139 | test (invProp . rSqWC) | 138 | test (invProp . rSqWC) |
140 | test (invProp . cSqWC) | 139 | test (invProp . cSqWC) |
141 | putStrLn "------ pinv" | 140 | putStrLn "------ luSolve" |
141 | test (linearSolveProp (luSolve.luPacked) . rSqWC) | ||
142 | test (linearSolveProp (luSolve.luPacked) . cSqWC) | ||
143 | putStrLn "------ pinv (linearSolveSVD)" | ||
142 | test (pinvProp . rM) | 144 | test (pinvProp . rM) |
143 | if os == "mingw32" | 145 | if os == "mingw32" |
144 | then putStrLn "complex pinvTest skipped in this OS" | 146 | then putStrLn "complex pinvTest skipped in this OS" |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index b5321c2..45b03a2 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -35,7 +35,8 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
35 | schurProp1, schurProp2, | 35 | schurProp1, schurProp2, |
36 | cholProp, | 36 | cholProp, |
37 | expmDiagProp, | 37 | expmDiagProp, |
38 | multProp1, multProp2 | 38 | multProp1, multProp2, |
39 | linearSolveProp | ||
39 | ) where | 40 | ) where |
40 | 41 | ||
41 | import Numeric.LinearAlgebra | 42 | import Numeric.LinearAlgebra |
@@ -153,4 +154,6 @@ expmDiagProp m = expm (logm m) :~ 7 ~: complex m | |||
153 | 154 | ||
154 | multProp1 (a,b) = a <> b |~| mulH a b | 155 | multProp1 (a,b) = a <> b |~| mulH a b |
155 | 156 | ||
156 | multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a | 157 | multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a |
158 | |||
159 | linearSolveProp f m = f m m |~| ident (rows m) | ||