summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs12
1 files changed, 10 insertions, 2 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 544c9e4..7d54e6d 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x 81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
82 82
83{-# INLINE safeIndexV #-} 83{-# INLINE safeIndexV #-}
84safeIndexV :: Storable t2
85 => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t
84safeIndexV f (STVector v) k 86safeIndexV f (STVector v) k
85 | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" 87 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
86 ++show (dim v)++", pos="++show k++")" 88 ++show (dim v)++", pos="++show k++")"
@@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
150freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) 152freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
151freezeMatrix m = liftSTMatrix id m 153freezeMatrix m = liftSTMatrix id m
152 154
155cloneMatrix :: Element t => Matrix t -> IO (Matrix t)
153cloneMatrix m = copy (orderOf m) m 156cloneMatrix m = copy (orderOf m) m
154 157
155{-# INLINE safeIndexM #-} 158{-# INLINE safeIndexM #-}
159safeIndexM :: (STMatrix s t2 -> Int -> Int -> t)
160 -> STMatrix t1 t2 -> Int -> Int -> t
156safeIndexM f (STMatrix m) r c 161safeIndexM f (STMatrix m) r c
157 | r<0 || r>=rows m || 162 | r<0 || r>=rows m ||
158 c<0 || c>=cols m = error $ "out of range error in matrix (size=" 163 c<0 || c>=cols m = error $ "out of range error in matrix (size="
@@ -184,6 +189,7 @@ data ColRange = AllCols
184 | Col Int 189 | Col Int
185 | FromCol Int 190 | FromCol Int
186 191
192getColRange :: Int -> ColRange -> (Int, Int)
187getColRange c AllCols = (0,c-1) 193getColRange c AllCols = (0,c-1)
188getColRange c (ColRange a b) = (a `mod` c, b `mod` c) 194getColRange c (ColRange a b) = (a `mod` c, b `mod` c)
189getColRange c (Col a) = (a `mod` c, a `mod` c) 195getColRange c (Col a) = (a `mod` c, a `mod` c)
@@ -194,6 +200,7 @@ data RowRange = AllRows
194 | Row Int 200 | Row Int
195 | FromRow Int 201 | FromRow Int
196 202
203getRowRange :: Int -> RowRange -> (Int, Int)
197getRowRange r AllRows = (0,r-1) 204getRowRange r AllRows = (0,r-1)
198getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) 205getRowRange r (RowRange a b) = (a `mod` r, b `mod` r)
199getRowRange r (Row a) = (a `mod` r, a `mod` r) 206getRowRange r (Row a) = (a `mod` r, a `mod` r)
@@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
223 i2' = i2 `mod` (rows m) 230 i2' = i2 `mod` (rows m)
224 231
225 232
233extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
226extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 234extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
227 where 235 where
228 (i1,i2) = getRowRange (rows m) rr 236 (i1,i2) = getRowRange (rows m) rr
@@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
231-- | r0 c0 height width 239-- | r0 c0 height width
232data Slice s t = Slice (STMatrix s t) Int Int Int Int 240data Slice s t = Slice (STMatrix s t) Int Int Int Int
233 241
242slice :: Element a => Slice t a -> Matrix a
234slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m 243slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m
235 244
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 245gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
@@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
238 where 247 where
239 res = unsafeIOToST (gemm v a b r) 248 res = unsafeIOToST (gemm v a b r)
240 v = fromList [alpha,beta] 249 v = fromList [alpha,beta]
241 250
242 251
243mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 252mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
244mutable f a = runST $ do 253mutable f a = runST $ do
@@ -246,4 +255,3 @@ mutable f a = runST $ do
246 info <- f (rows a, cols a) x 255 info <- f (rows a, cols a) x
247 r <- unsafeFreezeMatrix x 256 r <- unsafeFreezeMatrix x
248 return (r,info) 257 return (r,info)
249