From 2749f4ef144cbc8541d70434f46abf312a1bb42e Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sun, 28 Jun 2015 14:18:25 +0200 Subject: copy slice --- packages/base/src/Internal/Matrix.hs | 17 +++-------------- packages/base/src/Internal/ST.hs | 11 +++++------ 2 files changed, 8 insertions(+), 20 deletions(-) (limited to 'packages/base') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index bdf2785..8597dcb 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -142,10 +142,9 @@ a # b = apply a b -------------------------------------------------------------------------------- -extractAll ord m = unsafePerformIO $ - extractR ord m - 0 (idxs[0,rows m-1]) - 0 (idxs[0,cols m-1]) +copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) + +extractAll ord m = unsafePerformIO (copy ord m) {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. @@ -204,16 +203,6 @@ toRows m sub k = subVector (k*xRow m) (cols m) (xdat m) ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) -{- - | c == 0 = replicate r (fromList[]) - | otherwise = toRows' 0 - where - v = flatten m - r = rows m - c = cols m - toRows' k | k == r*c = [] - | otherwise = subVector k c v : toRows' (k+c) --} -- | Creates a matrix from a list of vectors, as columns fromColumns :: Element t => [Vector t] -> Matrix t diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 23fda99..91c2a11 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -119,7 +119,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val newtype STMatrix s t = STMatrix (Matrix t) -thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) +thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) @@ -140,18 +140,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c -liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a +liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x -freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) +freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) freezeMatrix m = liftSTMatrix id m --- FIXME -cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'}) +cloneMatrix m = copy (orderOf m) m {-# INLINE safeIndexM #-} safeIndexM f (STMatrix m) r c @@ -242,7 +241,7 @@ gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res v = vjoin[pa,pb,pr] -mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) +mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) mutable f a = runST $ do x <- thawMatrix a info <- f (rows a, cols a) x -- cgit v1.2.3