diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-17 13:02:40 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-17 13:02:40 +0200 |
commit | 61d90ff66af8bfe53ef8cdda8dfe1e70463c213c (patch) | |
tree | 79ca6e2024731708a4378ff691b5317d8da11d11 /packages/base | |
parent | e7d2916f78b5c140738fc4f4f95c9b13c1768293 (diff) |
gemmm
Diffstat (limited to 'packages/base')
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 59 | ||||
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.h | 1 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 24 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 6 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 27 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 2 |
6 files changed, 108 insertions, 11 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 2843ab5..4d48594 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -1398,6 +1398,65 @@ ROWOP(int64_t) | |||
1398 | ROWOP_MOD(int32_t,mod) | 1398 | ROWOP_MOD(int32_t,mod) |
1399 | ROWOP_MOD(int64_t,mod_l) | 1399 | ROWOP_MOD(int64_t,mod_l) |
1400 | 1400 | ||
1401 | /////////////////////////////// inplace GEMM //////////////////////////////// | ||
1402 | |||
1403 | #define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ | ||
1404 | T a = cp[0], b = cp[1]; \ | ||
1405 | int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ | ||
1406 | int r1b = pp[4], c1b = pp[6] ; \ | ||
1407 | int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ | ||
1408 | int dra = r1a - r1r; \ | ||
1409 | int dcb = c1b-c1r; \ | ||
1410 | int nk = c2a-c1a+1; \ | ||
1411 | int i,j,k; \ | ||
1412 | T t; \ | ||
1413 | for (i=r1r; i<=r2r; i++) { \ | ||
1414 | for (j=c1r; j<=c2r; j++) { \ | ||
1415 | t = 0; \ | ||
1416 | for(k=0; k<nk; k++) { \ | ||
1417 | t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \ | ||
1418 | } \ | ||
1419 | AT(r,i,j) = b*AT(r,i,j) + a*t; \ | ||
1420 | } \ | ||
1421 | } \ | ||
1422 | OK \ | ||
1423 | } | ||
1424 | |||
1425 | GEMM(double) | ||
1426 | GEMM(float) | ||
1427 | GEMM(TCD) | ||
1428 | GEMM(TCF) | ||
1429 | GEMM(int32_t) | ||
1430 | GEMM(int64_t) | ||
1431 | |||
1432 | #define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ | ||
1433 | T a = cp[0], b = cp[1]; \ | ||
1434 | int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ | ||
1435 | int r1b = pp[4], c1b = pp[6] ; \ | ||
1436 | int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ | ||
1437 | int dra = r1a - r1r; \ | ||
1438 | int dcb = c1b-c1r; \ | ||
1439 | int nk = c2a-c1a+1; \ | ||
1440 | int i,j,k; \ | ||
1441 | T t; \ | ||
1442 | for (i=r1r; i<=r2r; i++) { \ | ||
1443 | for (j=c1r; j<=c2r; j++) { \ | ||
1444 | t = 0; \ | ||
1445 | for(k=0; k<nk; k++) { \ | ||
1446 | t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \ | ||
1447 | } \ | ||
1448 | AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \ | ||
1449 | } \ | ||
1450 | } \ | ||
1451 | OK \ | ||
1452 | } | ||
1453 | |||
1454 | #define MOD32(X) mod(X,m) | ||
1455 | #define MOD64(X) mod_l(X,m) | ||
1456 | |||
1457 | GEMM_MOD(int32_t,MOD32) | ||
1458 | GEMM_MOD(int64_t,MOD64) | ||
1459 | |||
1401 | ////////////////// sparse matrix-product /////////////////////////////////////// | 1460 | ////////////////// sparse matrix-product /////////////////////////////////////// |
1402 | 1461 | ||
1403 | 1462 | ||
diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h index e4d95bc..bf8c5e9 100644 --- a/packages/base/src/Internal/C/lapack-aux.h +++ b/packages/base/src/Internal/C/lapack-aux.h | |||
@@ -59,6 +59,7 @@ typedef short ftnlen; | |||
59 | #define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p | 59 | #define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p |
60 | #define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p | 60 | #define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p |
61 | 61 | ||
62 | #define VECG(T,A) int A##n, T* A##p | ||
62 | #define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p | 63 | #define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p |
63 | 64 | ||
64 | #define KIVEC(A) int A##n, const int*A##p | 65 | #define KIVEC(A) int A##n, const int*A##p |
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 75e92a5..8f8c219 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -274,6 +274,7 @@ class (Storable a) => Element a where | |||
274 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | 274 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a |
275 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 275 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
276 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 276 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
277 | gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
277 | 278 | ||
278 | 279 | ||
279 | instance Element Float where | 280 | instance Element Float where |
@@ -287,6 +288,7 @@ instance Element Float where | |||
287 | selectV = selectF | 288 | selectV = selectF |
288 | remapM = remapF | 289 | remapM = remapF |
289 | rowOp = rowOpAux c_rowOpF | 290 | rowOp = rowOpAux c_rowOpF |
291 | gemm = gemmg c_gemmF | ||
290 | 292 | ||
291 | instance Element Double where | 293 | instance Element Double where |
292 | transdata = transdataAux ctransR | 294 | transdata = transdataAux ctransR |
@@ -299,7 +301,7 @@ instance Element Double where | |||
299 | selectV = selectD | 301 | selectV = selectD |
300 | remapM = remapD | 302 | remapM = remapD |
301 | rowOp = rowOpAux c_rowOpD | 303 | rowOp = rowOpAux c_rowOpD |
302 | 304 | gemm = gemmg c_gemmD | |
303 | 305 | ||
304 | instance Element (Complex Float) where | 306 | instance Element (Complex Float) where |
305 | transdata = transdataAux ctransQ | 307 | transdata = transdataAux ctransQ |
@@ -312,7 +314,7 @@ instance Element (Complex Float) where | |||
312 | selectV = selectQ | 314 | selectV = selectQ |
313 | remapM = remapQ | 315 | remapM = remapQ |
314 | rowOp = rowOpAux c_rowOpQ | 316 | rowOp = rowOpAux c_rowOpQ |
315 | 317 | gemm = gemmg c_gemmQ | |
316 | 318 | ||
317 | instance Element (Complex Double) where | 319 | instance Element (Complex Double) where |
318 | transdata = transdataAux ctransC | 320 | transdata = transdataAux ctransC |
@@ -325,6 +327,7 @@ instance Element (Complex Double) where | |||
325 | selectV = selectC | 327 | selectV = selectC |
326 | remapM = remapC | 328 | remapM = remapC |
327 | rowOp = rowOpAux c_rowOpC | 329 | rowOp = rowOpAux c_rowOpC |
330 | gemm = gemmg c_gemmC | ||
328 | 331 | ||
329 | instance Element (CInt) where | 332 | instance Element (CInt) where |
330 | transdata = transdataAux ctransI | 333 | transdata = transdataAux ctransI |
@@ -337,6 +340,7 @@ instance Element (CInt) where | |||
337 | selectV = selectI | 340 | selectV = selectI |
338 | remapM = remapI | 341 | remapM = remapI |
339 | rowOp = rowOpAux c_rowOpI | 342 | rowOp = rowOpAux c_rowOpI |
343 | gemm = gemmg c_gemmI | ||
340 | 344 | ||
341 | instance Element Z where | 345 | instance Element Z where |
342 | transdata = transdataAux ctransL | 346 | transdata = transdataAux ctransL |
@@ -349,6 +353,7 @@ instance Element Z where | |||
349 | selectV = selectL | 353 | selectV = selectL |
350 | remapM = remapL | 354 | remapM = remapL |
351 | rowOp = rowOpAux c_rowOpL | 355 | rowOp = rowOpAux c_rowOpL |
356 | gemm = gemmg c_gemmL | ||
352 | 357 | ||
353 | ------------------------------------------------------------------- | 358 | ------------------------------------------------------------------- |
354 | 359 | ||
@@ -575,6 +580,21 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
575 | 580 | ||
576 | -------------------------------------------------------------------------------- | 581 | -------------------------------------------------------------------------------- |
577 | 582 | ||
583 | gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" | ||
584 | |||
585 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok | ||
586 | |||
587 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | ||
588 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | ||
589 | foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C | ||
590 | foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) | ||
591 | foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I | ||
592 | foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z | ||
593 | foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I | ||
594 | foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | ||
595 | |||
596 | -------------------------------------------------------------------------------- | ||
597 | |||
578 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | 598 | foreign import ccall unsafe "saveMatrix" c_saveMatrix |
579 | :: CString -> CString -> Double ..> Ok | 599 | :: CString -> CString -> Double ..> Ok |
580 | 600 | ||
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 6c6d5c5..d158111 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -131,6 +131,9 @@ instance KnownNat m => Element (Mod m I) | |||
131 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) | 131 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) |
132 | where | 132 | where |
133 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 133 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
134 | gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) | ||
135 | where | ||
136 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
134 | 137 | ||
135 | instance KnownNat m => Element (Mod m Z) | 138 | instance KnownNat m => Element (Mod m Z) |
136 | where | 139 | where |
@@ -146,6 +149,9 @@ instance KnownNat m => Element (Mod m Z) | |||
146 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) | 149 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) |
147 | where | 150 | where |
148 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 151 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
152 | gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) | ||
153 | where | ||
154 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
149 | 155 | ||
150 | 156 | ||
151 | instance forall m . KnownNat m => Container Vector (Mod m I) | 157 | instance forall m . KnownNat m => Container Vector (Mod m I) |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 434fe63..25e7f03 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -1,5 +1,6 @@ | |||
1 | {-# LANGUAGE Rank2Types #-} | 1 | {-# LANGUAGE Rank2Types #-} |
2 | {-# LANGUAGE BangPatterns #-} | 2 | {-# LANGUAGE BangPatterns #-} |
3 | {-# LANGUAGE ViewPatterns #-} | ||
3 | 4 | ||
4 | ----------------------------------------------------------------------------- | 5 | ----------------------------------------------------------------------------- |
5 | -- | | 6 | -- | |
@@ -15,14 +16,14 @@ | |||
15 | ----------------------------------------------------------------------------- | 16 | ----------------------------------------------------------------------------- |
16 | 17 | ||
17 | module Internal.ST ( | 18 | module Internal.ST ( |
19 | ST, runST, | ||
18 | -- * Mutable Vectors | 20 | -- * Mutable Vectors |
19 | STVector, newVector, thawVector, freezeVector, runSTVector, | 21 | STVector, newVector, thawVector, freezeVector, runSTVector, |
20 | readVector, writeVector, modifyVector, liftSTVector, | 22 | readVector, writeVector, modifyVector, liftSTVector, |
21 | -- * Mutable Matrices | 23 | -- * Mutable Matrices |
22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 24 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 25 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
24 | -- axpy, scal, swap, rowOp, | 26 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), |
25 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), | ||
26 | -- * Unsafe functions | 27 | -- * Unsafe functions |
27 | newUndefinedVector, | 28 | newUndefinedVector, |
28 | unsafeReadVector, unsafeWriteVector, | 29 | unsafeReadVector, unsafeWriteVector, |
@@ -70,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k | |||
70 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () | 71 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () |
71 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k | 72 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k |
72 | 73 | ||
73 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a | 74 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a |
74 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x | 75 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x |
75 | 76 | ||
76 | freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | 77 | freezeVector :: (Storable t) => STVector s t -> ST s (Vector t) |
77 | freezeVector v = liftSTVector id v | 78 | freezeVector v = liftSTVector id v |
78 | 79 | ||
79 | unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | 80 | unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) |
80 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | 81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x |
81 | 82 | ||
82 | {-# INLINE safeIndexV #-} | 83 | {-# INLINE safeIndexV #-} |
@@ -139,14 +140,14 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | |||
139 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | 140 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () |
140 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | 141 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c |
141 | 142 | ||
142 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a | 143 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a |
143 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | 144 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x |
144 | 145 | ||
145 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 146 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
146 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 147 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
147 | 148 | ||
148 | 149 | ||
149 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 150 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
150 | freezeMatrix m = liftSTMatrix id m | 151 | freezeMatrix m = liftSTMatrix id m |
151 | 152 | ||
152 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) | 153 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) |
@@ -227,6 +228,16 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i | |||
227 | (i1,i2) = getRowRange (rows m) rr | 228 | (i1,i2) = getRowRange (rows m) rr |
228 | (j1,j2) = getColRange (cols m) rc | 229 | (j1,j2) = getColRange (cols m) rc |
229 | 230 | ||
231 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | ||
232 | |||
233 | slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) | ||
234 | |||
235 | gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res | ||
236 | where | ||
237 | res = unsafeIOToST (gemm u v a b r) | ||
238 | u = fromList [alpha,beta] | ||
239 | v = vjoin[pa,pb,pr] | ||
240 | |||
230 | 241 | ||
231 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 242 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
232 | mutable f a = runST $ do | 243 | mutable f a = runST $ do |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 36c5f03..db4236b 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -43,7 +43,7 @@ module Numeric.LinearAlgebra.Devel( | |||
43 | -- ** Mutable Matrices | 43 | -- ** Mutable Matrices |
44 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 44 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
45 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 45 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
46 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), | 46 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), |
47 | -- ** Unsafe functions | 47 | -- ** Unsafe functions |
48 | newUndefinedVector, | 48 | newUndefinedVector, |
49 | unsafeReadVector, unsafeWriteVector, | 49 | unsafeReadVector, unsafeWriteVector, |