summaryrefslogtreecommitdiff
path: root/packages/base/src/Data/Packed/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data/Packed/ST.hs')
-rw-r--r--packages/base/src/Data/Packed/ST.hs178
1 files changed, 178 insertions, 0 deletions
diff --git a/packages/base/src/Data/Packed/ST.hs b/packages/base/src/Data/Packed/ST.hs
new file mode 100644
index 0000000..dae457c
--- /dev/null
+++ b/packages/base/src/Data/Packed/ST.hs
@@ -0,0 +1,178 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeOperators #-}
3{-# LANGUAGE Rank2Types #-}
4{-# LANGUAGE BangPatterns #-}
5-----------------------------------------------------------------------------
6-- |
7-- Module : Data.Packed.ST
8-- Copyright : (c) Alberto Ruiz 2008
9-- License : BSD3
10-- Maintainer : Alberto Ruiz
11-- Stability : provisional
12-- Portability : portable
13--
14-- In-place manipulation inside the ST monad.
15-- See examples/inplace.hs in the distribution.
16--
17-----------------------------------------------------------------------------
18
19module Data.Packed.ST (
20 -- * Mutable Vectors
21 STVector, newVector, thawVector, freezeVector, runSTVector,
22 readVector, writeVector, modifyVector, liftSTVector,
23 -- * Mutable Matrices
24 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
25 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
26 -- * Unsafe functions
27 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector,
29 unsafeThawVector, unsafeFreezeVector,
30 newUndefinedMatrix,
31 unsafeReadMatrix, unsafeWriteMatrix,
32 unsafeThawMatrix, unsafeFreezeMatrix
33) where
34
35import Data.Packed.Internal
36
37import Control.Monad.ST(ST, runST)
38import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
39
40#if MIN_VERSION_base(4,4,0)
41import Control.Monad.ST.Unsafe(unsafeIOToST)
42#else
43import Control.Monad.ST(unsafeIOToST)
44#endif
45
46{-# INLINE ioReadV #-}
47ioReadV :: Storable t => Vector t -> Int -> IO t
48ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
49
50{-# INLINE ioWriteV #-}
51ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
52ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
53
54newtype STVector s t = STVector (Vector t)
55
56thawVector :: Storable t => Vector t -> ST s (STVector s t)
57thawVector = unsafeIOToST . fmap STVector . cloneVector
58
59unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
60unsafeThawVector = unsafeIOToST . return . STVector
61
62runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
63runSTVector st = runST (st >>= unsafeFreezeVector)
64
65{-# INLINE unsafeReadVector #-}
66unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
67unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
68
69{-# INLINE unsafeWriteVector #-}
70unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
71unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
72
73{-# INLINE modifyVector #-}
74modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
75modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
76
77liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
78liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
79
80freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
81freezeVector v = liftSTVector id v
82
83unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
84unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
85
86{-# INLINE safeIndexV #-}
87safeIndexV f (STVector v) k
88 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
89 ++show (dim v)++", pos="++show k++")"
90 | otherwise = f (STVector v) k
91
92{-# INLINE readVector #-}
93readVector :: Storable t => STVector s t -> Int -> ST s t
94readVector = safeIndexV unsafeReadVector
95
96{-# INLINE writeVector #-}
97writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
98writeVector = safeIndexV unsafeWriteVector
99
100newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
101newUndefinedVector = unsafeIOToST . fmap STVector . createVector
102
103{-# INLINE newVector #-}
104newVector :: Storable t => t -> Int -> ST s (STVector s t)
105newVector x n = do
106 v <- newUndefinedVector n
107 let go (-1) = return v
108 go !k = unsafeWriteVector v k x >> go (k-1 :: Int)
109 go (n-1)
110
111-------------------------------------------------------------------------
112
113{-# INLINE ioReadM #-}
114ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
115ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
116ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
117
118{-# INLINE ioWriteM #-}
119ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
120ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
121ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
122
123newtype STMatrix s t = STMatrix (Matrix t)
124
125thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
126thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
127
128unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
129unsafeThawMatrix = unsafeIOToST . return . STMatrix
130
131runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
132runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
133
134{-# INLINE unsafeReadMatrix #-}
135unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
136unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
137
138{-# INLINE unsafeWriteMatrix #-}
139unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
140unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
141
142{-# INLINE modifyMatrix #-}
143modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
144modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
145
146liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
147liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
148
149unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
150unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
151
152freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
153freezeMatrix m = liftSTMatrix id m
154
155cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
156
157{-# INLINE safeIndexM #-}
158safeIndexM f (STMatrix m) r c
159 | r<0 || r>=rows m ||
160 c<0 || c>=cols m = error $ "out of range error in matrix (size="
161 ++show (rows m,cols m)++", pos="++show (r,c)++")"
162 | otherwise = f (STMatrix m) r c
163
164{-# INLINE readMatrix #-}
165readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
166readMatrix = safeIndexM unsafeReadMatrix
167
168{-# INLINE writeMatrix #-}
169writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
170writeMatrix = safeIndexM unsafeWriteMatrix
171
172newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
173newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
174
175{-# NOINLINE newMatrix #-}
176newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
177newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
178