From b9329d636d19f6a26da1cf1fd7e8d7cbd0b04cce Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 30 Jun 2015 12:04:21 +0200 Subject: support slice in multiply --- packages/base/src/Internal/C/lapack-aux.c | 104 ++++++++++++------------------ 1 file changed, 42 insertions(+), 62 deletions(-) (limited to 'packages/base/src/Internal/C') diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ca60846..30689bf 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -1093,16 +1093,15 @@ void dgemm_(char *, char *, integer *, integer *, integer *, integer *, double *, double *, integer *); int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("dgemm_"); CHECKNANR(a,"NaN multR Input\n") CHECKNANR(b,"NaN multR Input\n") integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; double alpha = 1; double beta = 0; dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); @@ -1115,16 +1114,15 @@ void zgemm_(char *, char *, integer *, integer *, integer *, integer *, doublecomplex *, doublecomplex *, integer *); int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("zgemm_"); CHECKNANC(a,"NaN multC Input\n") CHECKNANC(b,"NaN multC Input\n") integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; doublecomplex alpha = {1,0}; doublecomplex beta = {0,0}; zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, @@ -1140,14 +1138,13 @@ void sgemm_(char *, char *, integer *, integer *, integer *, integer *, float *, float *, integer *); int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("sgemm_"); integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; float alpha = 1; float beta = 0; sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); @@ -1159,14 +1156,13 @@ void cgemm_(char *, char *, integer *, integer *, integer *, integer *, complex *, complex *, integer *); int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("cgemm_"); integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; complex alpha = {1,0}; complex beta = {0,0}; cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, @@ -1187,15 +1183,15 @@ int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { } \ } -#define MULT_IMP { \ +#define MULT_IMP(M) { \ if (m==1) { \ MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ } else { \ - MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ + MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \ } OK } -int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP -int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP +int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod) +int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l) /////////////////////////////// inplace row ops //////////////////////////////// @@ -1277,27 +1273,19 @@ ROWOP_MOD(int64_t,mod_l) /////////////////////////////// inplace GEMM //////////////////////////////// -#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ - T a = cp[0], b = cp[1]; \ - int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ - int r1b = pp[4], c1b = pp[6] ; \ - int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ - int dra = r1a - r1r; \ - int dcb = c1b-c1r; \ - int nk = c2a-c1a+1; \ - int i,j,k; \ - T t; \ - for (i=r1r; i<=r2r; i++) { \ - for (j=c1r; j<=c2r; j++) { \ - t = 0; \ - for(k=0; k