diff options
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r-- | packages/base/src/Internal/ST.hs | 59 |
1 files changed, 49 insertions, 10 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index a84ca25..434fe63 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -10,7 +10,7 @@ | |||
10 | -- Stability : provisional | 10 | -- Stability : provisional |
11 | -- | 11 | -- |
12 | -- In-place manipulation inside the ST monad. | 12 | -- In-place manipulation inside the ST monad. |
13 | -- See examples/inplace.hs in the distribution. | 13 | -- See @examples/inplace.hs@ in the repository. |
14 | -- | 14 | -- |
15 | ----------------------------------------------------------------------------- | 15 | ----------------------------------------------------------------------------- |
16 | 16 | ||
@@ -21,8 +21,8 @@ module Internal.ST ( | |||
21 | -- * Mutable Matrices | 21 | -- * Mutable Matrices |
22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
24 | axpy, scal, swap, extractMatrix, setMatrix, rowOpST, | 24 | -- axpy, scal, swap, rowOp, |
25 | mutable, | 25 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), |
26 | -- * Unsafe functions | 26 | -- * Unsafe functions |
27 | newUndefinedVector, | 27 | newUndefinedVector, |
28 | unsafeReadVector, unsafeWriteVector, | 28 | unsafeReadVector, unsafeWriteVector, |
@@ -178,16 +178,55 @@ newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | |||
178 | 178 | ||
179 | -------------------------------------------------------------------------------- | 179 | -------------------------------------------------------------------------------- |
180 | 180 | ||
181 | rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () | 181 | data ColRange = AllCols |
182 | rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) | 182 | | ColRange Int Int |
183 | | Col Int | ||
184 | | FromCol Int | ||
183 | 185 | ||
184 | axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) | 186 | getColRange c AllCols = (0,c-1) |
185 | scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) | 187 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) |
186 | swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) | 188 | getColRange c (Col a) = (a `mod` c, a `mod` c) |
189 | getColRange c (FromCol a) = (a `mod` c, c-1) | ||
187 | 190 | ||
188 | extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 191 | data RowRange = AllRows |
192 | | RowRange Int Int | ||
193 | | Row Int | ||
194 | | FromRow Int | ||
195 | |||
196 | getRowRange r AllRows = (0,r-1) | ||
197 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) | ||
198 | getRowRange r (Row a) = (a `mod` r, a `mod` r) | ||
199 | getRowRange r (FromRow a) = (a `mod` r, r-1) | ||
200 | |||
201 | data RowOper t = AXPY t Int Int ColRange | ||
202 | | SCAL t RowRange ColRange | ||
203 | | SWAP Int Int ColRange | ||
204 | |||
205 | rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () | ||
206 | |||
207 | rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m | ||
208 | where | ||
209 | (j1,j2) = getColRange (cols m) r | ||
210 | i1' = i1 `mod` (rows m) | ||
211 | i2' = i2 `mod` (rows m) | ||
212 | |||
213 | rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m | ||
214 | where | ||
215 | (i1,i2) = getRowRange (rows m) rr | ||
216 | (j1,j2) = getColRange (cols m) rc | ||
217 | |||
218 | rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | ||
219 | where | ||
220 | (j1,j2) = getColRange (cols m) r | ||
221 | i1' = i1 `mod` (rows m) | ||
222 | i2' = i2 `mod` (rows m) | ||
223 | |||
224 | |||
225 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | ||
226 | where | ||
227 | (i1,i2) = getRowRange (rows m) rr | ||
228 | (j1,j2) = getColRange (cols m) rc | ||
189 | 229 | ||
190 | -------------------------------------------------------------------------------- | ||
191 | 230 | ||
192 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 231 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
193 | mutable f a = runST $ do | 232 | mutable f a = runST $ do |