diff options
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r-- | packages/base/src/Internal/ST.hs | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 544c9e4..7d54e6d 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -81,6 +81,8 @@ unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t) | |||
81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x | 81 | unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x |
82 | 82 | ||
83 | {-# INLINE safeIndexV #-} | 83 | {-# INLINE safeIndexV #-} |
84 | safeIndexV :: Storable t2 | ||
85 | => (STVector s t2 -> Int -> t) -> STVector t1 t2 -> Int -> t | ||
84 | safeIndexV f (STVector v) k | 86 | safeIndexV f (STVector v) k |
85 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" | 87 | | k < 0 || k>= dim v = error $ "out of range error in vector (dim=" |
86 | ++show (dim v)++", pos="++show k++")" | 88 | ++show (dim v)++", pos="++show k++")" |
@@ -150,9 +152,12 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | |||
150 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) | 152 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) |
151 | freezeMatrix m = liftSTMatrix id m | 153 | freezeMatrix m = liftSTMatrix id m |
152 | 154 | ||
155 | cloneMatrix :: Element t => Matrix t -> IO (Matrix t) | ||
153 | cloneMatrix m = copy (orderOf m) m | 156 | cloneMatrix m = copy (orderOf m) m |
154 | 157 | ||
155 | {-# INLINE safeIndexM #-} | 158 | {-# INLINE safeIndexM #-} |
159 | safeIndexM :: (STMatrix s t2 -> Int -> Int -> t) | ||
160 | -> STMatrix t1 t2 -> Int -> Int -> t | ||
156 | safeIndexM f (STMatrix m) r c | 161 | safeIndexM f (STMatrix m) r c |
157 | | r<0 || r>=rows m || | 162 | | r<0 || r>=rows m || |
158 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" | 163 | c<0 || c>=cols m = error $ "out of range error in matrix (size=" |
@@ -184,6 +189,7 @@ data ColRange = AllCols | |||
184 | | Col Int | 189 | | Col Int |
185 | | FromCol Int | 190 | | FromCol Int |
186 | 191 | ||
192 | getColRange :: Int -> ColRange -> (Int, Int) | ||
187 | getColRange c AllCols = (0,c-1) | 193 | getColRange c AllCols = (0,c-1) |
188 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) | 194 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) |
189 | getColRange c (Col a) = (a `mod` c, a `mod` c) | 195 | getColRange c (Col a) = (a `mod` c, a `mod` c) |
@@ -194,6 +200,7 @@ data RowRange = AllRows | |||
194 | | Row Int | 200 | | Row Int |
195 | | FromRow Int | 201 | | FromRow Int |
196 | 202 | ||
203 | getRowRange :: Int -> RowRange -> (Int, Int) | ||
197 | getRowRange r AllRows = (0,r-1) | 204 | getRowRange r AllRows = (0,r-1) |
198 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) | 205 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) |
199 | getRowRange r (Row a) = (a `mod` r, a `mod` r) | 206 | getRowRange r (Row a) = (a `mod` r, a `mod` r) |
@@ -223,6 +230,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | |||
223 | i2' = i2 `mod` (rows m) | 230 | i2' = i2 `mod` (rows m) |
224 | 231 | ||
225 | 232 | ||
233 | extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) | ||
226 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 234 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
227 | where | 235 | where |
228 | (i1,i2) = getRowRange (rows m) rr | 236 | (i1,i2) = getRowRange (rows m) rr |
@@ -231,6 +239,7 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
231 | -- | r0 c0 height width | 239 | -- | r0 c0 height width |
232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 240 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
233 | 241 | ||
242 | slice :: Element a => Slice t a -> Matrix a | ||
234 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m | 243 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
235 | 244 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 245 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
@@ -238,7 +247,7 @@ gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | |||
238 | where | 247 | where |
239 | res = unsafeIOToST (gemm v a b r) | 248 | res = unsafeIOToST (gemm v a b r) |
240 | v = fromList [alpha,beta] | 249 | v = fromList [alpha,beta] |
241 | 250 | ||
242 | 251 | ||
243 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 252 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
244 | mutable f a = runST $ do | 253 | mutable f a = runST $ do |
@@ -246,4 +255,3 @@ mutable f a = runST $ do | |||
246 | info <- f (rows a, cols a) x | 255 | info <- f (rows a, cols a) x |
247 | r <- unsafeFreezeMatrix x | 256 | r <- unsafeFreezeMatrix x |
248 | return (r,info) | 257 | return (r,info) |
249 | |||