From b9329d636d19f6a26da1cf1fd7e8d7cbd0b04cce Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 30 Jun 2015 12:04:21 +0200 Subject: support slice in multiply --- packages/base/src/Internal/C/lapack-aux.c | 104 ++++++++++++------------------ packages/base/src/Internal/Matrix.hs | 8 ++- packages/base/src/Internal/Modular.hs | 4 +- packages/base/src/Internal/ST.hs | 9 ++- 4 files changed, 53 insertions(+), 72 deletions(-) (limited to 'packages/base') diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ca60846..30689bf 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -1093,16 +1093,15 @@ void dgemm_(char *, char *, integer *, integer *, integer *, integer *, double *, double *, integer *); int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("dgemm_"); CHECKNANR(a,"NaN multR Input\n") CHECKNANR(b,"NaN multR Input\n") integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; double alpha = 1; double beta = 0; dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); @@ -1115,16 +1114,15 @@ void zgemm_(char *, char *, integer *, integer *, integer *, integer *, doublecomplex *, doublecomplex *, integer *); int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("zgemm_"); CHECKNANC(a,"NaN multC Input\n") CHECKNANC(b,"NaN multC Input\n") integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; doublecomplex alpha = {1,0}; doublecomplex beta = {0,0}; zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, @@ -1140,14 +1138,13 @@ void sgemm_(char *, char *, integer *, integer *, integer *, integer *, float *, float *, integer *); int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("sgemm_"); integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; float alpha = 1; float beta = 0; sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); @@ -1159,14 +1156,13 @@ void cgemm_(char *, char *, integer *, integer *, integer *, integer *, complex *, complex *, integer *); int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("cgemm_"); integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; complex alpha = {1,0}; complex beta = {0,0}; cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, @@ -1187,15 +1183,15 @@ int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { } \ } -#define MULT_IMP { \ +#define MULT_IMP(M) { \ if (m==1) { \ MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ } else { \ - MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ + MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \ } OK } -int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP -int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP +int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod) +int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l) /////////////////////////////// inplace row ops //////////////////////////////// @@ -1277,27 +1273,19 @@ ROWOP_MOD(int64_t,mod_l) /////////////////////////////// inplace GEMM //////////////////////////////// -#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ - T a = cp[0], b = cp[1]; \ - int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ - int r1b = pp[4], c1b = pp[6] ; \ - int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ - int dra = r1a - r1r; \ - int dcb = c1b-c1r; \ - int nk = c2a-c1a+1; \ - int i,j,k; \ - T t; \ - for (i=r1r; i<=r2r; i++) { \ - for (j=c1r; j<=c2r; j++) { \ - t = 0; \ - for(k=0; kd) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } +matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } matrixFromVector o r c v | r * c == dim v = m | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m @@ -280,7 +282,7 @@ class (Storable a) => Element a where selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () - gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () + gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () instance Element Float where @@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- -gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" +gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" -type Tgemm x = x :> I :> x ::> x ::> x ::> Ok +type Tgemm x = x :> x ::> x ::> x ::> Ok foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 54d9cb8..8fa2747 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -137,7 +137,7 @@ instance KnownNat m => Element (Mod m I) rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) where m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) + gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) where m' = fromIntegral . natVal $ (undefined :: Proxy m) @@ -154,7 +154,7 @@ instance KnownNat m => Element (Mod m Z) rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) where m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) + gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) where m' = fromIntegral . natVal $ (undefined :: Proxy m) diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 91c2a11..62dfddf 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -231,14 +231,13 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ -- | r0 c0 height width data Slice s t = Slice (STMatrix s t) Int Int Int Int -slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) +slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () -gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res +gemmm beta (slice->r) alpha (slice->a) (slice->b) = res where - res = unsafeIOToST (gemm u v a b r) - u = fromList [alpha,beta] - v = vjoin[pa,pb,pr] + res = unsafeIOToST (gemm v a b r) + v = fromList [alpha,beta] mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) -- cgit v1.2.3