diff options
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r-- | packages/base/src/Internal/ST.hs | 131 |
1 files changed, 118 insertions, 13 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 7d54e6d..326b90a 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -1,6 +1,7 @@ | |||
1 | {-# LANGUAGE Rank2Types #-} | 1 | {-# LANGUAGE Rank2Types #-} |
2 | {-# LANGUAGE BangPatterns #-} | 2 | {-# LANGUAGE BangPatterns #-} |
3 | {-# LANGUAGE ViewPatterns #-} | 3 | {-# LANGUAGE ViewPatterns #-} |
4 | {-# LANGUAGE PatternSynonyms #-} | ||
4 | 5 | ||
5 | ----------------------------------------------------------------------------- | 6 | ----------------------------------------------------------------------------- |
6 | -- | | 7 | -- | |
@@ -30,14 +31,20 @@ module Internal.ST ( | |||
30 | unsafeThawVector, unsafeFreezeVector, | 31 | unsafeThawVector, unsafeFreezeVector, |
31 | newUndefinedMatrix, | 32 | newUndefinedMatrix, |
32 | unsafeReadMatrix, unsafeWriteMatrix, | 33 | unsafeReadMatrix, unsafeWriteMatrix, |
33 | unsafeThawMatrix, unsafeFreezeMatrix | 34 | unsafeThawMatrix, unsafeFreezeMatrix, |
35 | setRect | ||
34 | ) where | 36 | ) where |
35 | 37 | ||
36 | import Internal.Vector | 38 | import Internal.Vector |
37 | import Internal.Matrix | 39 | import Internal.Matrix |
38 | import Internal.Vectorized | 40 | import Internal.Vectorized |
41 | import Internal.Devel ((#|)) | ||
39 | import Control.Monad.ST(ST, runST) | 42 | import Control.Monad.ST(ST, runST) |
40 | import Foreign.Storable(Storable, peekElemOff, pokeElemOff) | 43 | import Control.Monad |
44 | import Data.Function | ||
45 | import Data.Int | ||
46 | import Foreign.Ptr | ||
47 | import Foreign.Storable | ||
41 | import Control.Monad.ST.Unsafe(unsafeIOToST) | 48 | import Control.Monad.ST.Unsafe(unsafeIOToST) |
42 | 49 | ||
43 | {-# INLINE ioReadV #-} | 50 | {-# INLINE ioReadV #-} |
@@ -121,7 +128,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val | |||
121 | 128 | ||
122 | newtype STMatrix s t = STMatrix (Matrix t) | 129 | newtype STMatrix s t = STMatrix (Matrix t) |
123 | 130 | ||
124 | thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) | 131 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
125 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix | 132 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix |
126 | 133 | ||
127 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | 134 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
@@ -142,17 +149,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | |||
142 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | 149 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () |
143 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | 150 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c |
144 | 151 | ||
145 | liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a | 152 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a |
146 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | 153 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x |
147 | 154 | ||
148 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) | 155 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
149 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 156 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
150 | 157 | ||
151 | 158 | ||
152 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) | 159 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
153 | freezeMatrix m = liftSTMatrix id m | 160 | freezeMatrix m = liftSTMatrix id m |
154 | 161 | ||
155 | cloneMatrix :: Element t => Matrix t -> IO (Matrix t) | 162 | cloneMatrix :: Storable t => Matrix t -> IO (Matrix t) |
156 | cloneMatrix m = copy (orderOf m) m | 163 | cloneMatrix m = copy (orderOf m) m |
157 | 164 | ||
158 | {-# INLINE safeIndexM #-} | 165 | {-# INLINE safeIndexM #-} |
@@ -172,7 +179,7 @@ readMatrix = safeIndexM unsafeReadMatrix | |||
172 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | 179 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () |
173 | writeMatrix = safeIndexM unsafeWriteMatrix | 180 | writeMatrix = safeIndexM unsafeWriteMatrix |
174 | 181 | ||
175 | setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () | 182 | setMatrix :: Storable t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () |
176 | setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x | 183 | setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x |
177 | 184 | ||
178 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) | 185 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) |
@@ -210,7 +217,7 @@ data RowOper t = AXPY t Int Int ColRange | |||
210 | | SCAL t RowRange ColRange | 217 | | SCAL t RowRange ColRange |
211 | | SWAP Int Int ColRange | 218 | | SWAP Int Int ColRange |
212 | 219 | ||
213 | rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () | 220 | rowOper :: (Num t, Storable t) => RowOper t -> STMatrix s t -> ST s () |
214 | 221 | ||
215 | rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m | 222 | rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m |
216 | where | 223 | where |
@@ -230,8 +237,8 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | |||
230 | i2' = i2 `mod` (rows m) | 237 | i2' = i2 `mod` (rows m) |
231 | 238 | ||
232 | 239 | ||
233 | extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) | 240 | extractMatrix :: Storable a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) |
234 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 241 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractAux (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
235 | where | 242 | where |
236 | (i1,i2) = getRowRange (rows m) rr | 243 | (i1,i2) = getRowRange (rows m) rr |
237 | (j1,j2) = getColRange (cols m) rc | 244 | (j1,j2) = getColRange (cols m) rc |
@@ -239,19 +246,117 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ | |||
239 | -- | r0 c0 height width | 246 | -- | r0 c0 height width |
240 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 247 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
241 | 248 | ||
242 | slice :: Element a => Slice t a -> Matrix a | 249 | slice :: Storable a => Slice t a -> Matrix a |
243 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m | 250 | slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m |
244 | 251 | ||
245 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | 252 | gemmm :: (Storable t, Num t) => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () |
246 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res | 253 | gemmm beta (slice->r) alpha (slice->a) (slice->b) = res |
247 | where | 254 | where |
248 | res = unsafeIOToST (gemm v a b r) | 255 | res = unsafeIOToST (gemm v a b r) |
249 | v = fromList [alpha,beta] | 256 | v = fromList [alpha,beta] |
250 | 257 | ||
251 | 258 | ||
252 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 259 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
253 | mutable f a = runST $ do | 260 | mutable f a = runST $ do |
254 | x <- thawMatrix a | 261 | x <- thawMatrix a |
255 | info <- f (rows a, cols a) x | 262 | info <- f (rows a, cols a) x |
256 | r <- unsafeFreezeMatrix x | 263 | r <- unsafeFreezeMatrix x |
257 | return (r,info) | 264 | return (r,info) |
265 | |||
266 | |||
267 | |||
268 | setRect :: Storable t => Int -> Int -> Matrix t -> Matrix t -> IO () | ||
269 | setRect i j m r = (m Internal.Matrix.#! r) (setRectStorable (fi i) (fi j)) #|"setRect" | ||
270 | |||
271 | setRectStorable :: Storable t => | ||
272 | Int32 -> Int32 | ||
273 | -> Int32 -> Int32 -> Int32 -> Int32 -> {- const -} Ptr t | ||
274 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t | ||
275 | -> IO Int32 | ||
276 | setRectStorable i j mr mc mXr mXc mp rr rc rXr rXc rp = do | ||
277 | ($ 0) $ fix $ \aloop a -> when (a<mr) $ do | ||
278 | ($ 0) $ fix $ \bloop b -> when (b<mc) $ do | ||
279 | let x = a+i | ||
280 | y = b+j | ||
281 | when (0<=x && x<rr && 0<=y && y<rc) $ do | ||
282 | pokeElemOff rp (fromIntegral $ rXr*x + rXc*y) | ||
283 | =<< peekElemOff mp (fromIntegral $ mXr*a + mXc*b) | ||
284 | bloop (succ b) | ||
285 | aloop (succ a) | ||
286 | return 0 | ||
287 | |||
288 | rowOp :: (Storable t, Num t) => Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO () | ||
289 | rowOp = rowOpAux rowOpStorable | ||
290 | |||
291 | pattern BAD_CODE = 2001 | ||
292 | |||
293 | rowOpStorable :: (Storable t, Num t) => | ||
294 | Int32 -> Ptr t -> Int32 -> Int32 -> Int32 -> Int32 | ||
295 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t | ||
296 | -> IO Int32 | ||
297 | rowOpStorable 0 pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
298 | -- AXPY_IMP | ||
299 | a <- peek pa | ||
300 | ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do | ||
301 | ri1j <- peekElemOff rp $ fromIntegral $ rXr*i1 + rXc*j | ||
302 | let i2j = fromIntegral $ rXr*i2 + rXc*j | ||
303 | ri2j <- peekElemOff rp i2j | ||
304 | pokeElemOff rp i2j $ ri2j + a*ri1j | ||
305 | jloop (succ j) | ||
306 | return 0 | ||
307 | rowOpStorable 1 pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
308 | -- SCAL_IMP | ||
309 | a <- peek pa | ||
310 | ($ i1) $ fix $ \iloop i -> when (i<=i2) $ do | ||
311 | ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do | ||
312 | let rijp = rp `plusPtr` fromIntegral (rXr*i + rXc*j) | ||
313 | rij <- peek rijp | ||
314 | poke rijp $ a * rij | ||
315 | jloop (succ j) | ||
316 | iloop (succ i) | ||
317 | return 0 | ||
318 | rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp | i1 == i2 = return 0 | ||
319 | rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
320 | -- SWAP_IMP | ||
321 | ($ j1) $ fix $ \kloop k -> when (k<=j2) $ do | ||
322 | let i1k = fromIntegral $ rXr*i1 + rXc*k | ||
323 | i2k = fromIntegral $ rXr*i2 + rXc*k | ||
324 | aux <- peekElemOff rp i1k | ||
325 | pokeElemOff rp i1k =<< peekElemOff rp i2k | ||
326 | pokeElemOff rp i2k aux | ||
327 | kloop (succ k) | ||
328 | return 0 | ||
329 | rowOpStorable _ pa i1 i2 j1 j2 rr rc rXr rXc rp = do | ||
330 | return BAD_CODE | ||
331 | |||
332 | gemm :: (Storable t, Num t) => Vector t -> Matrix t -> Matrix t -> Matrix t -> IO () | ||
333 | gemm v m1 m2 m3 = (v Internal.Matrix.# m1 Internal.Matrix.# m2 Internal.Matrix.#! m3) gemmStorable #|"gemm" | ||
334 | |||
335 | -- ScalarLike t | ||
336 | gemmStorable :: (Storable t, Num t) => | ||
337 | Int32 -> Ptr t -- VECG(T,c) | ||
338 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,a) | ||
339 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,b) | ||
340 | -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,r) | ||
341 | -> IO Int32 | ||
342 | gemmStorable cn cp | ||
343 | ar ac aXr aXc ap | ||
344 | br bc bXr bXc bp | ||
345 | rr rc rXr rXc rp = do | ||
346 | a <- peek cp | ||
347 | b <- peekElemOff cp 1 | ||
348 | ($ 0) $ fix $ \iloop i -> when (i<rr) $ do | ||
349 | ($ 0) $ fix $ \jloop j -> when (j<rc) $ do | ||
350 | let kloop k !t fin | ||
351 | | k<ac = do | ||
352 | aik <- peekElemOff ap (fromIntegral $ i*aXr + k*aXc) | ||
353 | bkj <- peekElemOff bp (fromIntegral $ k*bXr + j*bXc) | ||
354 | kloop (succ k) (t + aik*bkj) fin | ||
355 | | otherwise = fin t | ||
356 | kloop 0 0 $ \t -> do | ||
357 | let ij = fromIntegral $ i*rXr + j*rXc | ||
358 | rij <- peekElemOff rp ij | ||
359 | pokeElemOff rp ij (b*rij + a*t) | ||
360 | jloop (succ j) | ||
361 | iloop (succ i) | ||
362 | return 0 | ||