From 9fc924dc107cd619c60a421d288dafb92f417b8c Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 1 Jan 2011 20:59:47 +0000 Subject: accum --- lib/Numeric/ContainerBoot.hs | 39 ++++++++++++++++++++++++++++++-------- lib/Numeric/LinearAlgebra/Tests.hs | 13 +++++++++++++ 2 files changed, 44 insertions(+), 8 deletions(-) (limited to 'lib') diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index d9f0d78..250d4c5 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs @@ -121,15 +121,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where sumElements :: c e -> e -- | the product of elements (faster than using @fold@) prodElements :: c e -> e - -- | a more efficient implementation of @cmap (\x -> if x>0 then 1 else 0)@ + -- | a more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@ step :: RealElement e => c e -> c e - -- | find index of elements which satisfy a predicate - find :: (e -> Bool) -> c e -> [IndexOf c] - -- | create a structure from an association list - assoc :: IndexOf c -- ^ size - -> e -- ^ default value - -> [(IndexOf c, e)] -- ^ association list - -> c e -- ^ result -- | element by element @case compare a b of LT -> l, EQ -> e, GT -> g@ cond :: RealElement e @@ -140,6 +133,21 @@ class (Complexable c, Fractional e, Element e) => Container c e where -> c e -- ^ g -> c e -- ^ result + -- | find index of elements which satisfy a predicate + find :: (e -> Bool) -> c e -> [IndexOf c] + + -- | create a structure from an association list + assoc :: IndexOf c -- ^ size + -> e -- ^ default value + -> [(IndexOf c, e)] -- ^ association list + -> c e -- ^ result + + -- | modify a structure using an update function + accum :: c e -- ^ initial structure + -> (e -> e -> e) -- ^ update function + -> [(IndexOf c, e)] -- ^ association list + -> c e -- ^ result + -------------------------------------------------------------------------- instance Container Vector Float where @@ -167,6 +175,7 @@ instance Container Vector Float where step = stepF find = findV assoc = assocV + accum = accumV cond = condV condF instance Container Vector Double where @@ -194,6 +203,7 @@ instance Container Vector Double where step = stepD find = findV assoc = assocV + accum = accumV cond = condV condD instance Container Vector (Complex Double) where @@ -221,6 +231,7 @@ instance Container Vector (Complex Double) where step = undefined -- cannot match find = findV assoc = assocV + accum = accumV cond = undefined -- cannot match instance Container Vector (Complex Float) where @@ -248,6 +259,7 @@ instance Container Vector (Complex Float) where step = undefined -- cannot match find = findV assoc = assocV + accum = accumV cond = undefined -- cannot match --------------------------------------------------------------- @@ -281,6 +293,7 @@ instance (Container Vector a) => Container Matrix a where step = liftMatrix step find = findM assoc = assocM + accum = accumM cond = condM ---------------------------------------------------- @@ -637,6 +650,16 @@ assocM (r,c) z xs = ST.runSTMatrix $ do mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs return m +accumV v0 f xs = ST.runSTVector $ do + v <- ST.thawVector v0 + mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs + return v + +accumM m0 f xs = ST.runSTMatrix $ do + m <- ST.thawMatrix m0 + mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs + return m + ---------------------------------------------------------------------- condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 3bcfec5..32cd39d 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -392,6 +392,18 @@ conformTest = utest "conform" ok --------------------------------------------------------------------- +accumTest = utest "accum" ok + where + x = ident 3 :: Matrix Double + ok = accum x (+) [((1,2),7), ((2,2),3)] + == (3><3) [1,0,0 + ,0,1,7 + ,0,0,4] + && + toList (flatten x) == [1,0,0,0,1,0,0,0,1] + +--------------------------------------------------------------------- + -- | All tests must pass with a maximum dimension of about 20 -- (some tests may fail with bigger sizes due to precision loss). runTests :: Int -- ^ maximum dimension @@ -562,6 +574,7 @@ runTests n = do , findAssocTest , condTest , conformTest + , accumTest ] return () -- cgit v1.2.3