From c520939e33cc895febed271d5c3218457317bba9 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 3 Dec 2007 10:43:52 +0000 Subject: lapack lu --- lib/Numeric/LinearAlgebra/Algorithms.hs | 58 +++++++++++++++++++++++---- lib/Numeric/LinearAlgebra/LAPACK.hs | 21 +++++++++- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 46 +++++++++++++++++++++ lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 13 ++---- 4 files changed, 120 insertions(+), 18 deletions(-) (limited to 'lib') diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 069d9a3..b19c0ec 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -37,6 +37,8 @@ module Numeric.LinearAlgebra.Algorithms ( hess, -- ** Schur schur, +-- ** LU + lu, -- * Matrix functions expm, sqrtm, @@ -52,11 +54,11 @@ module Numeric.LinearAlgebra.Algorithms ( -- * Util haussholder, unpackQR, unpackHess, - Field(linearSolveSVD,lu,eigSH',cholSH) + Field(linearSolveSVD,eigSH',cholSH) ) where -import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj) +import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) import Data.Packed import qualified Numeric.GSL.Matrix as GSL import Numeric.GSL.Vector @@ -64,12 +66,13 @@ import Numeric.LinearAlgebra.LAPACK as LAPACK import Complex import Numeric.LinearAlgebra.Linear import Data.List(foldl1') +import Data.Array -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. class (Normed (Matrix t), Linear Matrix t) => Field t where -- | Singular value decomposition using lapack's dgesvd or zgesvd. svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) - lu :: Matrix t -> (Matrix t, Matrix t, [Int], t) + luPacked :: Matrix t -> (Matrix t, [Int]) -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". linearSolve :: Matrix t -> Matrix t -> Matrix t @@ -106,7 +109,7 @@ class (Normed (Matrix t), Linear Matrix t) => Field t where instance Field Double where svd = svdR - lu = GSL.luR + luPacked = luR linearSolve = linearSolveR linearSolveSVD = linearSolveSVDR Nothing ctrans = trans @@ -119,7 +122,7 @@ instance Field Double where instance Field (Complex Double) where svd = svdC - lu = GSL.luC + luPacked = luC linearSolve = linearSolveC linearSolveSVD = linearSolveSVDC Nothing ctrans = conj . trans @@ -146,10 +149,19 @@ chol m | m `equal` ctrans m = cholSH m square m = rows m == cols m +-- | determinant of a square matrix, computed from the LU decomposition. det :: Field t => Matrix t -> t -det m | square m = s * (product $ toList $ takeDiag $ u) +det m | square m = s * (product $ toList $ takeDiag $ lu) | otherwise = error "det of nonsquare matrix" - where (_,u,_,s) = lu m + where (lu,perm) = luPacked m + s = signlp (rows m) perm + +-- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. +-- +-- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular, +-- u is upper triangular, p is a permutation matrix and s is the signature of the permutation. +lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t) +lu = luFact . luPacked -- | Inverse of a square matrix using lapacks' dgesv and zgesv. inv :: Field t => Matrix t -> Matrix t @@ -457,3 +469,35 @@ sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x)) (.*) = scale (|+|) = add (|-|) = sub + +------------------------------------------------------------------ + +signlp r vals = foldl f 1 (zip [0..r-1] vals) + where f s (a,b) | a /= b = -s + | otherwise = s + +swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s) + | otherwise = (arr,s) + +fixPerm r vals = (fromColumns $ elems res, sign) + where v = [0..r-1] + s = toColumns (ident r) + (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) + +triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] + where el i j = if j-i>=h then v else 1 - v + +luFact (lu,perm) | r <= c = (l ,u ,p, s) + | otherwise = (l',u',p, s) + where + r = rows lu + c = cols lu + tu = triang r c 0 1 + tl = triang r c 0 0 + l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r + u = lu |*| tu + (p,s) = fixPerm r perm + l' = (lu |*| tl) |+| diagRect (constant 1 c) r c + u' = takeRows c (lu |*| tu) + (|+|) = add + (|*|) = mul diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index cacad87..83db901 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -19,6 +19,7 @@ module Numeric.LinearAlgebra.LAPACK ( linearSolveR, linearSolveC, linearSolveLSR, linearSolveLSC, linearSolveSVDR, linearSolveSVDC, + luR, luC, cholS, cholH, qrR, qrC, hessR, hessC, @@ -299,7 +300,7 @@ hessAux f st a = unsafePerformIO $ do mn = min m n ----------------------------------------------------------------------------------- -foreign import ccall safe "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM +foreign import ccall "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM -- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix. @@ -318,3 +319,21 @@ schurAux f st a = unsafePerformIO $ do where n = rows a ----------------------------------------------------------------------------------- +foreign import ccall "LAPACK/lapack-aux.h lu_l_R" dgetrf :: TMVM +foreign import ccall "LAPACK/lapack-aux.h lu_l_C" zgetrf :: TCMVCM + +-- | Wrapper for LAPACK's /dgetrf/, which computes a LU factorization of a general real matrix. +luR :: Matrix Double -> (Matrix Double, [Int]) +luR = luAux dgetrf "luR" . fmat + +-- | Wrapper for LAPACK's /zgees/, which computes a Schur factorization of a square complex matrix. +luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) +luC = luAux zgetrf "luC" . fmat + +luAux f st a = unsafePerformIO $ do + lu <- createMatrix ColumnMajor n m + piv <- createVector (min n m) + app3 f mat a vec piv mat lu st + return (lu, map (pred.round) (toList piv)) + where n = rows a + m = cols a diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 8392feb..310f6ee 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -768,3 +768,49 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { OK #endif } + +//////////////////// LU factorization ///////////////////////// + +int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) { + integer m = ar; + integer n = ac; + integer mn = MIN(m,n); + REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); + DEBUGMSG("lu_l_R"); + integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); + memcpy(rp,ap,m*n*sizeof(double)); + integer res; + dgetrf_ (&m,&n,rp,&m,auxipiv,&res); + if(res>0) { + res = 0; // fixme + } + CHECK(res,res); + int k; + for (k=0; k=1 && n >=1 && ipivn == mn, BAD_SIZE); + DEBUGMSG("lu_l_C"); + integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); + memcpy(rp,ap,m*n*sizeof(doublecomplex)); + integer res; + zgetrf_ (&m,&n,(doublecomplex*)rp,&m,auxipiv,&res); + if(res>0) { + res = 0; // fixme + } + CHECK(res,res); + int k; + for (k=0; k