From 02805ad64715373347b34bac2f75cbb866563ba2 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 4 Nov 2008 09:32:35 +0000 Subject: multiply/trans ok --- lib/Numeric/LinearAlgebra/Algorithms.hs | 80 ++------------------------- lib/Numeric/LinearAlgebra/LAPACK.hs | 28 ++++++++++ lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 41 ++++++++++++++ lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 8 ++- lib/Numeric/LinearAlgebra/Tests/Properties.hs | 4 ++ 5 files changed, 83 insertions(+), 78 deletions(-) (limited to 'lib/Numeric/LinearAlgebra') diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 75f4ba3..f259db5 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -54,7 +54,6 @@ module Numeric.LinearAlgebra.Algorithms ( ctrans, eps, i, outer, kronecker, - mulH, -- * Util haussholder, unpackQR, unpackHess, @@ -70,8 +69,8 @@ import Complex import Numeric.LinearAlgebra.Linear import Data.List(foldl1') import Data.Array -import Foreign -import Foreign.C.Types + + -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where @@ -132,7 +131,7 @@ instance Field Double where qr = unpackQR . qrR hess = unpackHess hessR schur = schurR - multiply = multiplyR3 + multiply = multiplyR instance Field (Complex Double) where svd = svdC @@ -147,7 +146,7 @@ instance Field (Complex Double) where qr = unpackQR . qrC hess = unpackHess hessC schur = schurC - multiply = multiplyC3 + multiply = multiplyC -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. @@ -567,74 +566,3 @@ kronecker a b = fromBlocks . map (reshape (cols b)) . toRows $ flatten a `outer` flatten b - ---------------------------------------------------------------------- --- reference multiply ---------------------------------------------------------------------- - -mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] - where doth u v = sum $ zipWith (*) (toList u) (toList v) - ------------------------------------------------------------------------------------ --- workaround ------------------------------------------------------------------------------------ - -mulCW a b = toComplex (rr,ri) - where rr = multiply ar br `sub` multiply ai bi - ri = multiply ar bi `add` multiply ai br - (ar,ai) = fromComplex a - (br,bi) = fromComplex b - ------------------------------------------------------------------------------------ --- Direct CBLAS ------------------------------------------------------------------------------------ - --- taken from Patrick Perry's BLAS package -newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) -newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) - -rowMajor, colMajor :: CBLASOrder -rowMajor = CBLASOrder 101 -colMajor = CBLASOrder 102 - -noTrans, trans', conjTrans :: CBLASTrans -noTrans = CBLASTrans 111 -trans' = CBLASTrans 112 -conjTrans = CBLASTrans 113 - -foreign import ccall "cblas.h cblas_dgemm" - dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt - -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double - -> Ptr Double -> CInt -> IO () - -multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double -multiplyR3 x y = multiply3 dgemm "cblas_dgemm" (fmat x) (fmat y) - where - multiply3 f st a b - | cols a == rows b = unsafePerformIO $ do - s <- createMatrix ColumnMajor (rows a) (cols b) - let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 - app3 g mat a mat b mat s st - return s - | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" - - -foreign import ccall "cblas.h cblas_zgemm" - zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt - -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) - -> Ptr (Complex Double) -> CInt -> IO () - -multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) -multiplyC3 x y = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat x) (fmat y) - where - multiply3 f st a b - | cols a == rows b = do - s <- createMatrix ColumnMajor (rows a) (cols b) - palpha <- new 1 - pbeta <- new 0 - let g ar ac ap br bc bp rr _rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 - app3 g mat a mat b mat s st - free palpha - free pbeta - return s - | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 8bc2492..56945d7 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -14,6 +14,7 @@ ----------------------------------------------------------------------------- module Numeric.LinearAlgebra.LAPACK ( + multiplyR, multiplyC, svdR, svdRdd, svdC, eigC, eigR, eigS, eigH, eigS', eigH', linearSolveR, linearSolveC, @@ -35,6 +36,33 @@ import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale)) import Complex import Foreign import Foreign.C.Types (CInt) +import Control.Monad(when) + +----------------------------------------------------------------------------------- + +foreign import ccall "LAPACK/lapack-aux.h multiplyR" dgemmc :: CInt -> CInt -> TMMM +foreign import ccall "LAPACK/lapack-aux.h multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM + +isT MF{} = 0 +isT MC{} = 1 + +tt x@MF{} = x +tt x@MC{} = trans x + +multiplyAux f st a b = unsafePerformIO $ do + when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ + show (rows a,cols a) ++ " x " ++ show (rows b, cols b) + s <- createMatrix ColumnMajor (rows a) (cols b) + app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st + return s + +-- | Matrix product based on BLAS's /dgemm/. +multiplyR :: Matrix Double -> Matrix Double -> Matrix Double +multiplyR a b = multiplyAux dgemmc "dgemmc" a b + +-- | Matrix product based on BLAS's /zgemm/. +multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +multiplyC a b = multiplyAux zgemmc "zgemmc" a b ----------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 842b5ad..e85c1b7 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -860,3 +860,44 @@ int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { free(auxipiv); OK } + +//////////////////// Matrix Product ///////////////////////// + +void dgemm_(char *, char *, integer *, integer *, integer *, + double *, const double *, integer *, const double *, + integer *, double *, double *, integer *); + +int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) { + //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); + integer m = ta?ac:ar; + integer n = tb?br:bc; + integer k = ta?ar:ac; + integer lda = ar; + integer ldb = br; + integer ldc = rr; + double alpha = 1; + double beta = 0; + dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); + OK +} + +void zgemm_(char *, char *, integer *, integer *, integer *, + doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, + integer *, doublecomplex *, doublecomplex *, integer *); + +int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) { + //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); + integer m = ta?ac:ar; + integer n = tb?br:bc; + integer k = ta?ar:ac; + integer lda = ar; + integer ldb = br; + integer ldc = rr; + doublecomplex alpha = {1,0}; + doublecomplex beta = {0,0}; + zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, + (doublecomplex*)ap,&lda, + (doublecomplex*)bp,&ldb,&beta, + (doublecomplex*)rp,&ldc); + OK +} diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 23e5e28..3f58243 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h @@ -45,11 +45,15 @@ typedef short ftnlen; #define DMAT(A) int A##r, int A##c, double* A##p #define CMAT(A) int A##r, int A##c, double* A##p -// const pointer versions for the parameters #define KDVEC(A) int A##n, const double*A##p #define KCVEC(A) int A##n, const double*A##p #define KDMAT(A) int A##r, int A##c, const double* A##p -#define KCMAT(A) int A##r, int A##c, const double* A##p +#define KCMAT(A) int A##r, int A##c, const double* A##p + +/********************************************************/ + +int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)); +int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)); int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index 45b03a2..ec87ad0 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs @@ -152,6 +152,10 @@ cholProp m = m |~| ctrans c <> c && upperTriang c expmDiagProp m = expm (logm m) :~ 7 ~: complex m where logm = matFunc log +-- reference multiply +mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] + where doth u v = sum $ zipWith (*) (toList u) (toList v) + multProp1 (a,b) = a <> b |~| mulH a b multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a -- cgit v1.2.3