diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 104 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 8 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 4 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 9 |
4 files changed, 53 insertions, 72 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ca60846..30689bf 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -1093,16 +1093,15 @@ void dgemm_(char *, char *, integer *, integer *, integer *, | |||
1093 | integer *, double *, double *, integer *); | 1093 | integer *, double *, double *, integer *); |
1094 | 1094 | ||
1095 | int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { | 1095 | int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { |
1096 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1097 | DEBUGMSG("dgemm_"); | 1096 | DEBUGMSG("dgemm_"); |
1098 | CHECKNANR(a,"NaN multR Input\n") | 1097 | CHECKNANR(a,"NaN multR Input\n") |
1099 | CHECKNANR(b,"NaN multR Input\n") | 1098 | CHECKNANR(b,"NaN multR Input\n") |
1100 | integer m = ta?ac:ar; | 1099 | integer m = ta?ac:ar; |
1101 | integer n = tb?br:bc; | 1100 | integer n = tb?br:bc; |
1102 | integer k = ta?ar:ac; | 1101 | integer k = ta?ar:ac; |
1103 | integer lda = ar; | 1102 | integer lda = aXc; |
1104 | integer ldb = br; | 1103 | integer ldb = bXc; |
1105 | integer ldc = rr; | 1104 | integer ldc = rXc; |
1106 | double alpha = 1; | 1105 | double alpha = 1; |
1107 | double beta = 0; | 1106 | double beta = 0; |
1108 | dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); | 1107 | dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); |
@@ -1115,16 +1114,15 @@ void zgemm_(char *, char *, integer *, integer *, integer *, | |||
1115 | integer *, doublecomplex *, doublecomplex *, integer *); | 1114 | integer *, doublecomplex *, doublecomplex *, integer *); |
1116 | 1115 | ||
1117 | int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { | 1116 | int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { |
1118 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1119 | DEBUGMSG("zgemm_"); | 1117 | DEBUGMSG("zgemm_"); |
1120 | CHECKNANC(a,"NaN multC Input\n") | 1118 | CHECKNANC(a,"NaN multC Input\n") |
1121 | CHECKNANC(b,"NaN multC Input\n") | 1119 | CHECKNANC(b,"NaN multC Input\n") |
1122 | integer m = ta?ac:ar; | 1120 | integer m = ta?ac:ar; |
1123 | integer n = tb?br:bc; | 1121 | integer n = tb?br:bc; |
1124 | integer k = ta?ar:ac; | 1122 | integer k = ta?ar:ac; |
1125 | integer lda = ar; | 1123 | integer lda = aXc; |
1126 | integer ldb = br; | 1124 | integer ldb = bXc; |
1127 | integer ldc = rr; | 1125 | integer ldc = rXc; |
1128 | doublecomplex alpha = {1,0}; | 1126 | doublecomplex alpha = {1,0}; |
1129 | doublecomplex beta = {0,0}; | 1127 | doublecomplex beta = {0,0}; |
1130 | zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, | 1128 | zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, |
@@ -1140,14 +1138,13 @@ void sgemm_(char *, char *, integer *, integer *, integer *, | |||
1140 | integer *, float *, float *, integer *); | 1138 | integer *, float *, float *, integer *); |
1141 | 1139 | ||
1142 | int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { | 1140 | int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { |
1143 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1144 | DEBUGMSG("sgemm_"); | 1141 | DEBUGMSG("sgemm_"); |
1145 | integer m = ta?ac:ar; | 1142 | integer m = ta?ac:ar; |
1146 | integer n = tb?br:bc; | 1143 | integer n = tb?br:bc; |
1147 | integer k = ta?ar:ac; | 1144 | integer k = ta?ar:ac; |
1148 | integer lda = ar; | 1145 | integer lda = aXc; |
1149 | integer ldb = br; | 1146 | integer ldb = bXc; |
1150 | integer ldc = rr; | 1147 | integer ldc = rXc; |
1151 | float alpha = 1; | 1148 | float alpha = 1; |
1152 | float beta = 0; | 1149 | float beta = 0; |
1153 | sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); | 1150 | sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); |
@@ -1159,14 +1156,13 @@ void cgemm_(char *, char *, integer *, integer *, integer *, | |||
1159 | integer *, complex *, complex *, integer *); | 1156 | integer *, complex *, complex *, integer *); |
1160 | 1157 | ||
1161 | int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { | 1158 | int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { |
1162 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | ||
1163 | DEBUGMSG("cgemm_"); | 1159 | DEBUGMSG("cgemm_"); |
1164 | integer m = ta?ac:ar; | 1160 | integer m = ta?ac:ar; |
1165 | integer n = tb?br:bc; | 1161 | integer n = tb?br:bc; |
1166 | integer k = ta?ar:ac; | 1162 | integer k = ta?ar:ac; |
1167 | integer lda = ar; | 1163 | integer lda = aXc; |
1168 | integer ldb = br; | 1164 | integer ldb = bXc; |
1169 | integer ldc = rr; | 1165 | integer ldc = rXc; |
1170 | complex alpha = {1,0}; | 1166 | complex alpha = {1,0}; |
1171 | complex beta = {0,0}; | 1167 | complex beta = {0,0}; |
1172 | cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, | 1168 | cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, |
@@ -1187,15 +1183,15 @@ int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { | |||
1187 | } \ | 1183 | } \ |
1188 | } | 1184 | } |
1189 | 1185 | ||
1190 | #define MULT_IMP { \ | 1186 | #define MULT_IMP(M) { \ |
1191 | if (m==1) { \ | 1187 | if (m==1) { \ |
1192 | MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ | 1188 | MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ |
1193 | } else { \ | 1189 | } else { \ |
1194 | MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ | 1190 | MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \ |
1195 | } OK } | 1191 | } OK } |
1196 | 1192 | ||
1197 | int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP | 1193 | int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod) |
1198 | int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP | 1194 | int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l) |
1199 | 1195 | ||
1200 | /////////////////////////////// inplace row ops //////////////////////////////// | 1196 | /////////////////////////////// inplace row ops //////////////////////////////// |
1201 | 1197 | ||
@@ -1277,27 +1273,19 @@ ROWOP_MOD(int64_t,mod_l) | |||
1277 | 1273 | ||
1278 | /////////////////////////////// inplace GEMM //////////////////////////////// | 1274 | /////////////////////////////// inplace GEMM //////////////////////////////// |
1279 | 1275 | ||
1280 | #define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ | 1276 | #define GEMM(T) int gemm_##T(VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \ |
1281 | T a = cp[0], b = cp[1]; \ | 1277 | T a = cp[0], b = cp[1]; \ |
1282 | int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ | 1278 | T t; \ |
1283 | int r1b = pp[4], c1b = pp[6] ; \ | 1279 | int k; \ |
1284 | int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ | 1280 | { TRAV(r,i,j) { \ |
1285 | int dra = r1a - r1r; \ | 1281 | t = 0; \ |
1286 | int dcb = c1b-c1r; \ | 1282 | for(k=0; k<ac; k++) { \ |
1287 | int nk = c2a-c1a+1; \ | 1283 | t += AT(a,i,k) * AT(b,k,j); \ |
1288 | int i,j,k; \ | 1284 | } \ |
1289 | T t; \ | 1285 | AT(r,i,j) = b*AT(r,i,j) + a*t; \ |
1290 | for (i=r1r; i<=r2r; i++) { \ | 1286 | } \ |
1291 | for (j=c1r; j<=c2r; j++) { \ | 1287 | } OK } |
1292 | t = 0; \ | 1288 | |
1293 | for(k=0; k<nk; k++) { \ | ||
1294 | t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \ | ||
1295 | } \ | ||
1296 | AT(r,i,j) = b*AT(r,i,j) + a*t; \ | ||
1297 | } \ | ||
1298 | } \ | ||
1299 | OK \ | ||
1300 | } | ||
1301 | 1289 | ||
1302 | GEMM(double) | 1290 | GEMM(double) |
1303 | GEMM(float) | 1291 | GEMM(float) |
@@ -1306,27 +1294,19 @@ GEMM(TCF) | |||
1306 | GEMM(int32_t) | 1294 | GEMM(int32_t) |
1307 | GEMM(int64_t) | 1295 | GEMM(int64_t) |
1308 | 1296 | ||
1309 | #define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ | 1297 | #define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),MATG(T,a),MATG(T,b),MATG(T,r)) { \ |
1310 | T a = cp[0], b = cp[1]; \ | 1298 | T a = cp[0], b = cp[1]; \ |
1311 | int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ | 1299 | int k; \ |
1312 | int r1b = pp[4], c1b = pp[6] ; \ | 1300 | T t; \ |
1313 | int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ | 1301 | { TRAV(r,i,j) { \ |
1314 | int dra = r1a - r1r; \ | 1302 | t = 0; \ |
1315 | int dcb = c1b-c1r; \ | 1303 | for(k=0; k<ac; k++) { \ |
1316 | int nk = c2a-c1a+1; \ | 1304 | t = M(t+M(AT(a,i,k) * AT(b,k,j))); \ |
1317 | int i,j,k; \ | 1305 | } \ |
1318 | T t; \ | 1306 | AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \ |
1319 | for (i=r1r; i<=r2r; i++) { \ | 1307 | } \ |
1320 | for (j=c1r; j<=c2r; j++) { \ | 1308 | } OK } |
1321 | t = 0; \ | 1309 | |
1322 | for(k=0; k<nk; k++) { \ | ||
1323 | t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \ | ||
1324 | } \ | ||
1325 | AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \ | ||
1326 | } \ | ||
1327 | } \ | ||
1328 | OK \ | ||
1329 | } | ||
1330 | 1310 | ||
1331 | #define MOD32(X) mod(X,m) | 1311 | #define MOD32(X) mod(X,m) |
1332 | #define MOD64(X) mod_l(X,m) | 1312 | #define MOD64(X) mod_l(X,m) |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8597dcb..a789cae 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -226,6 +226,8 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | |||
226 | 226 | ||
227 | ------------------------------------------------------------------ | 227 | ------------------------------------------------------------------ |
228 | 228 | ||
229 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | ||
230 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | ||
229 | matrixFromVector o r c v | 231 | matrixFromVector o r c v |
230 | | r * c == dim v = m | 232 | | r * c == dim v = m |
231 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | 233 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
@@ -280,7 +282,7 @@ class (Storable a) => Element a where | |||
280 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | 282 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a |
281 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 283 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
282 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 284 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
283 | gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () | 285 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () |
284 | 286 | ||
285 | 287 | ||
286 | instance Element Float where | 288 | instance Element Float where |
@@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
569 | 571 | ||
570 | -------------------------------------------------------------------------------- | 572 | -------------------------------------------------------------------------------- |
571 | 573 | ||
572 | gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" | 574 | gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" |
573 | 575 | ||
574 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok | 576 | type Tgemm x = x :> x ::> x ::> x ::> Ok |
575 | 577 | ||
576 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | 578 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R |
577 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | 579 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 54d9cb8..8fa2747 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -137,7 +137,7 @@ instance KnownNat m => Element (Mod m I) | |||
137 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) | 137 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) |
138 | where | 138 | where |
139 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 139 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
140 | gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) | 140 | gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) |
141 | where | 141 | where |
142 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 142 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
143 | 143 | ||
@@ -154,7 +154,7 @@ instance KnownNat m => Element (Mod m Z) | |||
154 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) | 154 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) |
155 | where | 155 | where |
156 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 156 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
157 | gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) | 157 | gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) |
158 | where | 158 | where |
159 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 159 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
160 | 160 | ||
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 91c2a11..62dfddf 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -231,14 +231,13 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
231 | -- | r0 c0 height width | 231 | -- | r0 c0 height width |
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
233 | 233 | ||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) | 234 | slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m |
235 | 235 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
237 | gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res | 237 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res |
238 | where | 238 | where |
239 | res = unsafeIOToST (gemm u v a b r) | 239 | res = unsafeIOToST (gemm v a b r) |
240 | u = fromList [alpha,beta] | 240 | v = fromList [alpha,beta] |
241 | v = vjoin[pa,pb,pr] | ||
242 | 241 | ||
243 | 242 | ||
244 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 243 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |