summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-17 13:02:40 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-17 13:02:40 +0200
commit61d90ff66af8bfe53ef8cdda8dfe1e70463c213c (patch)
tree79ca6e2024731708a4378ff691b5317d8da11d11 /packages/base/src/Internal
parente7d2916f78b5c140738fc4f4f95c9b13c1768293 (diff)
gemmm
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c59
-rw-r--r--packages/base/src/Internal/C/lapack-aux.h1
-rw-r--r--packages/base/src/Internal/Matrix.hs24
-rw-r--r--packages/base/src/Internal/Modular.hs6
-rw-r--r--packages/base/src/Internal/ST.hs27
5 files changed, 107 insertions, 10 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 2843ab5..4d48594 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1398,6 +1398,65 @@ ROWOP(int64_t)
1398ROWOP_MOD(int32_t,mod) 1398ROWOP_MOD(int32_t,mod)
1399ROWOP_MOD(int64_t,mod_l) 1399ROWOP_MOD(int64_t,mod_l)
1400 1400
1401/////////////////////////////// inplace GEMM ////////////////////////////////
1402
1403#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1404 T a = cp[0], b = cp[1]; \
1405 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \
1406 int r1b = pp[4], c1b = pp[6] ; \
1407 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \
1408 int dra = r1a - r1r; \
1409 int dcb = c1b-c1r; \
1410 int nk = c2a-c1a+1; \
1411 int i,j,k; \
1412 T t; \
1413 for (i=r1r; i<=r2r; i++) { \
1414 for (j=c1r; j<=c2r; j++) { \
1415 t = 0; \
1416 for(k=0; k<nk; k++) { \
1417 t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \
1418 } \
1419 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1420 } \
1421 } \
1422 OK \
1423}
1424
1425GEMM(double)
1426GEMM(float)
1427GEMM(TCD)
1428GEMM(TCF)
1429GEMM(int32_t)
1430GEMM(int64_t)
1431
1432#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1433 T a = cp[0], b = cp[1]; \
1434 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \
1435 int r1b = pp[4], c1b = pp[6] ; \
1436 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \
1437 int dra = r1a - r1r; \
1438 int dcb = c1b-c1r; \
1439 int nk = c2a-c1a+1; \
1440 int i,j,k; \
1441 T t; \
1442 for (i=r1r; i<=r2r; i++) { \
1443 for (j=c1r; j<=c2r; j++) { \
1444 t = 0; \
1445 for(k=0; k<nk; k++) { \
1446 t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \
1447 } \
1448 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1449 } \
1450 } \
1451 OK \
1452}
1453
1454#define MOD32(X) mod(X,m)
1455#define MOD64(X) mod_l(X,m)
1456
1457GEMM_MOD(int32_t,MOD32)
1458GEMM_MOD(int64_t,MOD64)
1459
1401////////////////// sparse matrix-product /////////////////////////////////////// 1460////////////////// sparse matrix-product ///////////////////////////////////////
1402 1461
1403 1462
diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h
index e4d95bc..bf8c5e9 100644
--- a/packages/base/src/Internal/C/lapack-aux.h
+++ b/packages/base/src/Internal/C/lapack-aux.h
@@ -59,6 +59,7 @@ typedef short ftnlen;
59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p 59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p
60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p 60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p
61 61
62#define VECG(T,A) int A##n, T* A##p
62#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p 63#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p
63 64
64#define KIVEC(A) int A##n, const int*A##p 65#define KIVEC(A) int A##n, const int*A##p
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 75e92a5..8f8c219 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -274,6 +274,7 @@ class (Storable a) => Element a where
274 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 274 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
275 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 275 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
276 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 276 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
277 gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO ()
277 278
278 279
279instance Element Float where 280instance Element Float where
@@ -287,6 +288,7 @@ instance Element Float where
287 selectV = selectF 288 selectV = selectF
288 remapM = remapF 289 remapM = remapF
289 rowOp = rowOpAux c_rowOpF 290 rowOp = rowOpAux c_rowOpF
291 gemm = gemmg c_gemmF
290 292
291instance Element Double where 293instance Element Double where
292 transdata = transdataAux ctransR 294 transdata = transdataAux ctransR
@@ -299,7 +301,7 @@ instance Element Double where
299 selectV = selectD 301 selectV = selectD
300 remapM = remapD 302 remapM = remapD
301 rowOp = rowOpAux c_rowOpD 303 rowOp = rowOpAux c_rowOpD
302 304 gemm = gemmg c_gemmD
303 305
304instance Element (Complex Float) where 306instance Element (Complex Float) where
305 transdata = transdataAux ctransQ 307 transdata = transdataAux ctransQ
@@ -312,7 +314,7 @@ instance Element (Complex Float) where
312 selectV = selectQ 314 selectV = selectQ
313 remapM = remapQ 315 remapM = remapQ
314 rowOp = rowOpAux c_rowOpQ 316 rowOp = rowOpAux c_rowOpQ
315 317 gemm = gemmg c_gemmQ
316 318
317instance Element (Complex Double) where 319instance Element (Complex Double) where
318 transdata = transdataAux ctransC 320 transdata = transdataAux ctransC
@@ -325,6 +327,7 @@ instance Element (Complex Double) where
325 selectV = selectC 327 selectV = selectC
326 remapM = remapC 328 remapM = remapC
327 rowOp = rowOpAux c_rowOpC 329 rowOp = rowOpAux c_rowOpC
330 gemm = gemmg c_gemmC
328 331
329instance Element (CInt) where 332instance Element (CInt) where
330 transdata = transdataAux ctransI 333 transdata = transdataAux ctransI
@@ -337,6 +340,7 @@ instance Element (CInt) where
337 selectV = selectI 340 selectV = selectI
338 remapM = remapI 341 remapM = remapI
339 rowOp = rowOpAux c_rowOpI 342 rowOp = rowOpAux c_rowOpI
343 gemm = gemmg c_gemmI
340 344
341instance Element Z where 345instance Element Z where
342 transdata = transdataAux ctransL 346 transdata = transdataAux ctransL
@@ -349,6 +353,7 @@ instance Element Z where
349 selectV = selectL 353 selectV = selectL
350 remapM = remapL 354 remapM = remapL
351 rowOp = rowOpAux c_rowOpL 355 rowOp = rowOpAux c_rowOpL
356 gemm = gemmg c_gemmL
352 357
353------------------------------------------------------------------- 358-------------------------------------------------------------------
354 359
@@ -575,6 +580,21 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
575 580
576-------------------------------------------------------------------------------- 581--------------------------------------------------------------------------------
577 582
583gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg"
584
585type Tgemm x = x :> I :> x ::> x ::> x ::> Ok
586
587foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
588foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
589foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C
590foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float)
591foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
592foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
593foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
594foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
595
596--------------------------------------------------------------------------------
597
578foreign import ccall unsafe "saveMatrix" c_saveMatrix 598foreign import ccall unsafe "saveMatrix" c_saveMatrix
579 :: CString -> CString -> Double ..> Ok 599 :: CString -> CString -> Double ..> Ok
580 600
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 6c6d5c5..d158111 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -131,6 +131,9 @@ instance KnownNat m => Element (Mod m I)
131 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) 131 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
132 where 132 where
133 m' = fromIntegral . natVal $ (undefined :: Proxy m) 133 m' = fromIntegral . natVal $ (undefined :: Proxy m)
134 gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c)
135 where
136 m' = fromIntegral . natVal $ (undefined :: Proxy m)
134 137
135instance KnownNat m => Element (Mod m Z) 138instance KnownNat m => Element (Mod m Z)
136 where 139 where
@@ -146,6 +149,9 @@ instance KnownNat m => Element (Mod m Z)
146 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) 149 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
147 where 150 where
148 m' = fromIntegral . natVal $ (undefined :: Proxy m) 151 m' = fromIntegral . natVal $ (undefined :: Proxy m)
152 gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c)
153 where
154 m' = fromIntegral . natVal $ (undefined :: Proxy m)
149 155
150 156
151instance forall m . KnownNat m => Container Vector (Mod m I) 157instance forall m . KnownNat m => Container Vector (Mod m I)
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 434fe63..25e7f03 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3{-# LANGUAGE ViewPatterns #-}
3 4
4----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
5-- | 6-- |
@@ -15,14 +16,14 @@
15----------------------------------------------------------------------------- 16-----------------------------------------------------------------------------
16 17
17module Internal.ST ( 18module Internal.ST (
19 ST, runST,
18 -- * Mutable Vectors 20 -- * Mutable Vectors
19 STVector, newVector, thawVector, freezeVector, runSTVector, 21 STVector, newVector, thawVector, freezeVector, runSTVector,
20 readVector, writeVector, modifyVector, liftSTVector, 22 readVector, writeVector, modifyVector, liftSTVector,
21 -- * Mutable Matrices 23 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 24 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 25 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24-- axpy, scal, swap, rowOp, 26 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
25 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
26 -- * Unsafe functions 27 -- * Unsafe functions
27 newUndefinedVector, 28 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector, 29 unsafeReadVector, unsafeWriteVector,
@@ -70,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
70modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () 71modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
71modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k 72modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
72 73
73liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a 74liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a
74liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x 75liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
75 76
76freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 77freezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
77freezeVector v = liftSTVector id v 78freezeVector v = liftSTVector id v
78 79
79unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 80unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
80unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x 81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
81 82
82{-# INLINE safeIndexV #-} 83{-# INLINE safeIndexV #-}
@@ -139,14 +140,14 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
139modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
140modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
141 142
142liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a 143liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a
143liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
144 145
145unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
146unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
147 148
148 149
149freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
150freezeMatrix m = liftSTMatrix id m 151freezeMatrix m = liftSTMatrix id m
151 152
152cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) 153cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
@@ -227,6 +228,16 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i
227 (i1,i2) = getRowRange (rows m) rr 228 (i1,i2) = getRowRange (rows m) rr
228 (j1,j2) = getColRange (cols m) rc 229 (j1,j2) = getColRange (cols m) rc
229 230
231data Slice s t = Slice (STMatrix s t) Int Int Int Int
232
233slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1])
234
235gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res
236 where
237 res = unsafeIOToST (gemm u v a b r)
238 u = fromList [alpha,beta]
239 v = vjoin[pa,pb,pr]
240
230 241
231mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 242mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
232mutable f a = runST $ do 243mutable f a = runST $ do