From b9329d636d19f6a26da1cf1fd7e8d7cbd0b04cce Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 30 Jun 2015 12:04:21 +0200 Subject: support slice in multiply --- packages/base/src/Internal/Matrix.hs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8597dcb..a789cae 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -226,6 +226,8 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) ------------------------------------------------------------------ +matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } +matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } matrixFromVector o r c v | r * c == dim v = m | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m @@ -280,7 +282,7 @@ class (Storable a) => Element a where selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () - gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () + gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () instance Element Float where @@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- -gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" +gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" -type Tgemm x = x :> I :> x ::> x ::> x ::> Ok +type Tgemm x = x :> x ::> x ::> x ::> Ok foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float -- cgit v1.2.3