diff options
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r-- | packages/base/src/Internal/ST.hs | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index ae75a1b..107d3c3 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -1,5 +1,6 @@ | |||
1 | {-# LANGUAGE Rank2Types #-} | 1 | {-# LANGUAGE Rank2Types #-} |
2 | {-# LANGUAGE BangPatterns #-} | 2 | {-# LANGUAGE BangPatterns #-} |
3 | |||
3 | ----------------------------------------------------------------------------- | 4 | ----------------------------------------------------------------------------- |
4 | -- | | 5 | -- | |
5 | -- Module : Internal.ST | 6 | -- Module : Internal.ST |
@@ -20,6 +21,8 @@ module Internal.ST ( | |||
20 | -- * Mutable Matrices | 21 | -- * Mutable Matrices |
21 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
22 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
24 | axpy, scal, swap, extractRect, | ||
25 | mutable, | ||
23 | -- * Unsafe functions | 26 | -- * Unsafe functions |
24 | newUndefinedVector, | 27 | newUndefinedVector, |
25 | unsafeReadVector, unsafeWriteVector, | 28 | unsafeReadVector, unsafeWriteVector, |
@@ -34,8 +37,6 @@ import Internal.Matrix | |||
34 | import Internal.Vectorized | 37 | import Internal.Vectorized |
35 | import Control.Monad.ST(ST, runST) | 38 | import Control.Monad.ST(ST, runST) |
36 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) | 39 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) |
37 | |||
38 | |||
39 | import Control.Monad.ST.Unsafe(unsafeIOToST) | 40 | import Control.Monad.ST.Unsafe(unsafeIOToST) |
40 | 41 | ||
41 | {-# INLINE ioReadV #-} | 42 | {-# INLINE ioReadV #-} |
@@ -144,6 +145,7 @@ liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | |||
144 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 145 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) |
145 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 146 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
146 | 147 | ||
148 | |||
147 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 149 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) |
148 | freezeMatrix m = liftSTMatrix id m | 150 | freezeMatrix m = liftSTMatrix id m |
149 | 151 | ||
@@ -171,3 +173,23 @@ newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c | |||
171 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) | 173 | newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) |
172 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | 174 | newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) |
173 | 175 | ||
176 | -------------------------------------------------------------------------------- | ||
177 | |||
178 | rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () | ||
179 | rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) | ||
180 | |||
181 | axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) | ||
182 | scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) | ||
183 | swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) | ||
184 | |||
185 | extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | ||
186 | |||
187 | -------------------------------------------------------------------------------- | ||
188 | |||
189 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | ||
190 | mutable f a = runST $ do | ||
191 | x <- thawMatrix a | ||
192 | info <- f (rows a, cols a) x | ||
193 | r <- unsafeFreezeMatrix x | ||
194 | return (r,info) | ||
195 | |||