summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-30 12:04:21 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-30 12:04:21 +0200
commitb9329d636d19f6a26da1cf1fd7e8d7cbd0b04cce (patch)
treec0beb22b3b394ed9d18a6a98d5cf1ca6d4ea8960 /packages/base/src/Internal/Matrix.hs
parent9c05df0cd663bafaf0b69eafee53fce45569dc95 (diff)
support slice in multiply
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs8
1 files changed, 5 insertions, 3 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 8597dcb..a789cae 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -226,6 +226,8 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
226 226
227------------------------------------------------------------------ 227------------------------------------------------------------------
228 228
229matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
230matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
229matrixFromVector o r c v 231matrixFromVector o r c v
230 | r * c == dim v = m 232 | r * c == dim v = m
231 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m 233 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
@@ -280,7 +282,7 @@ class (Storable a) => Element a where
280 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 282 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
281 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 283 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
282 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 284 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
283 gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () 285 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
284 286
285 287
286instance Element Float where 288instance Element Float where
@@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
569 571
570-------------------------------------------------------------------------------- 572--------------------------------------------------------------------------------
571 573
572gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" 574gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg"
573 575
574type Tgemm x = x :> I :> x ::> x ::> x ::> Ok 576type Tgemm x = x :> x ::> x ::> x ::> Ok
575 577
576foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R 578foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
577foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float 579foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float