summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Vector.hs53
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs6
2 files changed, 58 insertions, 1 deletions
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs
index eaf4b9c..23fe37f 100644
--- a/lib/Data/Packed/Vector.hs
+++ b/lib/Data/Packed/Vector.hs
@@ -22,7 +22,8 @@ module Data.Packed.Vector (
22 subVector, takesV, join, 22 subVector, takesV, join,
23 mapVector, zipVector, zipVectorWith, unzipVector, unzipVectorWith, 23 mapVector, zipVector, zipVectorWith, unzipVector, unzipVectorWith,
24 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_, 24 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
25 foldLoop, foldVector, foldVectorG, foldVectorWithIndex 25 foldLoop, foldVector, foldVectorG, foldVectorWithIndex,
26 successive_, successive
26) where 27) where
27 28
28import Data.Packed.Internal.Vector 29import Data.Packed.Internal.Vector
@@ -82,4 +83,54 @@ zipVector = zipVectorWith (,)
82unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b) 83unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b)
83unzipVector = unzipVectorWith id 84unzipVector = unzipVectorWith id
84 85
86-------------------------------------------------------------------
87
88newtype State s a = State { runState :: s -> (a,s) }
89
90instance Monad (State s) where
91 return a = State $ \s -> (a,s)
92 m >>= f = State $ \s -> let (a,s') = runState m s
93 in runState (f a) s'
94
95state_get :: State s s
96state_get = State $ \s -> (s,s)
97
98state_put :: s -> State s ()
99state_put s = State $ \_ -> ((),s)
100
101evalState :: State s a -> s -> a
102evalState m s = fst $ runState m s
103
104newtype MaybeT m a = MaybeT { runMaybeT :: m (Maybe a) }
105
106instance Monad m => Monad (MaybeT m) where
107 return a = MaybeT $ return $ Just a
108 m >>= f = MaybeT $ do
109 res <- runMaybeT m
110 case res of
111 Nothing -> return Nothing
112 Just r -> runMaybeT (f r)
113 fail _ = MaybeT $ return Nothing
114
115lift_maybe m = MaybeT $ do
116 res <- m
117 return $ Just res
118
119-- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs
120successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool
121successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ step (subVector 1 (dim v - 1) v))) (v @> 0)
122 where step e = do
123 ep <- lift_maybe $ state_get
124 if t e ep
125 then lift_maybe $ state_put e
126 else (fail "successive_ test failed")
127
128-- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input
129successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b
130successive f v = evalState (mapVectorM step (subVector 1 (dim v - 1) v)) (v @> 0)
131 where step e = do
132 ep <- state_get
133 state_put e
134 return $ f ep e
85 135
136-------------------------------------------------------------------
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 0df29a8..a44c273 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -300,6 +300,11 @@ conjuTest m = mapVector conjugate (flatten (trans m)) == flatten (ctrans m)
300 300
301--------------------------------------------------------------------- 301---------------------------------------------------------------------
302 302
303succTest = utest "successive" $ successive_ (<) (fromList [1 :: Double,2,3,4]) == True
304 && successive_ (<) (fromList [1 :: Double,3,2,4]) == False
305
306---------------------------------------------------------------------
307
303 308
304-- | All tests must pass with a maximum dimension of about 20 309-- | All tests must pass with a maximum dimension of about 20
305-- (some tests may fail with bigger sizes due to precision loss). 310-- (some tests may fail with bigger sizes due to precision loss).
@@ -466,6 +471,7 @@ runTests n = do
466 , normsMTest 471 , normsMTest
467 , sumprodTest 472 , sumprodTest
468 , chainTest 473 , chainTest
474 , succTest
469 ] 475 ]
470 return () 476 return ()
471 477