From 192ac5f4b98517862c37ecf161505396ad223cd8 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Thu, 2 Oct 2008 15:53:10 +0000 Subject: alternative multiply versions --- lib/Data/Packed/Internal/Matrix.hs | 37 +-- lib/Data/Packed/Internal/auxi.c | 90 +------- lib/Data/Packed/Internal/auxi.h | 6 +- lib/Graphics/Plot.hs | 2 +- lib/Numeric/GSL/Matrix.hs | 311 -------------------------- lib/Numeric/GSL/gsl-aux.c | 286 ----------------------- lib/Numeric/GSL/gsl-aux.h | 19 -- lib/Numeric/LinearAlgebra/Algorithms.hs | 171 +++++++++++++- lib/Numeric/LinearAlgebra/Interface.hs | 4 +- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 74 ++++++ lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 6 + lib/Numeric/LinearAlgebra/Linear.hs | 54 +---- lib/Numeric/LinearAlgebra/Tests.hs | 5 + lib/Numeric/LinearAlgebra/Tests/Instances.hs | 16 ++ lib/Numeric/LinearAlgebra/Tests/Properties.hs | 6 +- 15 files changed, 300 insertions(+), 787 deletions(-) delete mode 100644 lib/Numeric/GSL/Matrix.hs (limited to 'lib') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index caf3699..45a3955 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -212,7 +212,6 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 class (Storable a, Floating a) => Element a where constantD :: a -> Int -> Vector a transdata :: Int -> Vector a -> Int -> Vector a - multiplyD :: Matrix a -> Matrix a -> Matrix a subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -> Matrix a @@ -221,14 +220,12 @@ class (Storable a, Floating a) => Element a where instance Element Double where constantD = constantR transdata = transdataR - multiplyD = multiplyR subMatrixD = subMatrixR diagD = diagR instance Element (Complex Double) where constantD = constantC transdata = transdataC - multiplyD = multiplyC subMatrixD = subMatrixC diagD = diagC @@ -266,33 +263,6 @@ transdataAux fun c1 d c2 = foreign import ccall "auxi.h transR" ctransR :: TMM foreign import ccall "auxi.h transC" ctransC :: TCMCM ------------------------------------------------------------------- - -gmatC MF { rows = r, cols = c } p f = f 1 (fi c) (fi r) p -gmatC MC { rows = r, cols = c } p f = f 0 (fi r) (fi c) p - -dtt MC { cdat = d } = d -dtt MF { fdat = d } = d - -multiplyAux fun a b = unsafePerformIO $ do - when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ - show (rows a,cols a) ++ " x " ++ show (rows b, cols b) - r <- createMatrix RowMajor (rows a) (cols b) - withForeignPtr (fptr (dtt a)) $ \pa -> withForeignPtr (fptr (dtt b)) $ \pb -> - withMatrix r $ \r' -> - fun // gmatC a pa // gmatC b pb // r' // check "multiplyAux" - return r - -multiplyR = multiplyAux cmultiplyR -foreign import ccall "auxi.h multiplyR" cmultiplyR :: TauxMul Double - -multiplyC = multiplyAux cmultiplyC -foreign import ccall "auxi.h multiplyC" cmultiplyC :: TauxMul (Complex Double) - --- | matrix product -multiply :: (Element a) => Matrix a -> Matrix a -> Matrix a -multiply = multiplyD - ---------------------------------------------------------------------- -- | extraction of a submatrix from a real matrix @@ -370,7 +340,12 @@ constant = constantD -- | obtains the complex conjugate of a complex vector conj :: Vector (Complex Double) -> Vector (Complex Double) -conj v = asComplex $ flatten $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) +conj v = unsafePerformIO $ do + r <- createVector (dim v) + app2 cconjugate vec v vec r "cconjugate" + return r +foreign import ccall "auxi.h conjugate" cconjugate :: TCVCV + -- | creates a complex vector from vectors with real and imaginary parts toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) diff --git a/lib/Data/Packed/Internal/auxi.c b/lib/Data/Packed/Internal/auxi.c index 7f83bcf..04dc7ad 100644 --- a/lib/Data/Packed/Internal/auxi.c +++ b/lib/Data/Packed/Internal/auxi.c @@ -4,14 +4,9 @@ #include #include #include -#include -#include -#include -#include -#include -#include #include #include +#include #include #include @@ -118,78 +113,6 @@ int constantC(gsl_complex* pval, CVEC(r)) { } -int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { - //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); - DEBUGMSG("multiplyR (gsl_blas_dgemm)"); - KDMVIEW(a); - KDMVIEW(b); - DMVIEW(r); - int k; - for(k=0;k --- Stability : provisional --- Portability : portable (uses FFI) --- --- A few linear algebra computations based on GSL. --- ------------------------------------------------------------------------------ --- #hide - -module Numeric.GSL.Matrix( - eigSg, eigHg, - svdg, - qr, qrPacked, unpackQR, - cholR, cholC, - luSolveR, luSolveC, - luR, luC -) where - -import Data.Packed.Internal -import Data.Packed.Matrix(ident) -import Numeric.GSL.Vector -import Foreign -import Complex - -{- | eigendecomposition of a real symmetric matrix using /gsl_eigen_symmv/. - -> > let (l,v) = eigS $ 'fromLists' [[1,2],[2,1]] -> > l -> 3.000 -1.000 -> -> > v -> 0.707 -0.707 -> 0.707 0.707 -> -> > v <> diag l <> trans v -> 1.000 2.000 -> 2.000 1.000 - --} -eigSg :: Matrix Double -> (Vector Double, Matrix Double) -eigSg = eigSg' . cmat - -eigSg' m - | r == 1 = (fromList [cdat m `at` 0], singleton 1) - | otherwise = unsafePerformIO $ do - l <- createVector r - v <- createMatrix RowMajor r r - app3 c_eigS mat m vec l mat v "eigSg" - return (l,v) - where r = rows m -foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM - ------------------------------------------------------------------- - - - -{- | eigendecomposition of a complex hermitian matrix using /gsl_eigen_hermv/ - -> > let (l,v) = eigH $ 'fromLists' [[1,2+i],[2-i,3]] -> -> > l -> 4.449 -0.449 -> -> > v -> -0.544 0.839 -> (-0.751,0.375) (-0.487,0.243) -> -> > v <> diag l <> (conjTrans) v -> 1.000 (2.000,1.000) -> (2.000,-1.000) 3.000 - --} -eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) -eigHg = eigHg' . cmat - -eigHg' m - | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) - | otherwise = unsafePerformIO $ do - l <- createVector r - v <- createMatrix RowMajor r r - app3 c_eigH mat m vec l mat v "eigHg" - return (l,v) - where r = rows m -foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM - - -{- | Singular value decomposition of a real matrix, using /gsl_linalg_SV_decomp_mod/: - - --} -svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) -svdg x = if rows x >= cols x - then svd' (cmat x) - else (v, s, u) where (u,s,v) = svd' (cmat (trans x)) - -svd' x = unsafePerformIO $ do - u <- createMatrix RowMajor r c - s <- createVector c - v <- createMatrix RowMajor c c - app4 c_svd mat x mat u vec s mat v "svdg" - return (u,s,v) - where r = rows x - c = cols x -foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM - -{- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/. - --} -qr :: Matrix Double -> (Matrix Double, Matrix Double) -qr = qr' . cmat - -qr' x = unsafePerformIO $ do - q <- createMatrix RowMajor r r - rot <- createMatrix RowMajor r c - app3 c_qr mat x mat q mat rot "qr" - return (q,rot) - where r = rows x - c = cols x -foreign import ccall "gsl-aux.h QR" c_qr :: TMMM - -qrPacked :: Matrix Double -> (Matrix Double, Vector Double) -qrPacked = qrPacked' . cmat - -qrPacked' x = unsafePerformIO $ do - qrp <- createMatrix RowMajor r c - tau <- createVector (min r c) - app3 c_qrPacked mat x mat qrp vec tau "qrUnpacked" - return (qrp,tau) - where r = rows x - c = cols x -foreign import ccall "gsl-aux.h QRpacked" c_qrPacked :: TMMV - -unpackQR :: (Matrix Double, Vector Double) -> (Matrix Double, Matrix Double) -unpackQR (qrp,tau) = unpackQR' (cmat qrp, tau) - -unpackQR' (qrp,tau) = unsafePerformIO $ do - q <- createMatrix RowMajor r r - res <- createMatrix RowMajor r c - app4 c_qrUnpack mat qrp vec tau mat q mat res "qrUnpack" - return (q,res) - where r = rows qrp - c = cols qrp -foreign import ccall "gsl-aux.h QRunpack" c_qrUnpack :: TMVMM - -{- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/. - -@\> chol $ (2><2) [1,2, - 2,9::Double] -(2><2) - [ 1.0, 0.0 - , 2.0, 2.23606797749979 ]@ - --} -cholR :: Matrix Double -> Matrix Double -cholR = cholR' . cmat - -cholR' x = unsafePerformIO $ do - r <- createMatrix RowMajor n n - app2 c_cholR mat x mat r "cholR" - return r - where n = rows x -foreign import ccall "gsl-aux.h cholR" c_cholR :: TMM - -cholC :: Matrix (Complex Double) -> Matrix (Complex Double) -cholC = cholC' . cmat - -cholC' x = unsafePerformIO $ do - r <- createMatrix RowMajor n n - app2 c_cholC mat x mat r "cholC" - return r - where n = rows x -foreign import ccall "gsl-aux.h cholC" c_cholC :: TCMCM - - --------------------------------------------------------- - -{- -| efficient multiplication by the inverse of a matrix (for real matrices) --} -luSolveR :: Matrix Double -> Matrix Double -> Matrix Double -luSolveR a b = luSolveR' (cmat a) (cmat b) - -luSolveR' a b - | n1==n2 && n1==r = unsafePerformIO $ do - s <- createMatrix RowMajor r c - app3 c_luSolveR mat a mat b mat s "luSolveR" - return s - | otherwise = error "luSolveR of nonsquare matrix" - where n1 = rows a - n2 = cols a - r = rows b - c = cols b -foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM - -{- -| efficient multiplication by the inverse of a matrix (for complex matrices). --} -luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) -luSolveC a b = luSolveC' (cmat a) (cmat b) - -luSolveC' a b - | n1==n2 && n1==r = unsafePerformIO $ do - s <- createMatrix RowMajor r c - app3 c_luSolveC mat a mat b mat s "luSolveC" - return s - | otherwise = error "luSolveC of nonsquare matrix" - where n1 = rows a - n2 = cols a - r = rows b - c = cols b -foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM - -{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) --} -luRaux :: Matrix Double -> Vector Double -luRaux = luRaux' . cmat - -luRaux' x = unsafePerformIO $ do - res <- createVector (r*r+r+1) - app2 c_luRaux mat x vec res "luRaux" - return res - where r = rows x -foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV - -{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) --} -luCaux :: Matrix (Complex Double) -> Vector (Complex Double) -luCaux = luCaux' . cmat - -luCaux' x = unsafePerformIO $ do - res <- createVector (r*r+r+1) - app2 c_luCaux mat x vec res "luCaux" - return res - where r = rows x -foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV - -{- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in . - -@\> let m = 'fromLists' [[1,2,-3],[2+3*i,-7,0],[1,-i,2*i]] -\> let (l,u,p,s) = luR m@ - -L is the lower triangular: - -@\> l - 1. 0. 0. -0.154-0.231i 1. 0. -0.154-0.231i 0.624-0.522i 1.@ - -U is the upper triangular: - -@\> u -2.+3.i -7. 0. - 0. 3.077-1.615i -3. - 0. 0. 1.873+0.433i@ - -p is a permutation: - -@\> p -[1,0,2]@ - -L \* U obtains a permuted version of the original matrix: - -@\> extractRows p m - 2.+3.i -7. 0. - 1. 2. -3. - 1. -1.i 2.i -\ -- CPP -\> l \<\> u - 2.+3.i -7. 0. - 1. 2. -3. - 1. -1.i 2.i@ - -s is the sign of the permutation, required to obtain sign of the determinant: - -@\> s * product ('toList' $ 'takeDiag' u) -(-18.0) :+ (-16.000000000000004) -\> 'LinearAlgebra.Algorithms.det' m -(-18.0) :+ (-16.000000000000004)@ - - -} -luR :: Matrix Double -> (Matrix Double, Matrix Double, [Int], Double) -luR m = (l,u,p, fromIntegral s') where - r = rows m - v = luRaux m - lu = reshape r $ subVector 0 (r*r) v - s':p = map round . toList . subVector (r*r) (r+1) $ v - u = triang r r 0 1`mul` lu - l = (triang r r 0 0 `mul` lu) `add` ident r - add = liftMatrix2 $ vectorZipR Add - mul = liftMatrix2 $ vectorZipR Mul - --- | Complex version of 'luR'. -luC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double), [Int], Complex Double) -luC m = (l,u,p, fromIntegral s') where - r = rows m - v = luCaux m - lu = reshape r $ subVector 0 (r*r) v - s':p = map (round.realPart) . toList . subVector (r*r) (r+1) $ v - u = triang r r 0 1 `mul` lu - l = (triang r r 0 0 `mul` lu) `add` liftMatrix comp (ident r) - add = liftMatrix2 $ vectorZipC Add - mul = liftMatrix2 $ vectorZipC Mul - -{- auxiliary function to get triangular matrices --} -triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] - where el i j = if j-i>=h then v else 1 - v diff --git a/lib/Numeric/GSL/gsl-aux.c b/lib/Numeric/GSL/gsl-aux.c index bd0a6bd..052cafd 100644 --- a/lib/Numeric/GSL/gsl-aux.c +++ b/lib/Numeric/GSL/gsl-aux.c @@ -1,11 +1,8 @@ #include "gsl-aux.h" #include -#include -#include #include #include #include -#include #include #include #include @@ -161,47 +158,6 @@ int mapC(int code, KCVEC(x), CVEC(r)) { } -/* -int scaleR(double* alpha, KRVEC(x), RVEC(r)) { - REQUIRES(xn == rn,BAD_SIZE); - DEBUGMSG("scaleR"); - KDVVIEW(x); - DVVIEW(r); - CHECK( gsl_vector_memcpy(V(r),V(x)) , MEM); - int res = gsl_vector_scale(V(r),*alpha); - CHECK(res,res); - OK -} - -int scaleC(gsl_complex *alpha, KCVEC(x), CVEC(r)) { - REQUIRES(xn == rn,BAD_SIZE); - DEBUGMSG("scaleC"); - //KCVVIEW(x); - CVVIEW(r); - gsl_vector_const_view vrx = gsl_vector_const_view_array((double*)xp,xn*2); - gsl_vector_view vrr = gsl_vector_view_array((double*)rp,rn*2); - CHECK(gsl_vector_memcpy(V(vrr),V(vrx)) , MEM); - gsl_blas_zscal(*alpha,V(r)); // void ! - int res = 0; - CHECK(res,res); - OK -} - -int addConstantR(double offs, KRVEC(x), RVEC(r)) { - REQUIRES(xn == rn,BAD_SIZE); - DEBUGMSG("addConstantR"); - KDVVIEW(x); - DVVIEW(r); - CHECK(gsl_vector_memcpy(V(r),V(x)), MEM); - int res = gsl_vector_add_constant(V(r),offs); - CHECK(res,res); - OK -} - -*/ - - - int mapValR(int code, double* pval, KRVEC(x), RVEC(r)) { int k; double val = *pval; @@ -291,248 +247,6 @@ int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r)) { - -int luSolveR(KRMAT(a),KRMAT(b),RMAT(r)) { - REQUIRES(ar==ac && ac==br && ar==rr && bc==rc,BAD_SIZE); - DEBUGMSG("luSolveR"); - KDMVIEW(a); - KDMVIEW(b); - DMVIEW(r); - int res; - gsl_matrix *LU = gsl_matrix_alloc(ar,ar); - CHECK(!LU,MEM); - int s; - gsl_permutation * p = gsl_permutation_alloc (ar); - CHECK(!p,MEM); - CHECK(gsl_matrix_memcpy(LU,M(a)),MEM); - res = gsl_linalg_LU_decomp(LU, p, &s); - CHECK(res,res); - int c; - - for (c=0; c Field t where @@ -105,6 +109,7 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where schur :: Matrix t -> (Matrix t, Matrix t) -- | Conjugate transpose. ctrans :: Matrix t -> Matrix t + multiply :: Matrix t -> Matrix t -> Matrix t instance Field Double where @@ -116,9 +121,10 @@ instance Field Double where eig = eigR eigSH' = eigS cholSH = cholS - qr = GSL.unpackQR . qrR + qr = unpackQR . qrR hess = unpackHess hessR schur = schurR + multiply = multiplyR3 instance Field (Complex Double) where svd = svdC @@ -132,6 +138,8 @@ instance Field (Complex Double) where qr = unpackQR . qrC hess = unpackHess hessC schur = schurC + multiply = mulCW -- workaround + -- multiplyC3 -- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix using lapack's dsyev or zheev. -- @@ -501,3 +509,162 @@ luFact (lu,perm) | r <= c = (l ,u ,p, s) u' = takeRows c (lu |*| tu) (|+|) = add (|*|) = mul + +-------------------------------------------------- + +-- | euclidean inner product +dot :: (Field t) => Vector t -> Vector t -> t +dot u v = multiply r c @@> (0,0) + where r = asRow u + c = asColumn v + + +{- | Outer product of two vectors. + +@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] +(3><3) + [ 5.0, 2.0, 3.0 + , 10.0, 4.0, 6.0 + , 15.0, 6.0, 9.0 ]@ +-} +outer :: (Field t) => Vector t -> Vector t -> Matrix t +outer u v = asColumn u `multiply` asRow v + +{- | Kronecker product of two matrices. + +@m1=(2><3) + [ 1.0, 2.0, 0.0 + , 0.0, -1.0, 3.0 ] +m2=(4><3) + [ 1.0, 2.0, 3.0 + , 4.0, 5.0, 6.0 + , 7.0, 8.0, 9.0 + , 10.0, 11.0, 12.0 ]@ + +@\> kronecker m1 m2 +(8><9) + [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 + , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 + , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 + , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 + , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 + , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 + , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 + , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ +-} +kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t +kronecker a b = fromBlocks + . partit (cols a) + . map (reshape (cols b)) + . toRows + $ flatten a `outer` flatten b + +--------------------------------------------------------------------- +-- reference multiply +--------------------------------------------------------------------- + +mulH a b = fromLists [[ dot ai bj | bj <- toColumns b] | ai <- toRows a ] + where dot u v = sum $ zipWith (*) (toList u) (toList v) + +----------------------------------------------------------------------------------- +-- workaround +----------------------------------------------------------------------------------- + +mulCW a b = toComplex (rr,ri) + where rr = multiply ar br `sub` multiply ai bi + ri = multiply ar bi `add` multiply ai br + (ar,ai) = fromComplex a + (br,bi) = fromComplex b + +----------------------------------------------------------------------------------- +-- Direct CBLAS +----------------------------------------------------------------------------------- + +newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) +newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) + +rowMajor, colMajor :: CBLASOrder +rowMajor = CBLASOrder 101 +colMajor = CBLASOrder 102 + +noTrans, trans', conjTrans :: CBLASTrans +noTrans = CBLASTrans 111 +trans' = CBLASTrans 112 +conjTrans = CBLASTrans 113 + +foreign import ccall "cblas.h cblas_dgemm" + dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () + + +multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double +multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) + where + multiply3 f st a b + | cols a == rows b = unsafePerformIO $ do + s <- createMatrix ColumnMajor (rows a) (cols b) + let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac 1 ap ar bp br 0 rp rr >> return 0 + app3 g mat a mat b mat s st + return s + | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" + + +foreign import ccall "cblas.h cblas_zgemm" + zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () + +multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) + where + multiply3 f st a b + | cols a == rows b = do + s <- createMatrix ColumnMajor (rows a) (cols b) + palpha <- new 1 + pbeta <- new 0 + let g ar ac ap br bc bp rr rc rp = f colMajor noTrans noTrans ar bc ac palpha ap ar bp br pbeta rp rr >> return 0 + app3 g mat a mat b mat s st + free palpha + free pbeta + return s + -- if toLists s== toLists s then return s else error $ "HORROR " ++ (show (toLists s)) + | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" + +----------------------------------------------------------------------------------- +-- BLAS via auxiliary C +----------------------------------------------------------------------------------- + +foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM +foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM + +multiply2 f st a b + | cols a == rows b = unsafePerformIO $ do + s <- createMatrix ColumnMajor (rows a) (cols b) + app3 f mat a mat b mat s st + if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) + | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" + +multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double +multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) + +multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) + +----------------------------------------------------------------------------------- +-- direct C multiplication +----------------------------------------------------------------------------------- + +foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM +foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM + +cmultiply f st a b +-- | cols a == rows b = + = unsafePerformIO $ do + s <- createMatrix RowMajor (rows a) (cols b) + app3 f mat a mat b mat s st + if toLists s== toLists s then return s else error $ "BRUTAL " ++ (show (toLists s)) + -- return s +-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" + +multiplyR :: Matrix Double -> Matrix Double -> Matrix Double +multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) + +multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs index 4a9b309..0ae9698 100644 --- a/lib/Numeric/LinearAlgebra/Interface.hs +++ b/lib/Numeric/LinearAlgebra/Interface.hs @@ -29,7 +29,7 @@ import Numeric.LinearAlgebra.Algorithms class Mul a b c | a b -> c where infixl 7 <> -- | matrix product - (<>) :: Element t => a t -> b t -> c t + (<>) :: Field t => a t -> b t -> c t instance Mul Matrix Matrix Matrix where (<>) = multiply @@ -43,7 +43,7 @@ instance Mul Vector Matrix Vector where --------------------------------------------------- -- | @u \<.\> v = dot u v@ -(<.>) :: (Element t) => Vector t -> Vector t -> t +(<.>) :: (Field t) => Vector t -> Vector t -> t infixl 7 <.> (<.>) = dot diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 310f6ee..0dccea2 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -814,3 +814,77 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { free(auxipiv); OK } + +//////////////////////////////////////////////////////////// + +int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) { + REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); + int i,j,k; + for (i=0;iw) exit(1); + //printf("%d",w>w); + temp.r += aik.r * bkj.r - aik.i * bkj.i; + temp.i += aik.r * bkj.i + aik.i * bkj.r; + //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i); + //printf("%f %f %f \n",w,temp.r,temp.i); + + } + ((doublecomplex*)rp)[i*rc+j] = temp; + //printf("%f %f\n",temp.r,temp.i); + } + } + OK +} + +void dgemm_(char *, char *, integer *, integer *, integer *, + double *, const double *, integer *, const double *, + integer *, double *, double *, integer *); + +void zgemm_(char *, char *, integer *, integer *, integer *, + doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, + integer *, doublecomplex *, doublecomplex *, integer *); + + +int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { + REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); + double alpha = 1; + double beta = 0; + integer m = ar; + integer n = bc; + integer k = ac; + int i,j; + dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); + OK +} + +int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { + REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); + integer m = ar; + integer n = bc; + integer k = ac; + doublecomplex alpha = {1,0}; + doublecomplex beta = {0,0}; + zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); + OK +} diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 79e52be..c0361a6 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h @@ -84,3 +84,9 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); + +int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); +int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); + +int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); +int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)); diff --git a/lib/Numeric/LinearAlgebra/Linear.hs b/lib/Numeric/LinearAlgebra/Linear.hs index 0ddbb55..1bf8b04 100644 --- a/lib/Numeric/LinearAlgebra/Linear.hs +++ b/lib/Numeric/LinearAlgebra/Linear.hs @@ -15,12 +15,11 @@ Basic optimized operations on vectors and matrices. ----------------------------------------------------------------------------- module Numeric.LinearAlgebra.Linear ( - Linear(..), - multiply, dot, outer, kronecker + Linear(..) ) where -import Data.Packed.Internal(multiply,partit) +import Data.Packed.Internal(partit) import Data.Packed import Numeric.GSL.Vector import Complex @@ -69,52 +68,3 @@ instance (Linear Vector a, Container Matrix a) => (Linear Matrix a) where mul = liftMatrix2 mul divide = liftMatrix2 divide equal a b = cols a == cols b && flatten a `equal` flatten b - --------------------------------------------------- - --- | euclidean inner product -dot :: (Element t) => Vector t -> Vector t -> t -dot u v = multiply r c @@> (0,0) - where r = asRow u - c = asColumn v - - -{- | Outer product of two vectors. - -@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] -(3><3) - [ 5.0, 2.0, 3.0 - , 10.0, 4.0, 6.0 - , 15.0, 6.0, 9.0 ]@ --} -outer :: (Element t) => Vector t -> Vector t -> Matrix t -outer u v = asColumn u `multiply` asRow v - -{- | Kronecker product of two matrices. - -@m1=(2><3) - [ 1.0, 2.0, 0.0 - , 0.0, -1.0, 3.0 ] -m2=(4><3) - [ 1.0, 2.0, 3.0 - , 4.0, 5.0, 6.0 - , 7.0, 8.0, 9.0 - , 10.0, 11.0, 12.0 ]@ - -@\> kronecker m1 m2 -(8><9) - [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 - , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 - , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 - , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 - , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 - , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 - , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 - , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ --} -kronecker :: (Element t) => Matrix t -> Matrix t -> Matrix t -kronecker a b = fromBlocks - . partit (cols a) - . map (reshape (cols b)) - . toRows - $ flatten a `outer` flatten b diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 7b28075..07b9f63 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -123,6 +123,11 @@ runTests :: Int -- ^ maximum dimension runTests n = do setErrorHandlerOff let test p = qCheck n p + putStrLn "------ mult" + test (multProp1 . rConsist) + test (multProp1 . cConsist) + test (multProp2 . rConsist) + test (multProp2 . cConsist) putStrLn "------ lu" test (luProp . rM) test (luProp . cM) diff --git a/lib/Numeric/LinearAlgebra/Tests/Instances.hs b/lib/Numeric/LinearAlgebra/Tests/Instances.hs index af486c8..e7fecf2 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Instances.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Instances.hs @@ -20,6 +20,7 @@ module Numeric.LinearAlgebra.Tests.Instances( WC(..), rWC,cWC, SqWC(..), rSqWC, cSqWC, PosDef(..), rPosDef, cPosDef, + Consistent(..), rConsist, cConsist, RM,CM, rM,cM ) where @@ -116,6 +117,19 @@ instance (Field a, Arbitrary a) => Arbitrary (PosDef a) where return $ PosDef (0.5 .* p + 0.5 .* ctrans p) coarbitrary = undefined +-- a pair of matrices that can be multiplied +newtype (Consistent a) = Consistent (Matrix a, Matrix a) deriving Show +instance (Field a, Arbitrary a) => Arbitrary (Consistent a) where + arbitrary = do + n <- chooseDim + k <- chooseDim + m <- chooseDim + la <- vector (n*k) + lb <- vector (k*m) + return $ Consistent ((n> c && upperTriang c expmDiagProp m = expm (logm m) :~ 7 ~: complex m where logm m = matFunc log m +multProp1 (a,b) = a <> b |~| mulH a b + +multProp2 (a,b) = trans (a <> b) |~| trans b <> trans a -- cgit v1.2.3