{-# LANGUAGE Rank2Types #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE ViewPatterns #-} ----------------------------------------------------------------------------- -- | -- Module : Internal.ST -- Copyright : (c) Alberto Ruiz 2008 -- License : BSD3 -- Maintainer : Alberto Ruiz -- Stability : provisional -- -- In-place manipulation inside the ST monad. -- See @examples/inplace.hs@ in the repository. -- ----------------------------------------------------------------------------- module Internal.ST ( ST, runST, -- * Mutable Vectors STVector, newVector, thawVector, freezeVector, runSTVector, readVector, writeVector, modifyVector, liftSTVector, -- * Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), -- * Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, unsafeThawVector, unsafeFreezeVector, newUndefinedMatrix, unsafeReadMatrix, unsafeWriteMatrix, unsafeThawMatrix, unsafeFreezeMatrix ) where import Internal.Vector import Internal.Matrix import Internal.Vectorized import Control.Monad.ST(ST, runST) import Foreign.Storable(Storable, peekElemOff, pokeElemOff) import Control.Monad.ST.Unsafe(unsafeIOToST) {-# INLINE ioReadV #-} ioReadV :: Storable t => Vector t -> Int -> IO t ioReadV v k = unsafeWith v $ \s -> peekElemOff s k {-# INLINE ioWriteV #-} ioWriteV :: Storable t => Vector t -> Int -> t -> IO () ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x newtype STVector s t = STVector (Vector t) thawVector :: Storable t => Vector t -> ST s (STVector s t) thawVector = unsafeIOToST . fmap STVector . cloneVector unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t) unsafeThawVector = unsafeIOToST . return . STVector runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t runSTVector st = runST (st >>= unsafeFreezeVector) {-# INLINE unsafeReadVector #-} unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x {-# INLINE unsafeWriteVector #-} unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s () unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k {-# INLINE modifyVector #-} modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x freezeVector :: (Storable t) => STVector s t -> ST s (Vector t) freezeVector v = liftSTVector id v unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x {-# INLINE safeIndexV #-} safeIndexV f (STVector v) k | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" ++show (dim v)++", pos="++show k++")" | otherwise = f (STVector v) k {-# INLINE readVector #-} readVector :: Storable t => STVector s t -> Int -> ST s t readVector = safeIndexV unsafeReadVector {-# INLINE writeVector #-} writeVector :: Storable t => STVector s t -> Int -> t -> ST s () writeVector = safeIndexV unsafeWriteVector newUndefinedVector :: Storable t => Int -> ST s (STVector s t) newUndefinedVector = unsafeIOToST . fmap STVector . createVector {-# INLINE newVector #-} newVector :: Storable t => t -> Int -> ST s (STVector s t) newVector x n = do v <- newUndefinedVector n let go (-1) = return v go !k = unsafeWriteVector v k x >> go (k-1 :: Int) go (n-1) ------------------------------------------------------------------------- {-# INLINE ioReadM #-} ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t ioReadM m r c = ioReadV (xdat m) (r * xRow m + c * xCol m) {-# INLINE ioWriteM #-} ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val newtype STMatrix s t = STMatrix (Matrix t) thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) unsafeThawMatrix = unsafeIOToST . return . STMatrix runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t runSTMatrix st = runST (st >>= unsafeFreezeMatrix) {-# INLINE unsafeReadMatrix #-} unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r {-# INLINE unsafeWriteMatrix #-} unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c {-# INLINE modifyMatrix #-} modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) freezeMatrix m = liftSTMatrix id m cloneMatrix m = copy (orderOf m) m {-# INLINE safeIndexM #-} safeIndexM f (STMatrix m) r c | r<0 || r>=rows m || c<0 || c>=cols m = error $ "out of range error in matrix (size=" ++show (rows m,cols m)++", pos="++show (r,c)++")" | otherwise = f (STMatrix m) r c {-# INLINE readMatrix #-} readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t readMatrix = safeIndexM unsafeReadMatrix {-# INLINE writeMatrix #-} writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () writeMatrix = safeIndexM unsafeWriteMatrix setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c {-# NOINLINE newMatrix #-} newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) -------------------------------------------------------------------------------- data ColRange = AllCols | ColRange Int Int | Col Int | FromCol Int getColRange c AllCols = (0,c-1) getColRange c (ColRange a b) = (a `mod` c, b `mod` c) getColRange c (Col a) = (a `mod` c, a `mod` c) getColRange c (FromCol a) = (a `mod` c, c-1) data RowRange = AllRows | RowRange Int Int | Row Int | FromRow Int getRowRange r AllRows = (0,r-1) getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) getRowRange r (Row a) = (a `mod` r, a `mod` r) getRowRange r (FromRow a) = (a `mod` r, r-1) data RowOper t = AXPY t Int Int ColRange | SCAL t RowRange ColRange | SWAP Int Int ColRange rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m where (j1,j2) = getColRange (cols m) r i1' = i1 `mod` (rows m) i2' = i2 `mod` (rows m) rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m where (i1,i2) = getRowRange (rows m) rr (j1,j2) = getColRange (cols m) rc rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m where (j1,j2) = getColRange (cols m) r i1' = i1 `mod` (rows m) i2' = i2 `mod` (rows m) extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) where (i1,i2) = getRowRange (rows m) rr (j1,j2) = getColRange (cols m) rc -- | r0 c0 height width data Slice s t = Slice (STMatrix s t) Int Int Int Int slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () gemmm beta (slice->r) alpha (slice->a) (slice->b) = res where res = unsafeIOToST (gemm v a b r) v = fromList [alpha,beta] mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) mutable f a = runST $ do x <- thawMatrix a info <- f (rows a, cols a) x r <- unsafeFreezeMatrix x return (r,info)