diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 39 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 13 |
2 files changed, 44 insertions, 8 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index d9f0d78..250d4c5 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -121,15 +121,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
121 | sumElements :: c e -> e | 121 | sumElements :: c e -> e |
122 | -- | the product of elements (faster than using @fold@) | 122 | -- | the product of elements (faster than using @fold@) |
123 | prodElements :: c e -> e | 123 | prodElements :: c e -> e |
124 | -- | a more efficient implementation of @cmap (\x -> if x>0 then 1 else 0)@ | 124 | -- | a more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@ |
125 | step :: RealElement e => c e -> c e | 125 | step :: RealElement e => c e -> c e |
126 | -- | find index of elements which satisfy a predicate | ||
127 | find :: (e -> Bool) -> c e -> [IndexOf c] | ||
128 | -- | create a structure from an association list | ||
129 | assoc :: IndexOf c -- ^ size | ||
130 | -> e -- ^ default value | ||
131 | -> [(IndexOf c, e)] -- ^ association list | ||
132 | -> c e -- ^ result | ||
133 | 126 | ||
134 | -- | element by element @case compare a b of LT -> l, EQ -> e, GT -> g@ | 127 | -- | element by element @case compare a b of LT -> l, EQ -> e, GT -> g@ |
135 | cond :: RealElement e | 128 | cond :: RealElement e |
@@ -140,6 +133,21 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
140 | -> c e -- ^ g | 133 | -> c e -- ^ g |
141 | -> c e -- ^ result | 134 | -> c e -- ^ result |
142 | 135 | ||
136 | -- | find index of elements which satisfy a predicate | ||
137 | find :: (e -> Bool) -> c e -> [IndexOf c] | ||
138 | |||
139 | -- | create a structure from an association list | ||
140 | assoc :: IndexOf c -- ^ size | ||
141 | -> e -- ^ default value | ||
142 | -> [(IndexOf c, e)] -- ^ association list | ||
143 | -> c e -- ^ result | ||
144 | |||
145 | -- | modify a structure using an update function | ||
146 | accum :: c e -- ^ initial structure | ||
147 | -> (e -> e -> e) -- ^ update function | ||
148 | -> [(IndexOf c, e)] -- ^ association list | ||
149 | -> c e -- ^ result | ||
150 | |||
143 | -------------------------------------------------------------------------- | 151 | -------------------------------------------------------------------------- |
144 | 152 | ||
145 | instance Container Vector Float where | 153 | instance Container Vector Float where |
@@ -167,6 +175,7 @@ instance Container Vector Float where | |||
167 | step = stepF | 175 | step = stepF |
168 | find = findV | 176 | find = findV |
169 | assoc = assocV | 177 | assoc = assocV |
178 | accum = accumV | ||
170 | cond = condV condF | 179 | cond = condV condF |
171 | 180 | ||
172 | instance Container Vector Double where | 181 | instance Container Vector Double where |
@@ -194,6 +203,7 @@ instance Container Vector Double where | |||
194 | step = stepD | 203 | step = stepD |
195 | find = findV | 204 | find = findV |
196 | assoc = assocV | 205 | assoc = assocV |
206 | accum = accumV | ||
197 | cond = condV condD | 207 | cond = condV condD |
198 | 208 | ||
199 | instance Container Vector (Complex Double) where | 209 | instance Container Vector (Complex Double) where |
@@ -221,6 +231,7 @@ instance Container Vector (Complex Double) where | |||
221 | step = undefined -- cannot match | 231 | step = undefined -- cannot match |
222 | find = findV | 232 | find = findV |
223 | assoc = assocV | 233 | assoc = assocV |
234 | accum = accumV | ||
224 | cond = undefined -- cannot match | 235 | cond = undefined -- cannot match |
225 | 236 | ||
226 | instance Container Vector (Complex Float) where | 237 | instance Container Vector (Complex Float) where |
@@ -248,6 +259,7 @@ instance Container Vector (Complex Float) where | |||
248 | step = undefined -- cannot match | 259 | step = undefined -- cannot match |
249 | find = findV | 260 | find = findV |
250 | assoc = assocV | 261 | assoc = assocV |
262 | accum = accumV | ||
251 | cond = undefined -- cannot match | 263 | cond = undefined -- cannot match |
252 | 264 | ||
253 | --------------------------------------------------------------- | 265 | --------------------------------------------------------------- |
@@ -281,6 +293,7 @@ instance (Container Vector a) => Container Matrix a where | |||
281 | step = liftMatrix step | 293 | step = liftMatrix step |
282 | find = findM | 294 | find = findM |
283 | assoc = assocM | 295 | assoc = assocM |
296 | accum = accumM | ||
284 | cond = condM | 297 | cond = condM |
285 | 298 | ||
286 | ---------------------------------------------------- | 299 | ---------------------------------------------------- |
@@ -637,6 +650,16 @@ assocM (r,c) z xs = ST.runSTMatrix $ do | |||
637 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | 650 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs |
638 | return m | 651 | return m |
639 | 652 | ||
653 | accumV v0 f xs = ST.runSTVector $ do | ||
654 | v <- ST.thawVector v0 | ||
655 | mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs | ||
656 | return v | ||
657 | |||
658 | accumM m0 f xs = ST.runSTMatrix $ do | ||
659 | m <- ST.thawMatrix m0 | ||
660 | mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs | ||
661 | return m | ||
662 | |||
640 | ---------------------------------------------------------------------- | 663 | ---------------------------------------------------------------------- |
641 | 664 | ||
642 | condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' | 665 | condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' |
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 3bcfec5..32cd39d 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -392,6 +392,18 @@ conformTest = utest "conform" ok | |||
392 | 392 | ||
393 | --------------------------------------------------------------------- | 393 | --------------------------------------------------------------------- |
394 | 394 | ||
395 | accumTest = utest "accum" ok | ||
396 | where | ||
397 | x = ident 3 :: Matrix Double | ||
398 | ok = accum x (+) [((1,2),7), ((2,2),3)] | ||
399 | == (3><3) [1,0,0 | ||
400 | ,0,1,7 | ||
401 | ,0,0,4] | ||
402 | && | ||
403 | toList (flatten x) == [1,0,0,0,1,0,0,0,1] | ||
404 | |||
405 | --------------------------------------------------------------------- | ||
406 | |||
395 | -- | All tests must pass with a maximum dimension of about 20 | 407 | -- | All tests must pass with a maximum dimension of about 20 |
396 | -- (some tests may fail with bigger sizes due to precision loss). | 408 | -- (some tests may fail with bigger sizes due to precision loss). |
397 | runTests :: Int -- ^ maximum dimension | 409 | runTests :: Int -- ^ maximum dimension |
@@ -562,6 +574,7 @@ runTests n = do | |||
562 | , findAssocTest | 574 | , findAssocTest |
563 | , condTest | 575 | , condTest |
564 | , conformTest | 576 | , conformTest |
577 | , accumTest | ||
565 | ] | 578 | ] |
566 | return () | 579 | return () |
567 | 580 | ||