From 61d90ff66af8bfe53ef8cdda8dfe1e70463c213c Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 17 Jun 2015 13:02:40 +0200 Subject: gemmm --- packages/base/src/Internal/C/lapack-aux.c | 59 ++++++++++++++++++++++++ packages/base/src/Internal/C/lapack-aux.h | 1 + packages/base/src/Internal/Matrix.hs | 24 +++++++++- packages/base/src/Internal/Modular.hs | 6 +++ packages/base/src/Internal/ST.hs | 27 +++++++---- packages/base/src/Numeric/LinearAlgebra/Devel.hs | 2 +- 6 files changed, 108 insertions(+), 11 deletions(-) diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 2843ab5..4d48594 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -1398,6 +1398,65 @@ ROWOP(int64_t) ROWOP_MOD(int32_t,mod) ROWOP_MOD(int64_t,mod_l) +/////////////////////////////// inplace GEMM //////////////////////////////// + +#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ + T a = cp[0], b = cp[1]; \ + int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ + int r1b = pp[4], c1b = pp[6] ; \ + int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ + int dra = r1a - r1r; \ + int dcb = c1b-c1r; \ + int nk = c2a-c1a+1; \ + int i,j,k; \ + T t; \ + for (i=r1r; i<=r2r; i++) { \ + for (j=c1r; j<=c2r; j++) { \ + t = 0; \ + for(k=0; k Element a where selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () + gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () instance Element Float where @@ -287,6 +288,7 @@ instance Element Float where selectV = selectF remapM = remapF rowOp = rowOpAux c_rowOpF + gemm = gemmg c_gemmF instance Element Double where transdata = transdataAux ctransR @@ -299,7 +301,7 @@ instance Element Double where selectV = selectD remapM = remapD rowOp = rowOpAux c_rowOpD - + gemm = gemmg c_gemmD instance Element (Complex Float) where transdata = transdataAux ctransQ @@ -312,7 +314,7 @@ instance Element (Complex Float) where selectV = selectQ remapM = remapQ rowOp = rowOpAux c_rowOpQ - + gemm = gemmg c_gemmQ instance Element (Complex Double) where transdata = transdataAux ctransC @@ -325,6 +327,7 @@ instance Element (Complex Double) where selectV = selectC remapM = remapC rowOp = rowOpAux c_rowOpC + gemm = gemmg c_gemmC instance Element (CInt) where transdata = transdataAux ctransI @@ -337,6 +340,7 @@ instance Element (CInt) where selectV = selectI remapM = remapI rowOp = rowOpAux c_rowOpI + gemm = gemmg c_gemmI instance Element Z where transdata = transdataAux ctransL @@ -349,6 +353,7 @@ instance Element Z where selectV = selectL remapM = remapL rowOp = rowOpAux c_rowOpL + gemm = gemmg c_gemmL ------------------------------------------------------------------- @@ -575,6 +580,21 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- +gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" + +type Tgemm x = x :> I :> x ::> x ::> x ::> Ok + +foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R +foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float +foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C +foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) +foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I +foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z +foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I +foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z + +-------------------------------------------------------------------------------- + foreign import ccall unsafe "saveMatrix" c_saveMatrix :: CString -> CString -> Double ..> Ok diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 6c6d5c5..d158111 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -131,6 +131,9 @@ instance KnownNat m => Element (Mod m I) rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) where m' = fromIntegral . natVal $ (undefined :: Proxy m) + gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) instance KnownNat m => Element (Mod m Z) where @@ -146,6 +149,9 @@ instance KnownNat m => Element (Mod m Z) rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) where m' = fromIntegral . natVal $ (undefined :: Proxy m) + gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) instance forall m . KnownNat m => Container Vector (Mod m I) diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 434fe63..25e7f03 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -1,5 +1,6 @@ {-# LANGUAGE Rank2Types #-} {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE ViewPatterns #-} ----------------------------------------------------------------------------- -- | @@ -15,14 +16,14 @@ ----------------------------------------------------------------------------- module Internal.ST ( + ST, runST, -- * Mutable Vectors STVector, newVector, thawVector, freezeVector, runSTVector, readVector, writeVector, modifyVector, liftSTVector, -- * Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, --- axpy, scal, swap, rowOp, - mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), + mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), -- * Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, @@ -70,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k -liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a +liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x -freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) +freezeVector :: (Storable t) => STVector s t -> ST s (Vector t) freezeVector v = liftSTVector id v -unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) +unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x {-# INLINE safeIndexV #-} @@ -139,14 +140,14 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c -liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a +liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x -unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) +unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x -freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) +freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) freezeMatrix m = liftSTMatrix id m cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) @@ -227,6 +228,16 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i (i1,i2) = getRowRange (rows m) rr (j1,j2) = getColRange (cols m) rc +data Slice s t = Slice (STMatrix s t) Int Int Int Int + +slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) + +gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res + where + res = unsafeIOToST (gemm u v a b r) + u = fromList [alpha,beta] + v = vjoin[pa,pb,pr] + mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) mutable f a = runST $ do diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 36c5f03..db4236b 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs @@ -43,7 +43,7 @@ module Numeric.LinearAlgebra.Devel( -- ** Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, - mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), + mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), -- ** Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, -- cgit v1.2.3