summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Numeric/ContainerBoot.hs39
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs13
2 files changed, 44 insertions, 8 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index d9f0d78..250d4c5 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -121,15 +121,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where
121 sumElements :: c e -> e 121 sumElements :: c e -> e
122 -- | the product of elements (faster than using @fold@) 122 -- | the product of elements (faster than using @fold@)
123 prodElements :: c e -> e 123 prodElements :: c e -> e
124 -- | a more efficient implementation of @cmap (\x -> if x>0 then 1 else 0)@ 124 -- | a more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
125 step :: RealElement e => c e -> c e 125 step :: RealElement e => c e -> c e
126 -- | find index of elements which satisfy a predicate
127 find :: (e -> Bool) -> c e -> [IndexOf c]
128 -- | create a structure from an association list
129 assoc :: IndexOf c -- ^ size
130 -> e -- ^ default value
131 -> [(IndexOf c, e)] -- ^ association list
132 -> c e -- ^ result
133 126
134 -- | element by element @case compare a b of LT -> l, EQ -> e, GT -> g@ 127 -- | element by element @case compare a b of LT -> l, EQ -> e, GT -> g@
135 cond :: RealElement e 128 cond :: RealElement e
@@ -140,6 +133,21 @@ class (Complexable c, Fractional e, Element e) => Container c e where
140 -> c e -- ^ g 133 -> c e -- ^ g
141 -> c e -- ^ result 134 -> c e -- ^ result
142 135
136 -- | find index of elements which satisfy a predicate
137 find :: (e -> Bool) -> c e -> [IndexOf c]
138
139 -- | create a structure from an association list
140 assoc :: IndexOf c -- ^ size
141 -> e -- ^ default value
142 -> [(IndexOf c, e)] -- ^ association list
143 -> c e -- ^ result
144
145 -- | modify a structure using an update function
146 accum :: c e -- ^ initial structure
147 -> (e -> e -> e) -- ^ update function
148 -> [(IndexOf c, e)] -- ^ association list
149 -> c e -- ^ result
150
143-------------------------------------------------------------------------- 151--------------------------------------------------------------------------
144 152
145instance Container Vector Float where 153instance Container Vector Float where
@@ -167,6 +175,7 @@ instance Container Vector Float where
167 step = stepF 175 step = stepF
168 find = findV 176 find = findV
169 assoc = assocV 177 assoc = assocV
178 accum = accumV
170 cond = condV condF 179 cond = condV condF
171 180
172instance Container Vector Double where 181instance Container Vector Double where
@@ -194,6 +203,7 @@ instance Container Vector Double where
194 step = stepD 203 step = stepD
195 find = findV 204 find = findV
196 assoc = assocV 205 assoc = assocV
206 accum = accumV
197 cond = condV condD 207 cond = condV condD
198 208
199instance Container Vector (Complex Double) where 209instance Container Vector (Complex Double) where
@@ -221,6 +231,7 @@ instance Container Vector (Complex Double) where
221 step = undefined -- cannot match 231 step = undefined -- cannot match
222 find = findV 232 find = findV
223 assoc = assocV 233 assoc = assocV
234 accum = accumV
224 cond = undefined -- cannot match 235 cond = undefined -- cannot match
225 236
226instance Container Vector (Complex Float) where 237instance Container Vector (Complex Float) where
@@ -248,6 +259,7 @@ instance Container Vector (Complex Float) where
248 step = undefined -- cannot match 259 step = undefined -- cannot match
249 find = findV 260 find = findV
250 assoc = assocV 261 assoc = assocV
262 accum = accumV
251 cond = undefined -- cannot match 263 cond = undefined -- cannot match
252 264
253--------------------------------------------------------------- 265---------------------------------------------------------------
@@ -281,6 +293,7 @@ instance (Container Vector a) => Container Matrix a where
281 step = liftMatrix step 293 step = liftMatrix step
282 find = findM 294 find = findM
283 assoc = assocM 295 assoc = assocM
296 accum = accumM
284 cond = condM 297 cond = condM
285 298
286---------------------------------------------------- 299----------------------------------------------------
@@ -637,6 +650,16 @@ assocM (r,c) z xs = ST.runSTMatrix $ do
637 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs 650 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
638 return m 651 return m
639 652
653accumV v0 f xs = ST.runSTVector $ do
654 v <- ST.thawVector v0
655 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
656 return v
657
658accumM m0 f xs = ST.runSTMatrix $ do
659 m <- ST.thawMatrix m0
660 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs
661 return m
662
640---------------------------------------------------------------------- 663----------------------------------------------------------------------
641 664
642condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' 665condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t'
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 3bcfec5..32cd39d 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -392,6 +392,18 @@ conformTest = utest "conform" ok
392 392
393--------------------------------------------------------------------- 393---------------------------------------------------------------------
394 394
395accumTest = utest "accum" ok
396 where
397 x = ident 3 :: Matrix Double
398 ok = accum x (+) [((1,2),7), ((2,2),3)]
399 == (3><3) [1,0,0
400 ,0,1,7
401 ,0,0,4]
402 &&
403 toList (flatten x) == [1,0,0,0,1,0,0,0,1]
404
405---------------------------------------------------------------------
406
395-- | All tests must pass with a maximum dimension of about 20 407-- | All tests must pass with a maximum dimension of about 20
396-- (some tests may fail with bigger sizes due to precision loss). 408-- (some tests may fail with bigger sizes due to precision loss).
397runTests :: Int -- ^ maximum dimension 409runTests :: Int -- ^ maximum dimension
@@ -562,6 +574,7 @@ runTests n = do
562 , findAssocTest 574 , findAssocTest
563 , condTest 575 , condTest
564 , conformTest 576 , conformTest
577 , accumTest
565 ] 578 ]
566 return () 579 return ()
567 580