summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-12 20:58:13 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-12 20:58:13 +0200
commit4b3e29097aa272d429f8005fe17b459cf0c049c8 (patch)
treedf01591ec7bdffe61f68062cc09e95f69e745a90 /packages/base/src/Internal/ST.hs
parent0396adb9f10f5b337e54d64fec365c9cb01e9745 (diff)
row ops in ST
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs26
1 files changed, 24 insertions, 2 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index ae75a1b..107d3c3 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3
3----------------------------------------------------------------------------- 4-----------------------------------------------------------------------------
4-- | 5-- |
5-- Module : Internal.ST 6-- Module : Internal.ST
@@ -20,6 +21,8 @@ module Internal.ST (
20 -- * Mutable Matrices 21 -- * Mutable Matrices
21 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
22 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24 axpy, scal, swap, extractRect,
25 mutable,
23 -- * Unsafe functions 26 -- * Unsafe functions
24 newUndefinedVector, 27 newUndefinedVector,
25 unsafeReadVector, unsafeWriteVector, 28 unsafeReadVector, unsafeWriteVector,
@@ -34,8 +37,6 @@ import Internal.Matrix
34import Internal.Vectorized 37import Internal.Vectorized
35import Control.Monad.ST(ST, runST) 38import Control.Monad.ST(ST, runST)
36import Foreign.Storable(Storable, peekElemOff, pokeElemOff) 39import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
37
38
39import Control.Monad.ST.Unsafe(unsafeIOToST) 40import Control.Monad.ST.Unsafe(unsafeIOToST)
40 41
41{-# INLINE ioReadV #-} 42{-# INLINE ioReadV #-}
@@ -144,6 +145,7 @@ liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
144unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 145unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
145unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 146unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
146 147
148
147freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 149freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
148freezeMatrix m = liftSTMatrix id m 150freezeMatrix m = liftSTMatrix id m
149 151
@@ -171,3 +173,23 @@ newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
171newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) 173newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
172newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) 174newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
173 175
176--------------------------------------------------------------------------------
177
178rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s ()
179rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m)
180
181axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m)
182scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m)
183swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m)
184
185extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
186
187--------------------------------------------------------------------------------
188
189mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
190mutable f a = runST $ do
191 x <- thawMatrix a
192 info <- f (rows a, cols a) x
193 r <- unsafeFreezeMatrix x
194 return (r,info)
195