From edf12982f21c56c21bfc21eb2b2fcbc406838130 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 27 Oct 2008 13:03:41 +0000 Subject: added luSolve --- lib/Numeric/LinearAlgebra/Algorithms.hs | 19 ++++++++++++------ lib/Numeric/LinearAlgebra/LAPACK.hs | 29 +++++++++++++++++++-------- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 27 ++++++++++++++++++++----- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 3 ++- lib/Numeric/LinearAlgebra/Tests.hs | 8 +++++--- lib/Numeric/LinearAlgebra/Tests/Properties.hs | 7 +++++-- 6 files changed, 68 insertions(+), 25 deletions(-) (limited to 'lib/Numeric/LinearAlgebra') diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index e2fec9d..75f4ba3 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -40,7 +40,7 @@ module Numeric.LinearAlgebra.Algorithms ( -- ** Schur schur, -- ** LU - lu, + lu, luPacked, luSolve, -- * Matrix functions expm, sqrtm, @@ -77,8 +77,13 @@ import Foreign.C.Types class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where -- | Singular value decomposition using lapack's dgesvd or zgesvd. svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) + -- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'. luPacked :: Matrix t -> (Matrix t, [Int]) - -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. + -- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization + -- obtained by 'luPacked'. + luSolve :: (Matrix t, [Int]) -> Matrix t -> Matrix t + -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv or zgesv. + -- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system. -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". linearSolve :: Matrix t -> Matrix t -> Matrix t linearSolveSVD :: Matrix t -> Matrix t -> Matrix t @@ -110,13 +115,15 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where schur :: Matrix t -> (Matrix t, Matrix t) -- | Conjugate transpose. ctrans :: Matrix t -> Matrix t + -- | Matrix product. multiply :: Matrix t -> Matrix t -> Matrix t instance Field Double where svd = svdR luPacked = luR - linearSolve = linearSolveR + luSolve (l_u,perm) = lusR l_u perm + linearSolve = linearSolveR -- (luSolve . luPacked) ?? linearSolveSVD = linearSolveSVDR Nothing ctrans = trans eig = eigR @@ -130,6 +137,7 @@ instance Field Double where instance Field (Complex Double) where svd = svdC luPacked = luC + luSolve (l_u,perm) = lusC l_u perm linearSolve = linearSolveC linearSolveSVD = linearSolveSVDC Nothing ctrans = conj . trans @@ -165,7 +173,7 @@ det m | square m = s * (product $ toList $ takeDiag $ lup) where (lup,perm) = luPacked m s = signlp (rows m) perm --- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. +-- | Explicit LU factorization of a general matrix using lapack's dgetrf or zgetrf. -- -- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular, -- u is upper triangular, p is a permutation matrix and s is the signature of the permutation. @@ -513,7 +521,7 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) -------------------------------------------------- --- | euclidean inner product +-- | Euclidean inner product. dot :: (Field t) => Vector t -> Vector t -> t dot u v = multiply r c @@> (0,0) where r = asRow u @@ -629,5 +637,4 @@ multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat free palpha free pbeta return s - -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s)) | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index d78b506..8bc2492 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -19,7 +19,7 @@ module Numeric.LinearAlgebra.LAPACK ( linearSolveR, linearSolveC, linearSolveLSR, linearSolveLSC, linearSolveSVDR, linearSolveSVDC, - luR, luC, lusR, + luR, luC, lusR, lusC, cholS, cholH, qrR, qrC, hessR, hessC, @@ -34,6 +34,7 @@ import Data.Packed.Matrix import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) import Complex import Foreign +import Foreign.C.Types (CInt) ----------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM @@ -338,17 +339,29 @@ luAux f st a = unsafePerformIO $ do where n = rows a m = cols a - ----------------------------------------------------------------------------------- +type TW a = CInt -> PD -> a +type TQ a = CInt -> CInt -> PC -> a + foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM +foreign import ccall "LAPACK/lapack-aux.h luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt)))) +-- | Wrapper for LAPACK's /dgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition. lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double -lusR a piv b = lusR' (fmat a) piv (fmat b) +lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b) + +-- | Wrapper for LAPACK's /zgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition. +lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) +lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) -lusR' a piv b = unsafePerformIO $ do +lusAux f st a piv b + | n1==n2 && n2==n =unsafePerformIO $ do x <- createMatrix ColumnMajor n m - app4 dgetrs mat a vec piv' mat b mat x "lusR" + app4 f mat a vec piv' mat b mat x st return x - where n = rows b - m = cols b - piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double + | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" + where n1 = rows a + n2 = cols a + n = rows b + m = cols b + piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index de3cc98..842b5ad 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -820,9 +820,8 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { //////////////////// LU substitution ///////////////////////// -char charN = 'N'; -int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { +int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { integer m = ar; integer n = ac; integer mrhs = br; @@ -836,10 +835,28 @@ int luS_l_R(DMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { } integer res; memcpy(xp,bp,mrhs*nrhs*sizeof(double)); - integer ldb = mrhs; - integer lda = n; - dgetrs_ (&charN,&n,&nrhs,ap,&lda,auxipiv,xp,&ldb,&res); + dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res); CHECK(res,res); free(auxipiv); OK } + +int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { + integer m = ar; + integer n = ac; + integer mrhs = br; + integer nrhs = bc; + + REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE); + integer* auxipiv = (integer*)malloc(n*sizeof(integer)); + int k; + for (k=0; k) (map rot angles) where angles = toList $ linspace n (0,1) - -- | All tests must pass with a maximum dimension of about 20 -- (some tests may fail with bigger sizes due to precision loss). runTests :: Int -- ^ maximum dimension @@ -135,10 +134,13 @@ runTests n = do putStrLn "------ lu" test (luProp . rM) test (luProp . cM) - putStrLn "------ inv" + putStrLn "------ inv (linearSolve)" test (invProp . rSqWC) test (invProp . cSqWC) - putStrLn "------ pinv" + putStrLn "------ luSolve" + test (linearSolveProp (luSolve.luPacked) . rSqWC) + test (linearSolveProp (luSolve.luPacked) . cSqWC) + putStrLn "------ pinv (linearSolveSVD)" test (pinvProp . rM) if os == "mingw32" then putStrLn "complex pinvTest skipped in this OS" diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index b5321c2..45b03a2 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs @@ -35,7 +35,8 @@ module Numeric.LinearAlgebra.Tests.Properties ( schurProp1, schurProp2, cholProp, expmDiagProp, - multProp1, multProp2 + multProp1, multProp2, + linearSolveProp ) where import Numeric.LinearAlgebra @@ -153,4 +154,6 @@ expmDiagProp m = expm (logm m) :~ 7 ~: complex m multProp1 (a,b) = a <> b |~| mulH a b -multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a +multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a + +linearSolveProp f m = f m m |~| ident (rows m) -- cgit v1.2.3