summaryrefslogtreecommitdiff
path: root/examples/monadic.hs
blob: cf8aacc0b8e7201689efa5707a7f602df48f8449 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
-- monadic computations
-- (contributed by Vivian McPhail)

{-# LANGUAGE FlexibleContexts #-}

import Numeric.LinearAlgebra
import Numeric.LinearAlgebra.Devel
import Control.Monad.State.Strict
import Control.Monad.Trans.Maybe
import Foreign.Storable(Storable)
import System.Random(randomIO)

-------------------------------------------

-- an instance of MonadIO, a monad transformer
type VectorMonadT = StateT I IO

test1 :: Vector I -> IO (Vector I)
test1 = mapVectorM $ \x -> do
    putStr $ (show x) ++ " "
    return (x + 1)

-- we can have an arbitrary monad AND do IO
addInitialM :: Vector I -> VectorMonadT ()
addInitialM = mapVectorM_ $ \x -> do
    i <- get
    liftIO $ putStr $ (show $ x + i) ++ " "
    put $ x + i

-- sum the values of the even indiced elements
sumEvens :: Vector I -> I
sumEvens = foldVectorWithIndex (\x a b -> if x `mod` 2 == 0 then a + b else b) 0

-- sum and print running total of evens
sumEvensAndPrint = mapVectorWithIndexM_ $ \ i x -> do
    when (i `mod` 2 == 0) $ do
        v <- get
        put $ v + x
        v' <- get
        liftIO $ putStr $ (show v') ++ " "


--indexPlusSum :: Vector I -> VectorMonadT ()
indexPlusSum v' = do
    let f i x = do
            s <- get
            let inc = x+s
            liftIO $ putStr $ show (i,inc) ++ " "
            put inc
            return inc
    v <- mapVectorWithIndexM f v'
    liftIO $ do
        putStrLn ""
        putStrLn $ show v

-------------------------------------------

-- short circuit
monoStep :: Double -> MaybeT (State Double) ()
monoStep d = do
    dp <- get
    when (d < dp) (fail "negative difference")
    put d
{-# INLINE monoStep #-}

isMonotoneIncreasing :: Vector Double -> Bool
isMonotoneIncreasing v =
    let res = evalState (runMaybeT $ (mapVectorM_ monoStep v)) (v ! 0)
     in case res of
        Nothing -> False
        Just _  -> True


-------------------------------------------

-- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs
successive_ :: (Container Vector a, Indexable (Vector a) a) => (a -> a -> Bool) -> Vector a -> Bool
successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ step (subVector 1 (size v - 1) v))) (v ! 0)
   where step e = do
                  ep <- lift $ get
                  if t e ep
                     then lift $ put e
                     else (fail "successive_ test failed")

-- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input
successive
  :: (Storable b, Container Vector s, Indexable (Vector s) s)
  => (s -> s -> b) -> Vector s -> Vector b
successive f v = evalState (mapVectorM step (subVector 1 (size v - 1) v)) (v ! 0)
   where step e = do
                  ep <- get
                  put e
                  return $ f ep e

-------------------------------------------

v :: Vector I
v = 10 |> [0..]

w = fromList ([1..10]++[10,9..1]) :: Vector Double


main = do
    v' <- test1 v
    putStrLn ""
    putStrLn $ show v'
    evalStateT (addInitialM v) 0
    putStrLn ""
    putStrLn $ show (sumEvens v)
    evalStateT (sumEvensAndPrint v) 0
    putStrLn ""
    evalStateT (indexPlusSum v) 0
    putStrLn "-----------------------"
    mapVectorM_ print v
    print =<< (mapVectorM (const randomIO) v :: IO (Vector Double))
    print =<< (mapVectorM (\a -> fmap (+a) randomIO) (5|>[0,100..1000]) :: IO (Vector Double))
    putStrLn "-----------------------"
    print $ isMonotoneIncreasing w
    print $ isMonotoneIncreasing (subVector 0 7 w)
    print $ successive_ (>) v
    print $ successive_ (>) w
    print $ successive (+) v