From 9a0c3092e572f6bd11329e9acabc6470ef438203 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 27 Mar 2010 18:36:53 +0000 Subject: cholSolve --- lib/Numeric/LinearAlgebra/Algorithms.hs | 10 +++++++- lib/Numeric/LinearAlgebra/LAPACK.hs | 16 ++++++++++-- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 36 +++++++++++++++++++++++++++ lib/Numeric/LinearAlgebra/Tests.hs | 3 +++ 4 files changed, 62 insertions(+), 3 deletions(-) (limited to 'lib/Numeric') diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 6b0fb08..0f2ccef 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -27,6 +27,7 @@ module Numeric.LinearAlgebra.Algorithms ( -- * Linear Systems linearSolve, luSolve, + cholSolve, linearSolveLS, linearSolveSVD, inv, pinv, @@ -91,6 +92,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where luPacked' :: Matrix t -> (Matrix t, [Int]) luSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t linearSolve' :: Matrix t -> Matrix t -> Matrix t + cholSolve' :: Matrix t -> Matrix t -> Matrix t linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t linearSolveLS' :: Matrix t -> Matrix t -> Matrix t eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double)) @@ -112,6 +114,7 @@ instance Field Double where luPacked' = luR luSolve' (l_u,perm) = lusR l_u perm linearSolve' = linearSolveR -- (luSolve . luPacked) ?? + cholSolve' = cholSolveR linearSolveLS' = linearSolveLSR linearSolveSVD' = linearSolveSVDR Nothing ctrans' = trans @@ -132,6 +135,7 @@ instance Field (Complex Double) where luPacked' = luC luSolve' (l_u,perm) = lusC l_u perm linearSolve' = linearSolveC + cholSolve' = cholSolveC linearSolveLS' = linearSolveLSC linearSolveSVD' = linearSolveSVDC Nothing ctrans' = conj . trans @@ -229,6 +233,10 @@ luSolve = luSolve' linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t linearSolve = linearSolve' +-- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'. +cholSolve :: Field t => Matrix t -> Matrix t -> Matrix t +cholSolve = cholSolve' + -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t linearSolveSVD = linearSolveSVD' @@ -322,7 +330,7 @@ cholSH = cholSH' -- | Cholesky factorization of a positive definite hermitian or symmetric matrix. -- --- If @c = chol m@ then @m == ctrans c \<> c@. +-- If @c = chol m@ then @c@ is upper triangular and @m == ctrans c \<> c@. chol :: Field t => Matrix t -> Matrix t chol m | exactHermitian m = cholSH m | otherwise = error "chol requires positive definite complex hermitian or real symmetric matrix" diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index a1ac1cf..5d4eb0d 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -18,6 +18,7 @@ module Numeric.LinearAlgebra.LAPACK ( -- * Linear systems linearSolveR, linearSolveC, lusR, lusC, + cholSolveR, cholSolveC, linearSolveLSR, linearSolveLSC, linearSolveSVDR, linearSolveSVDC, -- * SVD @@ -312,8 +313,10 @@ eigOnlyH = vrev . fst. eigSHAux (zheev 1) "eigH'" . fmat vrev = flatten . flipud . reshape 1 ----------------------------------------------------------------------------- -foreign import ccall "LAPACK/lapack-aux.h linearSolveR_l" dgesv :: TMMM -foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM +foreign import ccall "linearSolveR_l" dgesv :: TMMM +foreign import ccall "linearSolveC_l" zgesv :: TCMCMCM +foreign import ccall "cholSolveR_l" dpotrs :: TMMM +foreign import ccall "cholSolveC_l" zpotrs :: TCMCMCM linearSolveSQAux f st a b | n1==n2 && n1==r = unsafePerformIO $ do @@ -334,6 +337,15 @@ linearSolveR a b = linearSolveSQAux dgesv "linearSolveR" (fmat a) (fmat b) linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) linearSolveC a b = linearSolveSQAux zgesv "linearSolveC" (fmat a) (fmat b) + +-- | Solves a symmetric positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholS'. +cholSolveR :: Matrix Double -> Matrix Double -> Matrix Double +cholSolveR a b = linearSolveSQAux dpotrs "cholSolveR" (fmat a) (fmat b) + +-- | Solves a Hermitian positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholH'. +cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +cholSolveC a b = linearSolveSQAux zpotrs "cholSolveC" (fmat a) (fmat b) + ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 06c2479..fd840e3 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -492,6 +492,42 @@ int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { OK } +//////// symmetric positive definite real linear system using Cholesky //////////// + +int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { + integer n = ar; + integer nhrs = bc; + REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); + DEBUGMSG("cholSolveR_l"); + memcpy(xp,bp,n*nhrs*sizeof(double)); + integer res; + dpotrs_ ("U", + &n,&nhrs, + (double*)ap, &n, + xp, &n, + &res); + CHECK(res,res); + OK +} + +//////// Hermitian positive definite real linear system using Cholesky //////////// + +int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { + integer n = ar; + integer nhrs = bc; + REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); + DEBUGMSG("cholSolveC_l"); + memcpy(xp,bp,2*n*nhrs*sizeof(double)); + integer res; + zpotrs_ ("U", + &n,&nhrs, + (doublecomplex*)ap, &n, + (doublecomplex*)xp, &n, + &res); + CHECK(res,res); + OK +} + //////////////////// least squares real linear system //////////// int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) { diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index f8f8bd5..36efab6 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -208,6 +208,9 @@ runTests n = do putStrLn "------ luSolve" test (linearSolveProp (luSolve.luPacked) . rSqWC) test (linearSolveProp (luSolve.luPacked) . cSqWC) + putStrLn "------ cholSolve" + test (linearSolveProp (cholSolve.chol) . rPosDef) + test (linearSolveProp (cholSolve.chol) . cPosDef) putStrLn "------ luSolveLS" test (linearSolveProp linearSolveLS . rSqWC) test (linearSolveProp linearSolveLS . cSqWC) -- cgit v1.2.3