diff options
Diffstat (limited to 'packages/hmatrix/src/Data/Packed/ST.hs')
-rw-r--r-- | packages/hmatrix/src/Data/Packed/ST.hs | 179 |
1 files changed, 0 insertions, 179 deletions
diff --git a/packages/hmatrix/src/Data/Packed/ST.hs b/packages/hmatrix/src/Data/Packed/ST.hs deleted file mode 100644 index 1cef296..0000000 --- a/packages/hmatrix/src/Data/Packed/ST.hs +++ /dev/null | |||
@@ -1,179 +0,0 @@ | |||
1 | {-# LANGUAGE CPP #-} | ||
2 | {-# LANGUAGE TypeOperators #-} | ||
3 | {-# LANGUAGE Rank2Types #-} | ||
4 | {-# LANGUAGE BangPatterns #-} | ||
5 | ----------------------------------------------------------------------------- | ||
6 | -- | | ||
7 | -- Module : Data.Packed.ST | ||
8 | -- Copyright : (c) Alberto Ruiz 2008 | ||
9 | -- License : GPL-style | ||
10 | -- | ||
11 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
12 | -- Stability : provisional | ||
13 | -- Portability : portable | ||
14 | -- | ||
15 | -- In-place manipulation inside the ST monad. | ||
16 | -- See examples/inplace.hs in the distribution. | ||
17 | -- | ||
18 | ----------------------------------------------------------------------------- | ||
19 | {-# OPTIONS_HADDOCK hide #-} | ||
20 | |||
21 | module Data.Packed.ST ( | ||
22 | -- * Mutable Vectors | ||
23 | STVector, newVector, thawVector, freezeVector, runSTVector, | ||
24 | readVector, writeVector, modifyVector, liftSTVector, | ||
25 | -- * Mutable Matrices | ||
26 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | ||
27 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | ||
28 | -- * Unsafe functions | ||
29 | newUndefinedVector, | ||
30 | unsafeReadVector, unsafeWriteVector, | ||
31 | unsafeThawVector, unsafeFreezeVector, | ||
32 | newUndefinedMatrix, | ||
33 | unsafeReadMatrix, unsafeWriteMatrix, | ||
34 | unsafeThawMatrix, unsafeFreezeMatrix | ||
35 | ) where | ||
36 | |||
37 | import Data.Packed.Internal | ||
38 | |||
39 | import Control.Monad.ST(ST, runST) | ||
40 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) | ||
41 | |||
42 | #if MIN_VERSION_base(4,4,0) | ||
43 | import Control.Monad.ST.Unsafe(unsafeIOToST) | ||
44 | #else | ||
45 | import Control.Monad.ST(unsafeIOToST) | ||
46 | #endif | ||
47 | |||
48 | {-# INLINE ioReadV #-} | ||
49 | ioReadV :: Storable t => Vector t -> Int -> IO t | ||
50 | ioReadV v k = unsafeWith v $ \s -> peekElemOff s k | ||
51 | |||
52 | {-# INLINE ioWriteV #-} | ||
53 | ioWriteV :: Storable t => Vector t -> Int -> t -> IO () | ||
54 | ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x | ||
55 | |||
56 | newtype STVector s t = STVector (Vector t) | ||
57 | |||
58 | thawVector :: Storable t => Vector t -> ST s (STVector s t) | ||
59 | thawVector = unsafeIOToST . fmap STVector . cloneVector | ||
60 | |||
61 | unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t) | ||
62 | unsafeThawVector = unsafeIOToST . return . STVector | ||
63 | |||
64 | runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t | ||
65 | runSTVector st = runST (st >>= unsafeFreezeVector) | ||
66 | |||
67 | {-# INLINE unsafeReadVector #-} | ||
68 | unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t | ||
69 | unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x | ||
70 | |||
71 | {-# INLINE unsafeWriteVector #-} | ||
72 | unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s () | ||
73 | unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k | ||
74 | |||
75 | {-# INLINE modifyVector #-} | ||
76 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () | ||
77 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k | ||
78 | |||
79 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a | ||
80 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x | ||
81 | |||
82 | freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | ||
83 | freezeVector v = liftSTVector id v | ||
84 | |||
85 | unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | ||
86 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | ||
87 | |||
88 | {-# INLINE safeIndexV #-} | ||
89 | safeIndexV f (STVector v) k | ||
90 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" | ||
91 | ++show (dim v)++", pos="++show k++")" | ||
92 | | otherwise = f (STVector v) k | ||
93 | |||
94 | {-# INLINE readVector #-} | ||
95 | readVector :: Storable t => STVector s t -> Int -> ST s t | ||
96 | readVector = safeIndexV unsafeReadVector | ||
97 | |||
98 | {-# INLINE writeVector #-} | ||
99 | writeVector :: Storable t => STVector s t -> Int -> t -> ST s () | ||
100 | writeVector = safeIndexV unsafeWriteVector | ||
101 | |||
102 | newUndefinedVector :: Storable t => Int -> ST s (STVector s t) | ||
103 | newUndefinedVector = unsafeIOToST . fmap STVector . createVector | ||
104 | |||
105 | {-# INLINE newVector #-} | ||
106 | newVector :: Storable t => t -> Int -> ST s (STVector s t) | ||
107 | newVector x n = do | ||
108 | v <- newUndefinedVector n | ||
109 | let go (-1) = return v | ||
110 | go !k = unsafeWriteVector v k x >> go (k-1 :: Int) | ||
111 | go (n-1) | ||
112 | |||
113 | ------------------------------------------------------------------------- | ||
114 | |||
115 | {-# INLINE ioReadM #-} | ||
116 | ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t | ||
117 | ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) | ||
118 | ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) | ||
119 | |||
120 | {-# INLINE ioWriteM #-} | ||
121 | ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () | ||
122 | ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val | ||
123 | ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val | ||
124 | |||
125 | newtype STMatrix s t = STMatrix (Matrix t) | ||
126 | |||
127 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | ||
128 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix | ||
129 | |||
130 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | ||
131 | unsafeThawMatrix = unsafeIOToST . return . STMatrix | ||
132 | |||
133 | runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t | ||
134 | runSTMatrix st = runST (st >>= unsafeFreezeMatrix) | ||
135 | |||
136 | {-# INLINE unsafeReadMatrix #-} | ||
137 | unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t | ||
138 | unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r | ||
139 | |||
140 | {-# INLINE unsafeWriteMatrix #-} | ||
141 | unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | ||
142 | unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | ||
143 | |||
144 | {-# INLINE modifyMatrix #-} | ||
145 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | ||
146 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | ||
147 | |||
148 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a | ||
149 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | ||
150 | |||
151 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | ||
152 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | ||
153 | |||
154 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | ||
155 | freezeMatrix m = liftSTMatrix id m | ||
156 | |||
157 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) | ||
158 | |||
159 | {-# INLINE safeIndexM #-} | ||
160 | safeIndexM f (STMatrix m) r c | ||
161 | | r<0 || r>=rows m || | ||
162 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" | ||
163 | ++show (rows m,cols m)++", pos="++show (r,c)++")" | ||
164 | | otherwise = f (STMatrix m) r c | ||
165 | |||
166 | {-# INLINE readMatrix #-} | ||
167 | readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t | ||
168 | readMatrix = safeIndexM unsafeReadMatrix | ||
169 | |||
170 | {-# INLINE writeMatrix #-} | ||
171 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | ||
172 | writeMatrix = safeIndexM unsafeWriteMatrix | ||
173 | |||
174 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) | ||
175 | newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c | ||
176 | |||
177 | {-# NOINLINE newMatrix #-} | ||
178 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) | ||
179 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | ||