summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2008-10-27 13:03:41 +0000
committerAlberto Ruiz <aruiz@um.es>2008-10-27 13:03:41 +0000
commitedf12982f21c56c21bfc21eb2b2fcbc406838130 (patch)
tree4f9463ceccc49dc5b9dfdf77b16dccef9bc8d3e5 /lib
parentd8639b28ec9e83b54b45c987508d270d5469451c (diff)
added luSolve
Diffstat (limited to 'lib')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs19
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK.hs29
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c27
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h3
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs8
-rw-r--r--lib/Numeric/LinearAlgebra/Tests/Properties.hs7
6 files changed, 68 insertions, 25 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index e2fec9d..75f4ba3 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -40,7 +40,7 @@ module Numeric.LinearAlgebra.Algorithms (
40-- ** Schur 40-- ** Schur
41 schur, 41 schur,
42-- ** LU 42-- ** LU
43 lu, 43 lu, luPacked, luSolve,
44-- * Matrix functions 44-- * Matrix functions
45 expm, 45 expm,
46 sqrtm, 46 sqrtm,
@@ -77,8 +77,13 @@ import Foreign.C.Types
77class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where 77class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
78 -- | Singular value decomposition using lapack's dgesvd or zgesvd. 78 -- | Singular value decomposition using lapack's dgesvd or zgesvd.
79 svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) 79 svd :: Matrix t -> (Matrix t, Vector Double, Matrix t)
80 -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'.
80 luPacked :: Matrix t -> (Matrix t, [Int]) 81 luPacked :: Matrix t -> (Matrix t, [Int])
81 -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. 82 -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization
83 -- obtained by 'luPacked'.
84 luSolve :: (Matrix t, [Int]) -> Matrix t -> Matrix t
85 -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv or zgesv.
86 -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system.
82 -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". 87 -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK".
83 linearSolve :: Matrix t -> Matrix t -> Matrix t 88 linearSolve :: Matrix t -> Matrix t -> Matrix t
84 linearSolveSVD :: Matrix t -> Matrix t -> Matrix t 89 linearSolveSVD :: Matrix t -> Matrix t -> Matrix t
@@ -110,13 +115,15 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
110 schur :: Matrix t -> (Matrix t, Matrix t) 115 schur :: Matrix t -> (Matrix t, Matrix t)
111 -- | Conjugate transpose. 116 -- | Conjugate transpose.
112 ctrans :: Matrix t -> Matrix t 117 ctrans :: Matrix t -> Matrix t
118 -- | Matrix product.
113 multiply :: Matrix t -> Matrix t -> Matrix t 119 multiply :: Matrix t -> Matrix t -> Matrix t
114 120
115 121
116instance Field Double where 122instance Field Double where
117 svd = svdR 123 svd = svdR
118 luPacked = luR 124 luPacked = luR
119 linearSolve = linearSolveR 125 luSolve (l_u,perm) = lusR l_u perm
126 linearSolve = linearSolveR -- (luSolve . luPacked) ??
120 linearSolveSVD = linearSolveSVDR Nothing 127 linearSolveSVD = linearSolveSVDR Nothing
121 ctrans = trans 128 ctrans = trans
122 eig = eigR 129 eig = eigR
@@ -130,6 +137,7 @@ instance Field Double where
130instance Field (Complex Double) where 137instance Field (Complex Double) where
131 svd = svdC 138 svd = svdC
132 luPacked = luC 139 luPacked = luC
140 luSolve (l_u,perm) = lusC l_u perm
133 linearSolve = linearSolveC 141 linearSolve = linearSolveC
134 linearSolveSVD = linearSolveSVDC Nothing 142 linearSolveSVD = linearSolveSVDC Nothing
135 ctrans = conj . trans 143 ctrans = conj . trans
@@ -165,7 +173,7 @@ det m | square m = s * (product $ toList $ takeDiag $ lup)
165 where (lup,perm) = luPacked m 173 where (lup,perm) = luPacked m
166 s = signlp (rows m) perm 174 s = signlp (rows m) perm
167 175
168-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. 176-- | Explicit LU factorization of a general matrix using lapack's dgetrf or zgetrf.
169-- 177--
170-- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular, 178-- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular,
171-- u is upper triangular, p is a permutation matrix and s is the signature of the permutation. 179-- u is upper triangular, p is a permutation matrix and s is the signature of the permutation.
@@ -513,7 +521,7 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s)
513 521
514-------------------------------------------------- 522--------------------------------------------------
515 523
516-- | euclidean inner product 524-- | Euclidean inner product.
517dot :: (Field t) => Vector t -> Vector t -> t 525dot :: (Field t) => Vector t -> Vector t -> t
518dot u v = multiply r c @@> (0,0) 526dot u v = multiply r c @@> (0,0)
519 where r = asRow u 527 where r = asRow u
@@ -629,5 +637,4 @@ multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat
629 free palpha 637 free palpha
630 free pbeta 638 free pbeta
631 return s 639 return s
632 -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s))
633 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" 640 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs
index d78b506..8bc2492 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK.hs
+++ b/lib/Numeric/LinearAlgebra/LAPACK.hs
@@ -19,7 +19,7 @@ module Numeric.LinearAlgebra.LAPACK (
19 linearSolveR, linearSolveC, 19 linearSolveR, linearSolveC,
20 linearSolveLSR, linearSolveLSC, 20 linearSolveLSR, linearSolveLSC,
21 linearSolveSVDR, linearSolveSVDC, 21 linearSolveSVDR, linearSolveSVDC,
22 luR, luC, lusR, 22 luR, luC, lusR, lusC,
23 cholS, cholH, 23 cholS, cholH,
24 qrR, qrC, 24 qrR, qrC,
25 hessR, hessC, 25 hessR, hessC,
@@ -34,6 +34,7 @@ import Data.Packed.Matrix
34import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) 34import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale))
35import Complex 35import Complex
36import Foreign 36import Foreign
37import Foreign.C.Types (CInt)
37 38
38----------------------------------------------------------------------------- 39-----------------------------------------------------------------------------
39foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM 40foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM
@@ -338,17 +339,29 @@ luAux f st a = unsafePerformIO $ do
338 where n = rows a 339 where n = rows a
339 m = cols a 340 m = cols a
340 341
341
342----------------------------------------------------------------------------------- 342-----------------------------------------------------------------------------------
343type TW a = CInt -> PD -> a
344type TQ a = CInt -> CInt -> PC -> a
345
343foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM 346foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM
347foreign import ccall "LAPACK/lapack-aux.h luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt))))
344 348
349-- | Wrapper for LAPACK's /dgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition.
345lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double 350lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
346lusR a piv b = lusR' (fmat a) piv (fmat b) 351lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b)
352
353-- | Wrapper for LAPACK's /zgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition.
354lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
355lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)
347 356
348lusR' a piv b = unsafePerformIO $ do 357lusAux f st a piv b
358 | n1==n2 && n2==n =unsafePerformIO $ do
349 x <- createMatrix ColumnMajor n m 359 x <- createMatrix ColumnMajor n m
350 app4 dgetrs mat a vec piv' mat b mat x "lusR" 360 app4 f mat a vec piv' mat b mat x st
351 return x 361 return x
352 where n = rows b 362 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
353 m = cols b 363 where n1 = rows a
354 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double 364 n2 = cols a
365 n = rows b
366 m = cols b
367 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index de3cc98..842b5ad 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -820,9 +820,8 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) {
820 820
821 821
822//////////////////// LU substitution ///////////////////////// 822//////////////////// LU substitution /////////////////////////
823char charN = 'N';
824 823
825int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { 824int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) {
826 integer m = ar; 825 integer m = ar;
827 integer n = ac; 826 integer n = ac;
828 integer mrhs = br; 827 integer mrhs = br;
@@ -836,10 +835,28 @@ int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) {
836 } 835 }
837 integer res; 836 integer res;
838 memcpy(xp,bp,mrhs*nrhs*sizeof(double)); 837 memcpy(xp,bp,mrhs*nrhs*sizeof(double));
839 integer ldb = mrhs; 838 dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res);
840 integer lda = n;
841 dgetrs_ (&charN,&n,&nrhs,ap,&lda,auxipiv,xp,&ldb,&res);
842 CHECK(res,res); 839 CHECK(res,res);
843 free(auxipiv); 840 free(auxipiv);
844 OK 841 OK
845} 842}
843
844int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) {
845 integer m = ar;
846 integer n = ac;
847 integer mrhs = br;
848 integer nrhs = bc;
849
850 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
851 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
852 int k;
853 for (k=0; k<n; k++) {
854 auxipiv[k] = (integer)ipivp[k];
855 }
856 integer res;
857 memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex));
858 zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,(doublecomplex*)xp,&mrhs,&res);
859 CHECK(res,res);
860 free(auxipiv);
861 OK
862}
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index e98e9dc..23e5e28 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -85,4 +85,5 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s));
85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); 85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r));
86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); 86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r));
87 87
88int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)); 88int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x));
89int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x));
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 7ebd1f2..9617a7a 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -119,7 +119,6 @@ rotTest = fun (10^5) :~12~: rot 5E4
119 where fun n = foldl1' (<>) (map rot angles) 119 where fun n = foldl1' (<>) (map rot angles)
120 where angles = toList $ linspace n (0,1) 120 where angles = toList $ linspace n (0,1)
121 121
122
123-- | All tests must pass with a maximum dimension of about 20 122-- | All tests must pass with a maximum dimension of about 20
124-- (some tests may fail with bigger sizes due to precision loss). 123-- (some tests may fail with bigger sizes due to precision loss).
125runTests :: Int -- ^ maximum dimension 124runTests :: Int -- ^ maximum dimension
@@ -135,10 +134,13 @@ runTests n = do
135 putStrLn "------ lu" 134 putStrLn "------ lu"
136 test (luProp . rM) 135 test (luProp . rM)
137 test (luProp . cM) 136 test (luProp . cM)
138 putStrLn "------ inv" 137 putStrLn "------ inv (linearSolve)"
139 test (invProp . rSqWC) 138 test (invProp . rSqWC)
140 test (invProp . cSqWC) 139 test (invProp . cSqWC)
141 putStrLn "------ pinv" 140 putStrLn "------ luSolve"
141 test (linearSolveProp (luSolve.luPacked) . rSqWC)
142 test (linearSolveProp (luSolve.luPacked) . cSqWC)
143 putStrLn "------ pinv (linearSolveSVD)"
142 test (pinvProp . rM) 144 test (pinvProp . rM)
143 if os == "mingw32" 145 if os == "mingw32"
144 then putStrLn "complex pinvTest skipped in this OS" 146 then putStrLn "complex pinvTest skipped in this OS"
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
index b5321c2..45b03a2 100644
--- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs
+++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs
@@ -35,7 +35,8 @@ module Numeric.LinearAlgebra.Tests.Properties (
35 schurProp1, schurProp2, 35 schurProp1, schurProp2,
36 cholProp, 36 cholProp,
37 expmDiagProp, 37 expmDiagProp,
38 multProp1, multProp2 38 multProp1, multProp2,
39 linearSolveProp
39) where 40) where
40 41
41import Numeric.LinearAlgebra 42import Numeric.LinearAlgebra
@@ -153,4 +154,6 @@ expmDiagProp m = expm (logm m) :~ 7 ~: complex m
153 154
154multProp1 (a,b) = a <> b |~| mulH a b 155multProp1 (a,b) = a <> b |~| mulH a b
155 156
156multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a 157multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a
158
159linearSolveProp f m = f m m |~| ident (rows m)