summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs59
1 files changed, 49 insertions, 10 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index a84ca25..434fe63 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -10,7 +10,7 @@
10-- Stability : provisional 10-- Stability : provisional
11-- 11--
12-- In-place manipulation inside the ST monad. 12-- In-place manipulation inside the ST monad.
13-- See examples/inplace.hs in the distribution. 13-- See @examples/inplace.hs@ in the repository.
14-- 14--
15----------------------------------------------------------------------------- 15-----------------------------------------------------------------------------
16 16
@@ -21,8 +21,8 @@ module Internal.ST (
21 -- * Mutable Matrices 21 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24 axpy, scal, swap, extractMatrix, setMatrix, rowOpST, 24-- axpy, scal, swap, rowOp,
25 mutable, 25 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
26 -- * Unsafe functions 26 -- * Unsafe functions
27 newUndefinedVector, 27 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector, 28 unsafeReadVector, unsafeWriteVector,
@@ -178,16 +178,55 @@ newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
178 178
179-------------------------------------------------------------------------------- 179--------------------------------------------------------------------------------
180 180
181rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () 181data ColRange = AllCols
182rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) 182 | ColRange Int Int
183 | Col Int
184 | FromCol Int
183 185
184axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) 186getColRange c AllCols = (0,c-1)
185scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) 187getColRange c (ColRange a b) = (a `mod` c, b `mod` c)
186swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) 188getColRange c (Col a) = (a `mod` c, a `mod` c)
189getColRange c (FromCol a) = (a `mod` c, c-1)
187 190
188extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 191data RowRange = AllRows
192 | RowRange Int Int
193 | Row Int
194 | FromRow Int
195
196getRowRange r AllRows = (0,r-1)
197getRowRange r (RowRange a b) = (a `mod` r, b `mod` r)
198getRowRange r (Row a) = (a `mod` r, a `mod` r)
199getRowRange r (FromRow a) = (a `mod` r, r-1)
200
201data RowOper t = AXPY t Int Int ColRange
202 | SCAL t RowRange ColRange
203 | SWAP Int Int ColRange
204
205rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s ()
206
207rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m
208 where
209 (j1,j2) = getColRange (cols m) r
210 i1' = i1 `mod` (rows m)
211 i2' = i2 `mod` (rows m)
212
213rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m
214 where
215 (i1,i2) = getRowRange (rows m) rr
216 (j1,j2) = getColRange (cols m) rc
217
218rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
219 where
220 (j1,j2) = getColRange (cols m) r
221 i1' = i1 `mod` (rows m)
222 i2' = i2 `mod` (rows m)
223
224
225extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
226 where
227 (i1,i2) = getRowRange (rows m) rr
228 (j1,j2) = getColRange (cols m) rc
189 229
190--------------------------------------------------------------------------------
191 230
192mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 231mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
193mutable f a = runST $ do 232mutable f a = runST $ do