summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-17 13:02:40 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-17 13:02:40 +0200
commit61d90ff66af8bfe53ef8cdda8dfe1e70463c213c (patch)
tree79ca6e2024731708a4378ff691b5317d8da11d11
parente7d2916f78b5c140738fc4f4f95c9b13c1768293 (diff)
gemmm
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c59
-rw-r--r--packages/base/src/Internal/C/lapack-aux.h1
-rw-r--r--packages/base/src/Internal/Matrix.hs24
-rw-r--r--packages/base/src/Internal/Modular.hs6
-rw-r--r--packages/base/src/Internal/ST.hs27
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs2
6 files changed, 108 insertions, 11 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 2843ab5..4d48594 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1398,6 +1398,65 @@ ROWOP(int64_t)
1398ROWOP_MOD(int32_t,mod) 1398ROWOP_MOD(int32_t,mod)
1399ROWOP_MOD(int64_t,mod_l) 1399ROWOP_MOD(int64_t,mod_l)
1400 1400
1401/////////////////////////////// inplace GEMM ////////////////////////////////
1402
1403#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1404 T a = cp[0], b = cp[1]; \
1405 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \
1406 int r1b = pp[4], c1b = pp[6] ; \
1407 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \
1408 int dra = r1a - r1r; \
1409 int dcb = c1b-c1r; \
1410 int nk = c2a-c1a+1; \
1411 int i,j,k; \
1412 T t; \
1413 for (i=r1r; i<=r2r; i++) { \
1414 for (j=c1r; j<=c2r; j++) { \
1415 t = 0; \
1416 for(k=0; k<nk; k++) { \
1417 t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \
1418 } \
1419 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1420 } \
1421 } \
1422 OK \
1423}
1424
1425GEMM(double)
1426GEMM(float)
1427GEMM(TCD)
1428GEMM(TCF)
1429GEMM(int32_t)
1430GEMM(int64_t)
1431
1432#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1433 T a = cp[0], b = cp[1]; \
1434 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \
1435 int r1b = pp[4], c1b = pp[6] ; \
1436 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \
1437 int dra = r1a - r1r; \
1438 int dcb = c1b-c1r; \
1439 int nk = c2a-c1a+1; \
1440 int i,j,k; \
1441 T t; \
1442 for (i=r1r; i<=r2r; i++) { \
1443 for (j=c1r; j<=c2r; j++) { \
1444 t = 0; \
1445 for(k=0; k<nk; k++) { \
1446 t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \
1447 } \
1448 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1449 } \
1450 } \
1451 OK \
1452}
1453
1454#define MOD32(X) mod(X,m)
1455#define MOD64(X) mod_l(X,m)
1456
1457GEMM_MOD(int32_t,MOD32)
1458GEMM_MOD(int64_t,MOD64)
1459
1401////////////////// sparse matrix-product /////////////////////////////////////// 1460////////////////// sparse matrix-product ///////////////////////////////////////
1402 1461
1403 1462
diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h
index e4d95bc..bf8c5e9 100644
--- a/packages/base/src/Internal/C/lapack-aux.h
+++ b/packages/base/src/Internal/C/lapack-aux.h
@@ -59,6 +59,7 @@ typedef short ftnlen;
59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p 59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p
60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p 60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p
61 61
62#define VECG(T,A) int A##n, T* A##p
62#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p 63#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p
63 64
64#define KIVEC(A) int A##n, const int*A##p 65#define KIVEC(A) int A##n, const int*A##p
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 75e92a5..8f8c219 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -274,6 +274,7 @@ class (Storable a) => Element a where
274 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 274 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
275 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 275 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
276 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 276 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
277 gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO ()
277 278
278 279
279instance Element Float where 280instance Element Float where
@@ -287,6 +288,7 @@ instance Element Float where
287 selectV = selectF 288 selectV = selectF
288 remapM = remapF 289 remapM = remapF
289 rowOp = rowOpAux c_rowOpF 290 rowOp = rowOpAux c_rowOpF
291 gemm = gemmg c_gemmF
290 292
291instance Element Double where 293instance Element Double where
292 transdata = transdataAux ctransR 294 transdata = transdataAux ctransR
@@ -299,7 +301,7 @@ instance Element Double where
299 selectV = selectD 301 selectV = selectD
300 remapM = remapD 302 remapM = remapD
301 rowOp = rowOpAux c_rowOpD 303 rowOp = rowOpAux c_rowOpD
302 304 gemm = gemmg c_gemmD
303 305
304instance Element (Complex Float) where 306instance Element (Complex Float) where
305 transdata = transdataAux ctransQ 307 transdata = transdataAux ctransQ
@@ -312,7 +314,7 @@ instance Element (Complex Float) where
312 selectV = selectQ 314 selectV = selectQ
313 remapM = remapQ 315 remapM = remapQ
314 rowOp = rowOpAux c_rowOpQ 316 rowOp = rowOpAux c_rowOpQ
315 317 gemm = gemmg c_gemmQ
316 318
317instance Element (Complex Double) where 319instance Element (Complex Double) where
318 transdata = transdataAux ctransC 320 transdata = transdataAux ctransC
@@ -325,6 +327,7 @@ instance Element (Complex Double) where
325 selectV = selectC 327 selectV = selectC
326 remapM = remapC 328 remapM = remapC
327 rowOp = rowOpAux c_rowOpC 329 rowOp = rowOpAux c_rowOpC
330 gemm = gemmg c_gemmC
328 331
329instance Element (CInt) where 332instance Element (CInt) where
330 transdata = transdataAux ctransI 333 transdata = transdataAux ctransI
@@ -337,6 +340,7 @@ instance Element (CInt) where
337 selectV = selectI 340 selectV = selectI
338 remapM = remapI 341 remapM = remapI
339 rowOp = rowOpAux c_rowOpI 342 rowOp = rowOpAux c_rowOpI
343 gemm = gemmg c_gemmI
340 344
341instance Element Z where 345instance Element Z where
342 transdata = transdataAux ctransL 346 transdata = transdataAux ctransL
@@ -349,6 +353,7 @@ instance Element Z where
349 selectV = selectL 353 selectV = selectL
350 remapM = remapL 354 remapM = remapL
351 rowOp = rowOpAux c_rowOpL 355 rowOp = rowOpAux c_rowOpL
356 gemm = gemmg c_gemmL
352 357
353------------------------------------------------------------------- 358-------------------------------------------------------------------
354 359
@@ -575,6 +580,21 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
575 580
576-------------------------------------------------------------------------------- 581--------------------------------------------------------------------------------
577 582
583gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg"
584
585type Tgemm x = x :> I :> x ::> x ::> x ::> Ok
586
587foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
588foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
589foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C
590foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float)
591foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
592foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
593foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
594foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
595
596--------------------------------------------------------------------------------
597
578foreign import ccall unsafe "saveMatrix" c_saveMatrix 598foreign import ccall unsafe "saveMatrix" c_saveMatrix
579 :: CString -> CString -> Double ..> Ok 599 :: CString -> CString -> Double ..> Ok
580 600
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 6c6d5c5..d158111 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -131,6 +131,9 @@ instance KnownNat m => Element (Mod m I)
131 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) 131 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
132 where 132 where
133 m' = fromIntegral . natVal $ (undefined :: Proxy m) 133 m' = fromIntegral . natVal $ (undefined :: Proxy m)
134 gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c)
135 where
136 m' = fromIntegral . natVal $ (undefined :: Proxy m)
134 137
135instance KnownNat m => Element (Mod m Z) 138instance KnownNat m => Element (Mod m Z)
136 where 139 where
@@ -146,6 +149,9 @@ instance KnownNat m => Element (Mod m Z)
146 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) 149 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
147 where 150 where
148 m' = fromIntegral . natVal $ (undefined :: Proxy m) 151 m' = fromIntegral . natVal $ (undefined :: Proxy m)
152 gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c)
153 where
154 m' = fromIntegral . natVal $ (undefined :: Proxy m)
149 155
150 156
151instance forall m . KnownNat m => Container Vector (Mod m I) 157instance forall m . KnownNat m => Container Vector (Mod m I)
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 434fe63..25e7f03 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3{-# LANGUAGE ViewPatterns #-}
3 4
4----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
5-- | 6-- |
@@ -15,14 +16,14 @@
15----------------------------------------------------------------------------- 16-----------------------------------------------------------------------------
16 17
17module Internal.ST ( 18module Internal.ST (
19 ST, runST,
18 -- * Mutable Vectors 20 -- * Mutable Vectors
19 STVector, newVector, thawVector, freezeVector, runSTVector, 21 STVector, newVector, thawVector, freezeVector, runSTVector,
20 readVector, writeVector, modifyVector, liftSTVector, 22 readVector, writeVector, modifyVector, liftSTVector,
21 -- * Mutable Matrices 23 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 24 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 25 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24-- axpy, scal, swap, rowOp, 26 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
25 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
26 -- * Unsafe functions 27 -- * Unsafe functions
27 newUndefinedVector, 28 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector, 29 unsafeReadVector, unsafeWriteVector,
@@ -70,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
70modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () 71modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
71modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k 72modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
72 73
73liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a 74liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a
74liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x 75liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
75 76
76freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 77freezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
77freezeVector v = liftSTVector id v 78freezeVector v = liftSTVector id v
78 79
79unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 80unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
80unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x 81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
81 82
82{-# INLINE safeIndexV #-} 83{-# INLINE safeIndexV #-}
@@ -139,14 +140,14 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
139modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
140modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
141 142
142liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a 143liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a
143liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
144 145
145unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
146unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
147 148
148 149
149freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
150freezeMatrix m = liftSTMatrix id m 151freezeMatrix m = liftSTMatrix id m
151 152
152cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) 153cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
@@ -227,6 +228,16 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i
227 (i1,i2) = getRowRange (rows m) rr 228 (i1,i2) = getRowRange (rows m) rr
228 (j1,j2) = getColRange (cols m) rc 229 (j1,j2) = getColRange (cols m) rc
229 230
231data Slice s t = Slice (STMatrix s t) Int Int Int Int
232
233slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1])
234
235gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res
236 where
237 res = unsafeIOToST (gemm u v a b r)
238 u = fromList [alpha,beta]
239 v = vjoin[pa,pb,pr]
240
230 241
231mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 242mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
232mutable f a = runST $ do 243mutable f a = runST $ do
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index 36c5f03..db4236b 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -43,7 +43,7 @@ module Numeric.LinearAlgebra.Devel(
43 -- ** Mutable Matrices 43 -- ** Mutable Matrices
44 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 44 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
45 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 45 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
46 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), 46 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
47 -- ** Unsafe functions 47 -- ** Unsafe functions
48 newUndefinedVector, 48 newUndefinedVector,
49 unsafeReadVector, unsafeWriteVector, 49 unsafeReadVector, unsafeWriteVector,