summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVivian McPhail <haskell.vivian.mcphail@gmail.com>2010-08-28 11:29:57 +0000
committerVivian McPhail <haskell.vivian.mcphail@gmail.com>2010-08-28 11:29:57 +0000
commit693cae17c1e4ae3570f35324119f47ca6103f3cf (patch)
tree52cd26440025575a2e339fb542593325a3668df4
parent5e60b08d76e666643c795131bcbb18d196a39520 (diff)
add withIndex traversal
-rw-r--r--examples/vector-map.hs33
-rw-r--r--lib/Data/Packed/Internal/Vector.hs46
-rw-r--r--lib/Data/Packed/Vector.hs4
3 files changed, 79 insertions, 4 deletions
diff --git a/examples/vector-map.hs b/examples/vector-map.hs
index f116946..7796cc0 100644
--- a/examples/vector-map.hs
+++ b/examples/vector-map.hs
@@ -30,12 +30,45 @@ addInitialM = mapVectorM_ (\x -> do
30 put $ x + i 30 put $ x + i
31 ) 31 )
32 32
33-- sum the values of the even indiced elements
34sumEvens :: Vector Int -> Int
35sumEvens = foldVectorWithIndex (\x a b -> if x `mod` 2 == 0 then a + b else b) 0
36
37-- sum and print running total of evens
38sumEvensAndPrint :: Vector Int -> VectorMonadT ()
39sumEvensAndPrint = mapVectorWithIndexM_ (\ i x -> do
40 when (i `mod` 2 == 0) (do
41 v <- get
42 put $ v + x
43 v' <- get
44 liftIO $ putStr $ (show v') ++ " "
45 return ())
46 return ()
47 )
48
49indexPlusSum :: Vector Int -> VectorMonadT ()
50indexPlusSum v' = do
51 v <- mapVectorWithIndexM (\i x -> do
52 s <- get
53 let inc = x+s
54 liftIO $ putStr $ show (i,inc) ++ " "
55 put inc
56 return inc) v'
57 liftIO $ do
58 putStrLn ""
59 putStrLn $ show v
60
33------------------------------------------- 61-------------------------------------------
62
34main = do 63main = do
35 v' <- test1 v 64 v' <- test1 v
36 putStrLn "" 65 putStrLn ""
37 putStrLn $ show v' 66 putStrLn $ show v'
38 evalStateT (addInitialM v) 0 67 evalStateT (addInitialM v) 0
39 putStrLn "" 68 putStrLn ""
69 putStrLn $ show (sumEvens v)
70 evalStateT (sumEvensAndPrint v) 0
71 putStrLn ""
72 evalStateT (indexPlusSum v) 0
40 return () 73 return ()
41 74
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index be2fcbb..a47c376 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -18,8 +18,8 @@ module Data.Packed.Internal.Vector (
18 fromList, toList, (|>), 18 fromList, toList, (|>),
19 join, (@>), safe, at, at', subVector, takesV, 19 join, (@>), safe, at, at', subVector, takesV,
20 mapVector, zipVectorWith, unzipVectorWith, 20 mapVector, zipVectorWith, unzipVectorWith,
21 mapVectorM, mapVectorM_, 21 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
22 foldVector, foldVectorG, foldLoop, 22 foldVector, foldVectorG, foldLoop, foldVectorWithIndex,
23 createVector, vec, 23 createVector, vec,
24 asComplex, asReal, float2DoubleV, double2FloatV, 24 asComplex, asReal, float2DoubleV, double2FloatV,
25 fwriteVector, freadVector, fprintfVector, fscanfVector, 25 fwriteVector, freadVector, fprintfVector, fscanfVector,
@@ -364,6 +364,16 @@ foldVector f x v = unsafePerformIO $
364 go (dim v -1) x 364 go (dim v -1) x
365{-# INLINE foldVector #-} 365{-# INLINE foldVector #-}
366 366
367-- the zero-indexed index is passed to the folding function
368foldVectorWithIndex :: Storable a => (Int -> a -> b -> b) -> b -> Vector a -> b
369foldVectorWithIndex f x v = unsafePerformIO $
370 unsafeWith v $ \p -> do
371 let go (-1) s = return s
372 go !k !s = do y <- peekElemOff p k
373 go (k-1::Int) (f k y s)
374 go (dim v -1) x
375{-# INLINE foldVectorWithIndex #-}
376
367foldLoop f s0 d = go (d - 1) s0 377foldLoop f s0 d = go (d - 1) s0
368 where 378 where
369 go 0 s = f (0::Int) s 379 go 0 s = f (0::Int) s
@@ -408,6 +418,38 @@ mapVectorM_ f v = do
408 mapVectorM' f' v' (k+1) t 418 mapVectorM' f' v' (k+1) t
409{-# INLINE mapVectorM_ #-} 419{-# INLINE mapVectorM_ #-}
410 420
421-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
422mapVectorWithIndexM :: (Storable a, Storable b, Monad m) => (Int -> a -> m b) -> Vector a -> m (Vector b)
423mapVectorWithIndexM f v = do
424 w <- return $! unsafePerformIO $! createVector (dim v)
425 mapVectorM' f v w 0 (dim v -1)
426 return w
427 where mapVectorM' f' v' w' !k !t
428 | k == t = do
429 x <- return $! inlinePerformIO $! unsafeWith v' $! \p -> peekElemOff p k
430 y <- f' k x
431 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
432 | otherwise = do
433 x <- return $! inlinePerformIO $! unsafeWith v' $! \p -> peekElemOff p k
434 y <- f' k x
435 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
436 mapVectorM' f' v' w' (k+1) t
437{-# INLINE mapVectorWithIndexM #-}
438
439-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
440mapVectorWithIndexM_ :: (Storable a, Monad m) => (Int -> a -> m ()) -> Vector a -> m ()
441mapVectorWithIndexM_ f v = do
442 mapVectorM' f v 0 (dim v -1)
443 where mapVectorM' f' v' !k !t
444 | k == t = do
445 x <- return $! inlinePerformIO $! unsafeWith v' $! \p -> peekElemOff p k
446 f' k x
447 | otherwise = do
448 x <- return $! inlinePerformIO $! unsafeWith v' $! \p -> peekElemOff p k
449 _ <- f' k x
450 mapVectorM' f' v' (k+1) t
451{-# INLINE mapVectorWithIndexM_ #-}
452
411------------------------------------------------------------------- 453-------------------------------------------------------------------
412 454
413 455
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs
index a526caa..ad690f9 100644
--- a/lib/Data/Packed/Vector.hs
+++ b/lib/Data/Packed/Vector.hs
@@ -27,9 +27,9 @@ module Data.Packed.Vector (
27-- vectorMax, vectorMin, 27-- vectorMax, vectorMin,
28 vectorMaxIndex, vectorMinIndex, 28 vectorMaxIndex, vectorMinIndex,
29 mapVector, zipVector, zipVectorWith, unzipVector, unzipVectorWith, 29 mapVector, zipVector, zipVectorWith, unzipVector, unzipVectorWith,
30 mapVectorM, mapVectorM_, 30 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
31 fscanfVector, fprintfVector, freadVector, fwriteVector, 31 fscanfVector, fprintfVector, freadVector, fwriteVector,
32 foldLoop, foldVector, foldVectorG 32 foldLoop, foldVector, foldVectorG, foldVectorWithIndex
33) where 33) where
34 34
35import Data.Packed.Internal.Vector 35import Data.Packed.Internal.Vector