summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/ST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/ST.hs')
-rw-r--r--packages/base/src/Internal/ST.hs131
1 files changed, 118 insertions, 13 deletions
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 7d54e6d..326b90a 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,6 +1,7 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3{-# LANGUAGE ViewPatterns #-} 3{-# LANGUAGE ViewPatterns #-}
4{-# LANGUAGE PatternSynonyms #-}
4 5
5----------------------------------------------------------------------------- 6-----------------------------------------------------------------------------
6-- | 7-- |
@@ -30,14 +31,20 @@ module Internal.ST (
30 unsafeThawVector, unsafeFreezeVector, 31 unsafeThawVector, unsafeFreezeVector,
31 newUndefinedMatrix, 32 newUndefinedMatrix,
32 unsafeReadMatrix, unsafeWriteMatrix, 33 unsafeReadMatrix, unsafeWriteMatrix,
33 unsafeThawMatrix, unsafeFreezeMatrix 34 unsafeThawMatrix, unsafeFreezeMatrix,
35 setRect
34) where 36) where
35 37
36import Internal.Vector 38import Internal.Vector
37import Internal.Matrix 39import Internal.Matrix
38import Internal.Vectorized 40import Internal.Vectorized
41import Internal.Devel ((#|))
39import Control.Monad.ST(ST, runST) 42import Control.Monad.ST(ST, runST)
40import Foreign.Storable(Storable, peekElemOff, pokeElemOff) 43import Control.Monad
44import Data.Function
45import Data.Int
46import Foreign.Ptr
47import Foreign.Storable
41import Control.Monad.ST.Unsafe(unsafeIOToST) 48import Control.Monad.ST.Unsafe(unsafeIOToST)
42 49
43{-# INLINE ioReadV #-} 50{-# INLINE ioReadV #-}
@@ -121,7 +128,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val
121 128
122newtype STMatrix s t = STMatrix (Matrix t) 129newtype STMatrix s t = STMatrix (Matrix t)
123 130
124thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) 131thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
125thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix 132thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
126 133
127unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) 134unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
@@ -142,17 +149,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
142modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 149modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
143modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 150modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
144 151
145liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a 152liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a
146liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 153liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
147 154
148unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) 155unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
149unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 156unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
150 157
151 158
152freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) 159freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
153freezeMatrix m = liftSTMatrix id m 160freezeMatrix m = liftSTMatrix id m
154 161
155cloneMatrix :: Element t => Matrix t -> IO (Matrix t) 162cloneMatrix :: Storable t => Matrix t -> IO (Matrix t)
156cloneMatrix m = copy (orderOf m) m 163cloneMatrix m = copy (orderOf m) m
157 164
158{-# INLINE safeIndexM #-} 165{-# INLINE safeIndexM #-}
@@ -172,7 +179,7 @@ readMatrix = safeIndexM unsafeReadMatrix
172writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () 179writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
173writeMatrix = safeIndexM unsafeWriteMatrix 180writeMatrix = safeIndexM unsafeWriteMatrix
174 181
175setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () 182setMatrix :: Storable t => STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
176setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x 183setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x
177 184
178newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) 185newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
@@ -210,7 +217,7 @@ data RowOper t = AXPY t Int Int ColRange
210 | SCAL t RowRange ColRange 217 | SCAL t RowRange ColRange
211 | SWAP Int Int ColRange 218 | SWAP Int Int ColRange
212 219
213rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () 220rowOper :: (Num t, Storable t) => RowOper t -> STMatrix s t -> ST s ()
214 221
215rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m 222rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m
216 where 223 where
@@ -230,8 +237,8 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
230 i2' = i2 `mod` (rows m) 237 i2' = i2 `mod` (rows m)
231 238
232 239
233extractMatrix :: Element a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a) 240extractMatrix :: Storable a => STMatrix t a -> RowRange -> ColRange -> ST s (Matrix a)
234extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 241extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractAux (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
235 where 242 where
236 (i1,i2) = getRowRange (rows m) rr 243 (i1,i2) = getRowRange (rows m) rr
237 (j1,j2) = getColRange (cols m) rc 244 (j1,j2) = getColRange (cols m) rc
@@ -239,19 +246,117 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[
239-- | r0 c0 height width 246-- | r0 c0 height width
240data Slice s t = Slice (STMatrix s t) Int Int Int Int 247data Slice s t = Slice (STMatrix s t) Int Int Int Int
241 248
242slice :: Element a => Slice t a -> Matrix a 249slice :: Storable a => Slice t a -> Matrix a
243slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m 250slice (Slice (STMatrix m) r0 c0 nr nc) = subMatrix (r0,c0) (nr,nc) m
244 251
245gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () 252gemmm :: (Storable t, Num t) => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
246gemmm beta (slice->r) alpha (slice->a) (slice->b) = res 253gemmm beta (slice->r) alpha (slice->a) (slice->b) = res
247 where 254 where
248 res = unsafeIOToST (gemm v a b r) 255 res = unsafeIOToST (gemm v a b r)
249 v = fromList [alpha,beta] 256 v = fromList [alpha,beta]
250 257
251 258
252mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 259mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
253mutable f a = runST $ do 260mutable f a = runST $ do
254 x <- thawMatrix a 261 x <- thawMatrix a
255 info <- f (rows a, cols a) x 262 info <- f (rows a, cols a) x
256 r <- unsafeFreezeMatrix x 263 r <- unsafeFreezeMatrix x
257 return (r,info) 264 return (r,info)
265
266
267
268setRect :: Storable t => Int -> Int -> Matrix t -> Matrix t -> IO ()
269setRect i j m r = (m Internal.Matrix.#! r) (setRectStorable (fi i) (fi j)) #|"setRect"
270
271setRectStorable :: Storable t =>
272 Int32 -> Int32
273 -> Int32 -> Int32 -> Int32 -> Int32 -> {- const -} Ptr t
274 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t
275 -> IO Int32
276setRectStorable i j mr mc mXr mXc mp rr rc rXr rXc rp = do
277 ($ 0) $ fix $ \aloop a -> when (a<mr) $ do
278 ($ 0) $ fix $ \bloop b -> when (b<mc) $ do
279 let x = a+i
280 y = b+j
281 when (0<=x && x<rr && 0<=y && y<rc) $ do
282 pokeElemOff rp (fromIntegral $ rXr*x + rXc*y)
283 =<< peekElemOff mp (fromIntegral $ mXr*a + mXc*b)
284 bloop (succ b)
285 aloop (succ a)
286 return 0
287
288rowOp :: (Storable t, Num t) => Int -> t -> Int -> Int -> Int -> Int -> Matrix t -> IO ()
289rowOp = rowOpAux rowOpStorable
290
291pattern BAD_CODE = 2001
292
293rowOpStorable :: (Storable t, Num t) =>
294 Int32 -> Ptr t -> Int32 -> Int32 -> Int32 -> Int32
295 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t
296 -> IO Int32
297rowOpStorable 0 pa i1 i2 j1 j2 rr rc rXr rXc rp = do
298 -- AXPY_IMP
299 a <- peek pa
300 ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do
301 ri1j <- peekElemOff rp $ fromIntegral $ rXr*i1 + rXc*j
302 let i2j = fromIntegral $ rXr*i2 + rXc*j
303 ri2j <- peekElemOff rp i2j
304 pokeElemOff rp i2j $ ri2j + a*ri1j
305 jloop (succ j)
306 return 0
307rowOpStorable 1 pa i1 i2 j1 j2 rr rc rXr rXc rp = do
308 -- SCAL_IMP
309 a <- peek pa
310 ($ i1) $ fix $ \iloop i -> when (i<=i2) $ do
311 ($ j1) $ fix $ \jloop j -> when (j<=j2) $ do
312 let rijp = rp `plusPtr` fromIntegral (rXr*i + rXc*j)
313 rij <- peek rijp
314 poke rijp $ a * rij
315 jloop (succ j)
316 iloop (succ i)
317 return 0
318rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp | i1 == i2 = return 0
319rowOpStorable 2 pa i1 i2 j1 j2 rr rc rXr rXc rp = do
320 -- SWAP_IMP
321 ($ j1) $ fix $ \kloop k -> when (k<=j2) $ do
322 let i1k = fromIntegral $ rXr*i1 + rXc*k
323 i2k = fromIntegral $ rXr*i2 + rXc*k
324 aux <- peekElemOff rp i1k
325 pokeElemOff rp i1k =<< peekElemOff rp i2k
326 pokeElemOff rp i2k aux
327 kloop (succ k)
328 return 0
329rowOpStorable _ pa i1 i2 j1 j2 rr rc rXr rXc rp = do
330 return BAD_CODE
331
332gemm :: (Storable t, Num t) => Vector t -> Matrix t -> Matrix t -> Matrix t -> IO ()
333gemm v m1 m2 m3 = (v Internal.Matrix.# m1 Internal.Matrix.# m2 Internal.Matrix.#! m3) gemmStorable #|"gemm"
334
335-- ScalarLike t
336gemmStorable :: (Storable t, Num t) =>
337 Int32 -> Ptr t -- VECG(T,c)
338 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,a)
339 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,b)
340 -> Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -- MATG(T,r)
341 -> IO Int32
342gemmStorable cn cp
343 ar ac aXr aXc ap
344 br bc bXr bXc bp
345 rr rc rXr rXc rp = do
346 a <- peek cp
347 b <- peekElemOff cp 1
348 ($ 0) $ fix $ \iloop i -> when (i<rr) $ do
349 ($ 0) $ fix $ \jloop j -> when (j<rc) $ do
350 let kloop k !t fin
351 | k<ac = do
352 aik <- peekElemOff ap (fromIntegral $ i*aXr + k*aXc)
353 bkj <- peekElemOff bp (fromIntegral $ k*bXr + j*bXc)
354 kloop (succ k) (t + aik*bkj) fin
355 | otherwise = fin t
356 kloop 0 0 $ \t -> do
357 let ij = fromIntegral $ i*rXr + j*rXc
358 rij <- peekElemOff rp ij
359 pokeElemOff rp ij (b*rij + a*t)
360 jloop (succ j)
361 iloop (succ i)
362 return 0