summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-30 12:04:21 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-30 12:04:21 +0200
commitb9329d636d19f6a26da1cf1fd7e8d7cbd0b04cce (patch)
treec0beb22b3b394ed9d18a6a98d5cf1ca6d4ea8960 /packages/base/src
parent9c05df0cd663bafaf0b69eafee53fce45569dc95 (diff)
support slice in multiply
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c104
-rw-r--r--packages/base/src/Internal/Matrix.hs8
-rw-r--r--packages/base/src/Internal/Modular.hs4
-rw-r--r--packages/base/src/Internal/ST.hs9
4 files changed, 53 insertions, 72 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index ca60846..30689bf 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1093,16 +1093,15 @@ void dgemm_(char *, char *, integer *, integer *, integer *,
1093 integer *, double *, double *, integer *); 1093 integer *, double *, double *, integer *);
1094 1094
1095int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { 1095int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) {
1096 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1097 DEBUGMSG("dgemm_"); 1096 DEBUGMSG("dgemm_");
1098 CHECKNANR(a,"NaN multR Input\n") 1097 CHECKNANR(a,"NaN multR Input\n")
1099 CHECKNANR(b,"NaN multR Input\n") 1098 CHECKNANR(b,"NaN multR Input\n")
1100 integer m = ta?ac:ar; 1099 integer m = ta?ac:ar;
1101 integer n = tb?br:bc; 1100 integer n = tb?br:bc;
1102 integer k = ta?ar:ac; 1101 integer k = ta?ar:ac;
1103 integer lda = ar; 1102 integer lda = aXc;
1104 integer ldb = br; 1103 integer ldb = bXc;
1105 integer ldc = rr; 1104 integer ldc = rXc;
1106 double alpha = 1; 1105 double alpha = 1;
1107 double beta = 0; 1106 double beta = 0;
1108 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); 1107 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
@@ -1115,16 +1114,15 @@ void zgemm_(char *, char *, integer *, integer *, integer *,
1115 integer *, doublecomplex *, doublecomplex *, integer *); 1114 integer *, doublecomplex *, doublecomplex *, integer *);
1116 1115
1117int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { 1116int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) {
1118 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1119 DEBUGMSG("zgemm_"); 1117 DEBUGMSG("zgemm_");
1120 CHECKNANC(a,"NaN multC Input\n") 1118 CHECKNANC(a,"NaN multC Input\n")
1121 CHECKNANC(b,"NaN multC Input\n") 1119 CHECKNANC(b,"NaN multC Input\n")
1122 integer m = ta?ac:ar; 1120 integer m = ta?ac:ar;
1123 integer n = tb?br:bc; 1121 integer n = tb?br:bc;
1124 integer k = ta?ar:ac; 1122 integer k = ta?ar:ac;
1125 integer lda = ar; 1123 integer lda = aXc;
1126 integer ldb = br; 1124 integer ldb = bXc;
1127 integer ldc = rr; 1125 integer ldc = rXc;
1128 doublecomplex alpha = {1,0}; 1126 doublecomplex alpha = {1,0};
1129 doublecomplex beta = {0,0}; 1127 doublecomplex beta = {0,0};
1130 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, 1128 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
@@ -1140,14 +1138,13 @@ void sgemm_(char *, char *, integer *, integer *, integer *,
1140 integer *, float *, float *, integer *); 1138 integer *, float *, float *, integer *);
1141 1139
1142int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { 1140int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) {
1143 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1144 DEBUGMSG("sgemm_"); 1141 DEBUGMSG("sgemm_");
1145 integer m = ta?ac:ar; 1142 integer m = ta?ac:ar;
1146 integer n = tb?br:bc; 1143 integer n = tb?br:bc;
1147 integer k = ta?ar:ac; 1144 integer k = ta?ar:ac;
1148 integer lda = ar; 1145 integer lda = aXc;
1149 integer ldb = br; 1146 integer ldb = bXc;
1150 integer ldc = rr; 1147 integer ldc = rXc;
1151 float alpha = 1; 1148 float alpha = 1;
1152 float beta = 0; 1149 float beta = 0;
1153 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); 1150 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
@@ -1159,14 +1156,13 @@ void cgemm_(char *, char *, integer *, integer *, integer *,
1159 integer *, complex *, complex *, integer *); 1156 integer *, complex *, complex *, integer *);
1160 1157
1161int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { 1158int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) {
1162 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1163 DEBUGMSG("cgemm_"); 1159 DEBUGMSG("cgemm_");
1164 integer m = ta?ac:ar; 1160 integer m = ta?ac:ar;
1165 integer n = tb?br:bc; 1161 integer n = tb?br:bc;
1166 integer k = ta?ar:ac; 1162 integer k = ta?ar:ac;
1167 integer lda = ar; 1163 integer lda = aXc;
1168 integer ldb = br; 1164 integer ldb = bXc;
1169 integer ldc = rr; 1165 integer ldc = rXc;
1170 complex alpha = {1,0}; 1166 complex alpha = {1,0};
1171 complex beta = {0,0}; 1167 complex beta = {0,0};
1172 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, 1168 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
@@ -1187,15 +1183,15 @@ int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) {
1187 } \ 1183 } \
1188 } 1184 }
1189 1185
1190#define MULT_IMP { \ 1186#define MULT_IMP(M) { \
1191 if (m==1) { \ 1187 if (m==1) { \
1192 MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ 1188 MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \
1193 } else { \ 1189 } else { \
1194 MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ 1190 MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \
1195 } OK } 1191 } OK }
1196 1192
1197int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP 1193int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod)
1198int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP 1194int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l)
1199 1195
1200/////////////////////////////// inplace row ops //////////////////////////////// 1196/////////////////////////////// inplace row ops ////////////////////////////////
1201 1197
@@ -1277,27 +1273,19 @@ ROWOP_MOD(int64_t,mod_l)
1277 1273
1278/////////////////////////////// inplace GEMM //////////////////////////////// 1274/////////////////////////////// inplace GEMM ////////////////////////////////
1279 1275
1280#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ 1276#define GEMM(T) int gemm_##T(VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1281 T a = cp[0], b = cp[1]; \ 1277 T a = cp[0], b = cp[1]; \
1282 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ 1278 T t; \
1283 int r1b = pp[4], c1b = pp[6] ; \ 1279 int k; \
1284 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ 1280 { TRAV(r,i,j) { \
1285 int dra = r1a - r1r; \ 1281 t = 0; \
1286 int dcb = c1b-c1r; \ 1282 for(k=0; k<ac; k++) { \
1287 int nk = c2a-c1a+1; \ 1283 t += AT(a,i,k) * AT(b,k,j); \
1288 int i,j,k; \ 1284 } \
1289 T t; \ 1285 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1290 for (i=r1r; i<=r2r; i++) { \ 1286 } \
1291 for (j=c1r; j<=c2r; j++) { \ 1287 } OK }
1292 t = 0; \ 1288
1293 for(k=0; k<nk; k++) { \
1294 t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \
1295 } \
1296 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1297 } \
1298 } \
1299 OK \
1300}
1301 1289
1302GEMM(double) 1290GEMM(double)
1303GEMM(float) 1291GEMM(float)
@@ -1306,27 +1294,19 @@ GEMM(TCF)
1306GEMM(int32_t) 1294GEMM(int32_t)
1307GEMM(int64_t) 1295GEMM(int64_t)
1308 1296
1309#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ 1297#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1310 T a = cp[0], b = cp[1]; \ 1298 T a = cp[0], b = cp[1]; \
1311 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ 1299 int k; \
1312 int r1b = pp[4], c1b = pp[6] ; \ 1300 T t; \
1313 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ 1301 { TRAV(r,i,j) { \
1314 int dra = r1a - r1r; \ 1302 t = 0; \
1315 int dcb = c1b-c1r; \ 1303 for(k=0; k<ac; k++) { \
1316 int nk = c2a-c1a+1; \ 1304 t = M(t+M(AT(a,i,k) * AT(b,k,j))); \
1317 int i,j,k; \ 1305 } \
1318 T t; \ 1306 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1319 for (i=r1r; i<=r2r; i++) { \ 1307 } \
1320 for (j=c1r; j<=c2r; j++) { \ 1308 } OK }
1321 t = 0; \ 1309
1322 for(k=0; k<nk; k++) { \
1323 t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \
1324 } \
1325 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1326 } \
1327 } \
1328 OK \
1329}
1330 1310
1331#define MOD32(X) mod(X,m) 1311#define MOD32(X) mod(X,m)
1332#define MOD64(X) mod_l(X,m) 1312#define MOD64(X) mod_l(X,m)
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 8597dcb..a789cae 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -226,6 +226,8 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
226 226
227------------------------------------------------------------------ 227------------------------------------------------------------------
228 228
229matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
230matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
229matrixFromVector o r c v 231matrixFromVector o r c v
230 | r * c == dim v = m 232 | r * c == dim v = m
231 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m 233 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
@@ -280,7 +282,7 @@ class (Storable a) => Element a where
280 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 282 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
281 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 283 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
282 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 284 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
283 gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () 285 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
284 286
285 287
286instance Element Float where 288instance Element Float where
@@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
569 571
570-------------------------------------------------------------------------------- 572--------------------------------------------------------------------------------
571 573
572gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" 574gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg"
573 575
574type Tgemm x = x :> I :> x ::> x ::> x ::> Ok 576type Tgemm x = x :> x ::> x ::> x ::> Ok
575 577
576foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R 578foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
577foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float 579foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 54d9cb8..8fa2747 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -137,7 +137,7 @@ instance KnownNat m => Element (Mod m I)
137 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) 137 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
138 where 138 where
139 m' = fromIntegral . natVal $ (undefined :: Proxy m) 139 m' = fromIntegral . natVal $ (undefined :: Proxy m)
140 gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) 140 gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
141 where 141 where
142 m' = fromIntegral . natVal $ (undefined :: Proxy m) 142 m' = fromIntegral . natVal $ (undefined :: Proxy m)
143 143
@@ -154,7 +154,7 @@ instance KnownNat m => Element (Mod m Z)
154 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) 154 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
155 where 155 where
156 m' = fromIntegral . natVal $ (undefined :: Proxy m) 156 m' = fromIntegral . natVal $ (undefined :: Proxy m)
157 gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) 157 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
158 where 158 where
159 m' = fromIntegral . natVal $ (undefined :: Proxy m) 159 m' = fromIntegral . natVal $ (undefined :: Proxy m)
160 160
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 91c2a11..62dfddf 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -231,14 +231,13 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
231-- | r0 c0 height width 231-- | r0 c0 height width
232data Slice s t = Slice (STMatrix s t) Int Int Int Int 232data Slice s t = Slice (STMatrix s t) Int Int Int Int
233 233
234slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) 234slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m
235 235
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
237gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res 237gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
238 where 238 where
239 res = unsafeIOToST (gemm u v a b r) 239 res = unsafeIOToST (gemm v a b r)
240 u = fromList [alpha,beta] 240 v = fromList [alpha,beta]
241 v = vjoin[pa,pb,pr]
242 241
243 242
244mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 243mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)