summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-28 14:18:25 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-28 14:18:25 +0200
commit2749f4ef144cbc8541d70434f46abf312a1bb42e (patch)
tree7e93d163e0a07e3fc6c1001433be953bcdd197f6 /packages/base/src/Internal/ST.hs
parent4d96b90c4cfd38cdb51f3dc66a8a644bd87cdbff (diff)
copy slice
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs11
1 files changed, 5 insertions, 6 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 23fda99..91c2a11 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -119,7 +119,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val
119 119
120newtype STMatrix s t = STMatrix (Matrix t) 120newtype STMatrix s t = STMatrix (Matrix t)
121 121
122thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) 122thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t)
123thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix 123thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
124 124
125unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) 125unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
@@ -140,18 +140,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
142 142
143liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a 143liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a
144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
145 145
146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) 146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
148 148
149 149
150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) 150freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
151freezeMatrix m = liftSTMatrix id m 151freezeMatrix m = liftSTMatrix id m
152 152
153-- FIXME 153cloneMatrix m = copy (orderOf m) m
154cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'})
155 154
156{-# INLINE safeIndexM #-} 155{-# INLINE safeIndexM #-}
157safeIndexM f (STMatrix m) r c 156safeIndexM f (STMatrix m) r c
@@ -242,7 +241,7 @@ gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res
242 v = vjoin[pa,pb,pr] 241 v = vjoin[pa,pb,pr]
243 242
244 243
245mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 244mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
246mutable f a = runST $ do 245mutable f a = runST $ do
247 x <- thawMatrix a 246 x <- thawMatrix a
248 info <- f (rows a, cols a) x 247 info <- f (rows a, cols a) x