summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs9
1 files changed, 4 insertions, 5 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 91c2a11..62dfddf 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -231,14 +231,13 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
231-- | r0 c0 height width 231-- | r0 c0 height width
232data Slice s t = Slice (STMatrix s t) Int Int Int Int 232data Slice s t = Slice (STMatrix s t) Int Int Int Int
233 233
234slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) 234slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m
235 235
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
237gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res 237gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
238 where 238 where
239 res = unsafeIOToST (gemm u v a b r) 239 res = unsafeIOToST (gemm v a b r)
240 u = fromList [alpha,beta] 240 v = fromList [alpha,beta]
241 v = vjoin[pa,pb,pr]
242 241
243 242
244mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 243mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)