summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-05 16:51:39 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-05 16:51:39 +0200
commitba04d65298266a2f6a37061bedaca4ea3cf7fae6 (patch)
tree5883897751f4b0864ffa53469799c55fda6ac41a /packages/base/src/Internal/ST.hs
parent83f3371700441e50927f24a19e28164cc5194f03 (diff)
move st
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs176
1 files changed, 176 insertions, 0 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
new file mode 100644
index 0000000..25e7969
--- /dev/null
+++ b/packages/base/src/Internal/ST.hs
@@ -0,0 +1,176 @@
1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-}
3-----------------------------------------------------------------------------
4-- |
5-- Module : Internal.ST
6-- Copyright : (c) Alberto Ruiz 2008
7-- License : BSD3
8-- Maintainer : Alberto Ruiz
9-- Stability : provisional
10--
11-- In-place manipulation inside the ST monad.
12-- See examples/inplace.hs in the distribution.
13--
14-----------------------------------------------------------------------------
15
16module Internal.ST (
17 -- * Mutable Vectors
18 STVector, newVector, thawVector, freezeVector, runSTVector,
19 readVector, writeVector, modifyVector, liftSTVector,
20 -- * Mutable Matrices
21 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
22 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
23 -- * Unsafe functions
24 newUndefinedVector,
25 unsafeReadVector, unsafeWriteVector,
26 unsafeThawVector, unsafeFreezeVector,
27 newUndefinedMatrix,
28 unsafeReadMatrix, unsafeWriteMatrix,
29 unsafeThawMatrix, unsafeFreezeMatrix
30) where
31
32import Internal.Vector
33import Internal.Matrix
34import Internal.Vectorized
35import Data.Vector.Storable(unsafeWith)
36
37
38import Control.Monad.ST(ST, runST)
39import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
40
41
42import Control.Monad.ST.Unsafe(unsafeIOToST)
43
44{-# INLINE ioReadV #-}
45ioReadV :: Storable t => Vector t -> Int -> IO t
46ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
47
48{-# INLINE ioWriteV #-}
49ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
50ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
51
52newtype STVector s t = STVector (Vector t)
53
54thawVector :: Storable t => Vector t -> ST s (STVector s t)
55thawVector = unsafeIOToST . fmap STVector . cloneVector
56
57unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
58unsafeThawVector = unsafeIOToST . return . STVector
59
60runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
61runSTVector st = runST (st >>= unsafeFreezeVector)
62
63{-# INLINE unsafeReadVector #-}
64unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
65unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
66
67{-# INLINE unsafeWriteVector #-}
68unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
69unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
70
71{-# INLINE modifyVector #-}
72modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
73modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
74
75liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
76liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
77
78freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
79freezeVector v = liftSTVector id v
80
81unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
82unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
83
84{-# INLINE safeIndexV #-}
85safeIndexV f (STVector v) k
86 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
87 ++show (dim v)++", pos="++show k++")"
88 | otherwise = f (STVector v) k
89
90{-# INLINE readVector #-}
91readVector :: Storable t => STVector s t -> Int -> ST s t
92readVector = safeIndexV unsafeReadVector
93
94{-# INLINE writeVector #-}
95writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
96writeVector = safeIndexV unsafeWriteVector
97
98newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
99newUndefinedVector = unsafeIOToST . fmap STVector . createVector
100
101{-# INLINE newVector #-}
102newVector :: Storable t => t -> Int -> ST s (STVector s t)
103newVector x n = do
104 v <- newUndefinedVector n
105 let go (-1) = return v
106 go !k = unsafeWriteVector v k x >> go (k-1 :: Int)
107 go (n-1)
108
109-------------------------------------------------------------------------
110
111{-# INLINE ioReadM #-}
112ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
113ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
114ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
115
116{-# INLINE ioWriteM #-}
117ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
118ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
119ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
120
121newtype STMatrix s t = STMatrix (Matrix t)
122
123thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
124thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
125
126unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
127unsafeThawMatrix = unsafeIOToST . return . STMatrix
128
129runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
130runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
131
132{-# INLINE unsafeReadMatrix #-}
133unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
134unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
135
136{-# INLINE unsafeWriteMatrix #-}
137unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
138unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
139
140{-# INLINE modifyMatrix #-}
141modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
142modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
143
144liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
145liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
146
147unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
148unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
149
150freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
151freezeMatrix m = liftSTMatrix id m
152
153cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
154
155{-# INLINE safeIndexM #-}
156safeIndexM f (STMatrix m) r c
157 | r<0 || r>=rows m ||
158 c<0 || c>=cols m = error $ "out of range error in matrix (size="
159 ++show (rows m,cols m)++", pos="++show (r,c)++")"
160 | otherwise = f (STMatrix m) r c
161
162{-# INLINE readMatrix #-}
163readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
164readMatrix = safeIndexM unsafeReadMatrix
165
166{-# INLINE writeMatrix #-}
167writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
168writeMatrix = safeIndexM unsafeWriteMatrix
169
170newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
171newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
172
173{-# NOINLINE newMatrix #-}
174newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
175newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
176