From 1a68793247b8845cefad4d157e4f4d25b1731b42 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 30 Mar 2018 12:48:20 +0100 Subject: Implement CI --- packages/base/src/Internal/ST.hs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'packages/base/src/Internal/ST.hs') diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 544c9e4..7d54e6d 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x {-# INLINE safeIndexV #-} +safeIndexV :: Storable t2 + => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t safeIndexV f (STVector v) k | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" ++show (dim v)++", pos="++show k++")" @@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) freezeMatrix m = liftSTMatrix id m +cloneMatrix :: Element t => Matrix t -> IO (Matrix t) cloneMatrix m = copy (orderOf m) m {-# INLINE safeIndexM #-} +safeIndexM :: (STMatrix s t2 -> Int -> Int -> t) + -> STMatrix t1 t2 -> Int -> Int -> t safeIndexM f (STMatrix m) r c | r<0 || r>=rows m || c<0 || c>=cols m = error $ "out of range error in matrix (size=" @@ -184,6 +189,7 @@ data ColRange = AllCols | Col Int | FromCol Int +getColRange :: Int -> ColRange -> (Int, Int) getColRange c AllCols = (0,c-1) getColRange c (ColRange a b) = (a `mod` c, b `mod` c) getColRange c (Col a) = (a `mod` c, a `mod` c) @@ -194,6 +200,7 @@ data RowRange = AllRows | Row Int | FromRow Int +getRowRange :: Int -> RowRange -> (Int, Int) getRowRange r AllRows = (0,r-1) getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) getRowRange r (Row a) = (a `mod` r, a `mod` r) @@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m i2' = i2 `mod` (rows m) +extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) where (i1,i2) = getRowRange (rows m) rr @@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ -- | r0 c0 height width data Slice s t = Slice (STMatrix s t) Int Int Int Int +slice :: Element a => Slice t a -> Matrix a slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () @@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res where res = unsafeIOToST (gemm v a b r) v = fromList [alpha,beta] - + mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) mutable f a = runST $ do @@ -246,4 +255,3 @@ mutable f a = runST $ do info <- f (rows a, cols a) x r <- unsafeFreezeMatrix x return (r,info) - -- cgit v1.2.3