summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs26
1 files changed, 24 insertions, 2 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index ae75a1b..107d3c3 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3
3----------------------------------------------------------------------------- 4-----------------------------------------------------------------------------
4-- | 5-- |
5-- Module : Internal.ST 6-- Module : Internal.ST
@@ -20,6 +21,8 @@ module Internal.ST (
20 -- * Mutable Matrices 21 -- * Mutable Matrices
21 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
22 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24 axpy, scal, swap, extractRect,
25 mutable,
23 -- * Unsafe functions 26 -- * Unsafe functions
24 newUndefinedVector, 27 newUndefinedVector,
25 unsafeReadVector, unsafeWriteVector, 28 unsafeReadVector, unsafeWriteVector,
@@ -34,8 +37,6 @@ import Internal.Matrix
34import Internal.Vectorized 37import Internal.Vectorized
35import Control.Monad.ST(ST, runST) 38import Control.Monad.ST(ST, runST)
36import Foreign.Storable(Storable, peekElemOff, pokeElemOff) 39import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
37
38
39import Control.Monad.ST.Unsafe(unsafeIOToST) 40import Control.Monad.ST.Unsafe(unsafeIOToST)
40 41
41{-# INLINE ioReadV #-} 42{-# INLINE ioReadV #-}
@@ -144,6 +145,7 @@ liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
144unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 145unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
145unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 146unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
146 147
148
147freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 149freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
148freezeMatrix m = liftSTMatrix id m 150freezeMatrix m = liftSTMatrix id m
149 151
@@ -171,3 +173,23 @@ newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
171newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) 173newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
172newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) 174newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
173 175
176--------------------------------------------------------------------------------
177
178rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s ()
179rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m)
180
181axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m)
182scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m)
183swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m)
184
185extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
186
187--------------------------------------------------------------------------------
188
189mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
190mutable f a = runST $ do
191 x <- thawMatrix a
192 info <- f (rows a, cols a) x
193 r <- unsafeFreezeMatrix x
194 return (r,info)
195