summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs27
1 files changed, 19 insertions, 8 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 434fe63..25e7f03 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3{-# LANGUAGE ViewPatterns #-}
3 4
4----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
5-- | 6-- |
@@ -15,14 +16,14 @@
15----------------------------------------------------------------------------- 16-----------------------------------------------------------------------------
16 17
17module Internal.ST ( 18module Internal.ST (
19 ST, runST,
18 -- * Mutable Vectors 20 -- * Mutable Vectors
19 STVector, newVector, thawVector, freezeVector, runSTVector, 21 STVector, newVector, thawVector, freezeVector, runSTVector,
20 readVector, writeVector, modifyVector, liftSTVector, 22 readVector, writeVector, modifyVector, liftSTVector,
21 -- * Mutable Matrices 23 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 24 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 25 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24-- axpy, scal, swap, rowOp, 26 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
25 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
26 -- * Unsafe functions 27 -- * Unsafe functions
27 newUndefinedVector, 28 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector, 29 unsafeReadVector, unsafeWriteVector,
@@ -70,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
70modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () 71modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
71modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k 72modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
72 73
73liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a 74liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a
74liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x 75liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
75 76
76freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 77freezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
77freezeVector v = liftSTVector id v 78freezeVector v = liftSTVector id v
78 79
79unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 80unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
80unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x 81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
81 82
82{-# INLINE safeIndexV #-} 83{-# INLINE safeIndexV #-}
@@ -139,14 +140,14 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
139modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
140modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
141 142
142liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a 143liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a
143liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
144 145
145unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
146unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
147 148
148 149
149freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
150freezeMatrix m = liftSTMatrix id m 151freezeMatrix m = liftSTMatrix id m
151 152
152cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) 153cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
@@ -227,6 +228,16 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i
227 (i1,i2) = getRowRange (rows m) rr 228 (i1,i2) = getRowRange (rows m) rr
228 (j1,j2) = getColRange (cols m) rc 229 (j1,j2) = getColRange (cols m) rc
229 230
231data Slice s t = Slice (STMatrix s t) Int Int Int Int
232
233slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1])
234
235gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res
236 where
237 res = unsafeIOToST (gemm u v a b r)
238 u = fromList [alpha,beta]
239 v = vjoin[pa,pb,pr]
240
230 241
231mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 242mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
232mutable f a = runST $ do 243mutable f a = runST $ do