From a407c0e101f8f6db44fcf731ebb8460d0f691196 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 6 Jun 2008 15:41:48 +0000 Subject: range checking and other additions to the mutable interface --- lib/Data/Packed/ST.hs | 97 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 78 insertions(+), 19 deletions(-) (limited to 'lib/Data/Packed/ST.hs') diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs index 3d94014..1311ff9 100644 --- a/lib/Data/Packed/ST.hs +++ b/lib/Data/Packed/ST.hs @@ -16,10 +16,17 @@ ----------------------------------------------------------------------------- module Data.Packed.ST ( - STVector, thawVector, freezeVector, runSTVector, + -- * Mutable Vectors + STVector, newVector, thawVector, freezeVector, runSTVector, readVector, writeVector, modifyVector, liftSTVector, - STMatrix, thawMatrix, freezeMatrix, runSTMatrix, - readMatrix, writeMatrix, modifyMatrix, liftSTMatrix + -- * Mutable Matrices + STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, + readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, + -- * Unsafe functions + unsafeReadVector, unsafeWriteVector, + unsafeThawVector, unsafeFreezeVector, + unsafeReadMatrix, unsafeWriteMatrix, + unsafeThawMatrix, unsafeFreezeMatrix ) where import Data.Packed.Internal @@ -28,44 +35,71 @@ import Control.Monad.ST import Data.Array.ST import Foreign - +{-# INLINE ioReadV #-} ioReadV :: Storable t => Vector t -> Int -> IO t ioReadV v k = withForeignPtr (fptr v) $ \s -> peekElemOff s k +{-# INLINE ioWriteV #-} ioWriteV :: Storable t => Vector t -> Int -> t -> IO () ioWriteV v k x = withForeignPtr (fptr v) $ \s -> pokeElemOff s k x -newtype STVector s t = Mut (Vector t) +newtype STVector s t = STVector (Vector t) thawVector :: Storable t => Vector t -> ST s (STVector s t) -thawVector = unsafeIOToST . fmap Mut . cloneVector +thawVector = unsafeIOToST . fmap STVector . cloneVector -unsafeFreezeVector (Mut x) = unsafeIOToST . return $ x +unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t) +unsafeThawVector = unsafeIOToST . return . STVector runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t runSTVector st = runST (st >>= unsafeFreezeVector) -readVector :: Storable t => STVector s t -> Int -> ST s t -readVector (Mut x) = unsafeIOToST . ioReadV x +{-# INLINE unsafeReadVector #-} +unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t +unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x -writeVector :: Storable t => STVector s t -> Int -> t -> ST s () -writeVector (Mut x) k = unsafeIOToST . ioWriteV x k +{-# INLINE unsafeWriteVector #-} +unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s () +unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k +{-# INLINE modifyVector #-} modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () -modifyVector x k f = readVector x k >>= return . f >>= writeVector x k +modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a -liftSTVector f (Mut x) = unsafeIOToST . fmap f . cloneVector $ x +liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) freezeVector v = liftSTVector id v +unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) +unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x + +{-# INLINE safeIndexV #-} +safeIndexV f (STVector v) k + | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" + ++show (dim v)++", pos="++show k++")" + | otherwise = f (STVector v) k + +{-# INLINE readVector #-} +readVector :: Storable t => STVector s t -> Int -> ST s t +readVector = safeIndexV unsafeReadVector + +{-# INLINE writeVector #-} +writeVector :: Storable t => STVector s t -> Int -> t -> ST s () +writeVector = safeIndexV unsafeWriteVector + +newVector :: Element t => t -> Int -> ST s (STVector s t) +newVector v = unsafeThawVector . constant v + ------------------------------------------------------------------------- +{-# INLINE ioReadM #-} ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t ioReadM (MC nr nc cv) r c = ioReadV cv (r*nc+c) ioReadM (MF nr nc fv) r c = ioReadV fv (c*nr+r) +{-# INLINE ioWriteM #-} ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () ioWriteM (MC nr nc cv) r c val = ioWriteV cv (r*nc+c) val ioWriteM (MF nr nc fv) r c val = ioWriteV fv (c*nr+r) val @@ -75,25 +109,50 @@ newtype STMatrix s t = STMatrix (Matrix t) thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix -unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x +unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) +unsafeThawMatrix = unsafeIOToST . return . STMatrix runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t runSTMatrix st = runST (st >>= unsafeFreezeMatrix) -readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t -readMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r +{-# INLINE unsafeReadMatrix #-} +unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t +unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r -writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () -writeMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c +{-# INLINE unsafeWriteMatrix #-} +unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () +unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c +{-# INLINE modifyMatrix #-} modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () -modifyMatrix x r c f = readMatrix x r c >>= return . f >>= writeMatrix x r c +modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x +unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) +unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x + freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) freezeMatrix m = liftSTMatrix id m cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c + +{-# INLINE safeIndexM #-} +safeIndexM f (STMatrix m) r c + | r<0 || r>=rows m || + c<0 || c>=cols m = error $ "out of range error in matrix (size=" + ++show (rows m,cols m)++", pos="++show (r,c)++")" + | otherwise = f (STMatrix m) r c + +{-# INLINE readMatrix #-} +readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t +readMatrix = safeIndexM unsafeReadMatrix + +{-# INLINE writeMatrix #-} +writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () +writeMatrix = safeIndexM unsafeWriteMatrix + +newMatrix :: Element t => t -> Int -> Int -> ST s (STMatrix s t) +newMatrix v r c = unsafeThawMatrix . reshape c . constant v $ r*c -- cgit v1.2.3