summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/C
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-30 12:04:21 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-30 12:04:21 +0200
commitb9329d636d19f6a26da1cf1fd7e8d7cbd0b04cce (patch)
treec0beb22b3b394ed9d18a6a98d5cf1ca6d4ea8960 /packages/base/src/Internal/C
parent9c05df0cd663bafaf0b69eafee53fce45569dc95 (diff)
support slice in multiply
Diffstat (limited to 'packages/base/src/Internal/C')
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c104
1 files changed, 42 insertions, 62 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index ca60846..30689bf 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1093,16 +1093,15 @@ void dgemm_(char *, char *, integer *, integer *, integer *,
1093 integer *, double *, double *, integer *); 1093 integer *, double *, double *, integer *);
1094 1094
1095int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { 1095int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) {
1096 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1097 DEBUGMSG("dgemm_"); 1096 DEBUGMSG("dgemm_");
1098 CHECKNANR(a,"NaN multR Input\n") 1097 CHECKNANR(a,"NaN multR Input\n")
1099 CHECKNANR(b,"NaN multR Input\n") 1098 CHECKNANR(b,"NaN multR Input\n")
1100 integer m = ta?ac:ar; 1099 integer m = ta?ac:ar;
1101 integer n = tb?br:bc; 1100 integer n = tb?br:bc;
1102 integer k = ta?ar:ac; 1101 integer k = ta?ar:ac;
1103 integer lda = ar; 1102 integer lda = aXc;
1104 integer ldb = br; 1103 integer ldb = bXc;
1105 integer ldc = rr; 1104 integer ldc = rXc;
1106 double alpha = 1; 1105 double alpha = 1;
1107 double beta = 0; 1106 double beta = 0;
1108 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); 1107 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
@@ -1115,16 +1114,15 @@ void zgemm_(char *, char *, integer *, integer *, integer *,
1115 integer *, doublecomplex *, doublecomplex *, integer *); 1114 integer *, doublecomplex *, doublecomplex *, integer *);
1116 1115
1117int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { 1116int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) {
1118 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1119 DEBUGMSG("zgemm_"); 1117 DEBUGMSG("zgemm_");
1120 CHECKNANC(a,"NaN multC Input\n") 1118 CHECKNANC(a,"NaN multC Input\n")
1121 CHECKNANC(b,"NaN multC Input\n") 1119 CHECKNANC(b,"NaN multC Input\n")
1122 integer m = ta?ac:ar; 1120 integer m = ta?ac:ar;
1123 integer n = tb?br:bc; 1121 integer n = tb?br:bc;
1124 integer k = ta?ar:ac; 1122 integer k = ta?ar:ac;
1125 integer lda = ar; 1123 integer lda = aXc;
1126 integer ldb = br; 1124 integer ldb = bXc;
1127 integer ldc = rr; 1125 integer ldc = rXc;
1128 doublecomplex alpha = {1,0}; 1126 doublecomplex alpha = {1,0};
1129 doublecomplex beta = {0,0}; 1127 doublecomplex beta = {0,0};
1130 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, 1128 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
@@ -1140,14 +1138,13 @@ void sgemm_(char *, char *, integer *, integer *, integer *,
1140 integer *, float *, float *, integer *); 1138 integer *, float *, float *, integer *);
1141 1139
1142int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { 1140int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) {
1143 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1144 DEBUGMSG("sgemm_"); 1141 DEBUGMSG("sgemm_");
1145 integer m = ta?ac:ar; 1142 integer m = ta?ac:ar;
1146 integer n = tb?br:bc; 1143 integer n = tb?br:bc;
1147 integer k = ta?ar:ac; 1144 integer k = ta?ar:ac;
1148 integer lda = ar; 1145 integer lda = aXc;
1149 integer ldb = br; 1146 integer ldb = bXc;
1150 integer ldc = rr; 1147 integer ldc = rXc;
1151 float alpha = 1; 1148 float alpha = 1;
1152 float beta = 0; 1149 float beta = 0;
1153 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); 1150 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
@@ -1159,14 +1156,13 @@ void cgemm_(char *, char *, integer *, integer *, integer *,
1159 integer *, complex *, complex *, integer *); 1156 integer *, complex *, complex *, integer *);
1160 1157
1161int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { 1158int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) {
1162 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1163 DEBUGMSG("cgemm_"); 1159 DEBUGMSG("cgemm_");
1164 integer m = ta?ac:ar; 1160 integer m = ta?ac:ar;
1165 integer n = tb?br:bc; 1161 integer n = tb?br:bc;
1166 integer k = ta?ar:ac; 1162 integer k = ta?ar:ac;
1167 integer lda = ar; 1163 integer lda = aXc;
1168 integer ldb = br; 1164 integer ldb = bXc;
1169 integer ldc = rr; 1165 integer ldc = rXc;
1170 complex alpha = {1,0}; 1166 complex alpha = {1,0};
1171 complex beta = {0,0}; 1167 complex beta = {0,0};
1172 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, 1168 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
@@ -1187,15 +1183,15 @@ int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) {
1187 } \ 1183 } \
1188 } 1184 }
1189 1185
1190#define MULT_IMP { \ 1186#define MULT_IMP(M) { \
1191 if (m==1) { \ 1187 if (m==1) { \
1192 MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ 1188 MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \
1193 } else { \ 1189 } else { \
1194 MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ 1190 MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \
1195 } OK } 1191 } OK }
1196 1192
1197int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP 1193int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod)
1198int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP 1194int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l)
1199 1195
1200/////////////////////////////// inplace row ops //////////////////////////////// 1196/////////////////////////////// inplace row ops ////////////////////////////////
1201 1197
@@ -1277,27 +1273,19 @@ ROWOP_MOD(int64_t,mod_l)
1277 1273
1278/////////////////////////////// inplace GEMM //////////////////////////////// 1274/////////////////////////////// inplace GEMM ////////////////////////////////
1279 1275
1280#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ 1276#define GEMM(T) int gemm_##T(VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1281 T a = cp[0], b = cp[1]; \ 1277 T a = cp[0], b = cp[1]; \
1282 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ 1278 T t; \
1283 int r1b = pp[4], c1b = pp[6] ; \ 1279 int k; \
1284 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ 1280 { TRAV(r,i,j) { \
1285 int dra = r1a - r1r; \ 1281 t = 0; \
1286 int dcb = c1b-c1r; \ 1282 for(k=0; k<ac; k++) { \
1287 int nk = c2a-c1a+1; \ 1283 t += AT(a,i,k) * AT(b,k,j); \
1288 int i,j,k; \ 1284 } \
1289 T t; \ 1285 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1290 for (i=r1r; i<=r2r; i++) { \ 1286 } \
1291 for (j=c1r; j<=c2r; j++) { \ 1287 } OK }
1292 t = 0; \ 1288
1293 for(k=0; k<nk; k++) { \
1294 t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \
1295 } \
1296 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1297 } \
1298 } \
1299 OK \
1300}
1301 1289
1302GEMM(double) 1290GEMM(double)
1303GEMM(float) 1291GEMM(float)
@@ -1306,27 +1294,19 @@ GEMM(TCF)
1306GEMM(int32_t) 1294GEMM(int32_t)
1307GEMM(int64_t) 1295GEMM(int64_t)
1308 1296
1309#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ 1297#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1310 T a = cp[0], b = cp[1]; \ 1298 T a = cp[0], b = cp[1]; \
1311 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ 1299 int k; \
1312 int r1b = pp[4], c1b = pp[6] ; \ 1300 T t; \
1313 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ 1301 { TRAV(r,i,j) { \
1314 int dra = r1a - r1r; \ 1302 t = 0; \
1315 int dcb = c1b-c1r; \ 1303 for(k=0; k<ac; k++) { \
1316 int nk = c2a-c1a+1; \ 1304 t = M(t+M(AT(a,i,k) * AT(b,k,j))); \
1317 int i,j,k; \ 1305 } \
1318 T t; \ 1306 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1319 for (i=r1r; i<=r2r; i++) { \ 1307 } \
1320 for (j=c1r; j<=c2r; j++) { \ 1308 } OK }
1321 t = 0; \ 1309
1322 for(k=0; k<nk; k++) { \
1323 t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \
1324 } \
1325 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1326 } \
1327 } \
1328 OK \
1329}
1330 1310
1331#define MOD32(X) mod(X,m) 1311#define MOD32(X) mod(X,m)
1332#define MOD64(X) mod_l(X,m) 1312#define MOD64(X) mod_l(X,m)