diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8597dcb..a789cae 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -226,6 +226,8 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | |||
226 | 226 | ||
227 | ------------------------------------------------------------------ | 227 | ------------------------------------------------------------------ |
228 | 228 | ||
229 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | ||
230 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | ||
229 | matrixFromVector o r c v | 231 | matrixFromVector o r c v |
230 | | r * c == dim v = m | 232 | | r * c == dim v = m |
231 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | 233 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
@@ -280,7 +282,7 @@ class (Storable a) => Element a where | |||
280 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | 282 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a |
281 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 283 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
282 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 284 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
283 | gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () | 285 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () |
284 | 286 | ||
285 | 287 | ||
286 | instance Element Float where | 288 | instance Element Float where |
@@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
569 | 571 | ||
570 | -------------------------------------------------------------------------------- | 572 | -------------------------------------------------------------------------------- |
571 | 573 | ||
572 | gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" | 574 | gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" |
573 | 575 | ||
574 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok | 576 | type Tgemm x = x :> x ::> x ::> x ::> Ok |
575 | 577 | ||
576 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | 578 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R |
577 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | 579 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float |