From 61d90ff66af8bfe53ef8cdda8dfe1e70463c213c Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 17 Jun 2015 13:02:40 +0200 Subject: gemmm --- packages/base/src/Internal/Matrix.hs | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 75e92a5..8f8c219 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -274,6 +274,7 @@ class (Storable a) => Element a where selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () + gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () instance Element Float where @@ -287,6 +288,7 @@ instance Element Float where selectV = selectF remapM = remapF rowOp = rowOpAux c_rowOpF + gemm = gemmg c_gemmF instance Element Double where transdata = transdataAux ctransR @@ -299,7 +301,7 @@ instance Element Double where selectV = selectD remapM = remapD rowOp = rowOpAux c_rowOpD - + gemm = gemmg c_gemmD instance Element (Complex Float) where transdata = transdataAux ctransQ @@ -312,7 +314,7 @@ instance Element (Complex Float) where selectV = selectQ remapM = remapQ rowOp = rowOpAux c_rowOpQ - + gemm = gemmg c_gemmQ instance Element (Complex Double) where transdata = transdataAux ctransC @@ -325,6 +327,7 @@ instance Element (Complex Double) where selectV = selectC remapM = remapC rowOp = rowOpAux c_rowOpC + gemm = gemmg c_gemmC instance Element (CInt) where transdata = transdataAux ctransI @@ -337,6 +340,7 @@ instance Element (CInt) where selectV = selectI remapM = remapI rowOp = rowOpAux c_rowOpI + gemm = gemmg c_gemmI instance Element Z where transdata = transdataAux ctransL @@ -349,6 +353,7 @@ instance Element Z where selectV = selectL remapM = remapL rowOp = rowOpAux c_rowOpL + gemm = gemmg c_gemmL ------------------------------------------------------------------- @@ -575,6 +580,21 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- +gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" + +type Tgemm x = x :> I :> x ::> x ::> x ::> Ok + +foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R +foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float +foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C +foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) +foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I +foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z +foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I +foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z + +-------------------------------------------------------------------------------- + foreign import ccall unsafe "saveMatrix" c_saveMatrix :: CString -> CString -> Double ..> Ok -- cgit v1.2.3