diff options
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r-- | packages/base/src/Internal/ST.hs | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 434fe63..25e7f03 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -1,5 +1,6 @@ | |||
1 | {-# LANGUAGE Rank2Types #-} | 1 | {-# LANGUAGE Rank2Types #-} |
2 | {-# LANGUAGE BangPatterns #-} | 2 | {-# LANGUAGE BangPatterns #-} |
3 | {-# LANGUAGE ViewPatterns #-} | ||
3 | 4 | ||
4 | ----------------------------------------------------------------------------- | 5 | ----------------------------------------------------------------------------- |
5 | -- | | 6 | -- | |
@@ -15,14 +16,14 @@ | |||
15 | ----------------------------------------------------------------------------- | 16 | ----------------------------------------------------------------------------- |
16 | 17 | ||
17 | module Internal.ST ( | 18 | module Internal.ST ( |
19 | ST, runST, | ||
18 | -- * Mutable Vectors | 20 | -- * Mutable Vectors |
19 | STVector, newVector, thawVector, freezeVector, runSTVector, | 21 | STVector, newVector, thawVector, freezeVector, runSTVector, |
20 | readVector, writeVector, modifyVector, liftSTVector, | 22 | readVector, writeVector, modifyVector, liftSTVector, |
21 | -- * Mutable Matrices | 23 | -- * Mutable Matrices |
22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 24 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 25 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
24 | -- axpy, scal, swap, rowOp, | 26 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..), |
25 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), | ||
26 | -- * Unsafe functions | 27 | -- * Unsafe functions |
27 | newUndefinedVector, | 28 | newUndefinedVector, |
28 | unsafeReadVector, unsafeWriteVector, | 29 | unsafeReadVector, unsafeWriteVector, |
@@ -70,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k | |||
70 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () | 71 | modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () |
71 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k | 72 | modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k |
72 | 73 | ||
73 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a | 74 | liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a |
74 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x | 75 | liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x |
75 | 76 | ||
76 | freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | 77 | freezeVector :: (Storable t) => STVector s t -> ST s (Vector t) |
77 | freezeVector v = liftSTVector id v | 78 | freezeVector v = liftSTVector id v |
78 | 79 | ||
79 | unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) | 80 | unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) |
80 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | 81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x |
81 | 82 | ||
82 | {-# INLINE safeIndexV #-} | 83 | {-# INLINE safeIndexV #-} |
@@ -139,14 +140,14 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | |||
139 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | 140 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () |
140 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | 141 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c |
141 | 142 | ||
142 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a | 143 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a |
143 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | 144 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x |
144 | 145 | ||
145 | unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 146 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
146 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 147 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
147 | 148 | ||
148 | 149 | ||
149 | freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) | 150 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
150 | freezeMatrix m = liftSTMatrix id m | 151 | freezeMatrix m = liftSTMatrix id m |
151 | 152 | ||
152 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) | 153 | cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) |
@@ -227,6 +228,16 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i | |||
227 | (i1,i2) = getRowRange (rows m) rr | 228 | (i1,i2) = getRowRange (rows m) rr |
228 | (j1,j2) = getColRange (cols m) rc | 229 | (j1,j2) = getColRange (cols m) rc |
229 | 230 | ||
231 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | ||
232 | |||
233 | slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) | ||
234 | |||
235 | gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res | ||
236 | where | ||
237 | res = unsafeIOToST (gemm u v a b r) | ||
238 | u = fromList [alpha,beta] | ||
239 | v = vjoin[pa,pb,pr] | ||
240 | |||
230 | 241 | ||
231 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 242 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
232 | mutable f a = runST $ do | 243 | mutable f a = runST $ do |