diff options
Diffstat (limited to 'lib/Numeric/ContainerBoot.hs')
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 5a8e243..e33857a 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -127,6 +127,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
127 | find :: (e -> Bool) -> c e -> [IndexOf c] | 127 | find :: (e -> Bool) -> c e -> [IndexOf c] |
128 | -- | create a structure from an association list | 128 | -- | create a structure from an association list |
129 | assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e | 129 | assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e |
130 | -- | a vectorized form of case 'compare' a_i b_i of LT -> l_i; EQ -> e_i; GT -> g_i | ||
131 | cond :: RealFloat e => c e -> c e -> c e -> c e -> c e -> c e | ||
130 | 132 | ||
131 | -------------------------------------------------------------------------- | 133 | -------------------------------------------------------------------------- |
132 | 134 | ||
@@ -155,6 +157,7 @@ instance Container Vector Float where | |||
155 | step = stepF | 157 | step = stepF |
156 | find = findV | 158 | find = findV |
157 | assoc = assocV | 159 | assoc = assocV |
160 | cond = condV condF | ||
158 | 161 | ||
159 | instance Container Vector Double where | 162 | instance Container Vector Double where |
160 | scale = vectorMapValR Scale | 163 | scale = vectorMapValR Scale |
@@ -181,6 +184,7 @@ instance Container Vector Double where | |||
181 | step = stepD | 184 | step = stepD |
182 | find = findV | 185 | find = findV |
183 | assoc = assocV | 186 | assoc = assocV |
187 | cond = condV condD | ||
184 | 188 | ||
185 | instance Container Vector (Complex Double) where | 189 | instance Container Vector (Complex Double) where |
186 | scale = vectorMapValC Scale | 190 | scale = vectorMapValC Scale |
@@ -207,6 +211,7 @@ instance Container Vector (Complex Double) where | |||
207 | step = undefined -- cannot match | 211 | step = undefined -- cannot match |
208 | find = findV | 212 | find = findV |
209 | assoc = assocV | 213 | assoc = assocV |
214 | cond = undefined -- cannot match | ||
210 | 215 | ||
211 | instance Container Vector (Complex Float) where | 216 | instance Container Vector (Complex Float) where |
212 | scale = vectorMapValQ Scale | 217 | scale = vectorMapValQ Scale |
@@ -233,6 +238,7 @@ instance Container Vector (Complex Float) where | |||
233 | step = undefined -- cannot match | 238 | step = undefined -- cannot match |
234 | find = findV | 239 | find = findV |
235 | assoc = assocV | 240 | assoc = assocV |
241 | cond = undefined -- cannot match | ||
236 | 242 | ||
237 | --------------------------------------------------------------- | 243 | --------------------------------------------------------------- |
238 | 244 | ||
@@ -265,6 +271,7 @@ instance (Container Vector a) => Container Matrix a where | |||
265 | step = liftMatrix step | 271 | step = liftMatrix step |
266 | find = findM | 272 | find = findM |
267 | assoc = assocM | 273 | assoc = assocM |
274 | cond = condM | ||
268 | 275 | ||
269 | ---------------------------------------------------- | 276 | ---------------------------------------------------- |
270 | 277 | ||
@@ -620,3 +627,35 @@ assocM (r,c) z xs = ST.runSTMatrix $ do | |||
620 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | 627 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs |
621 | return m | 628 | return m |
622 | 629 | ||
630 | ---------------------------------------------------------------------- | ||
631 | |||
632 | conformMTo (r,c) m | ||
633 | | size m == (r,c) = m | ||
634 | | size m == (1,1) = konst (m@@>(0,0)) (r,c) | ||
635 | | size m == (r,1) = repCols c m | ||
636 | | size m == (1,c) = repRows r m | ||
637 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
638 | |||
639 | conformVTo n v | ||
640 | | dim v == n = v | ||
641 | | dim v == 1 = konst (v@>0) n | ||
642 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
643 | |||
644 | repRows n x = fromRows (replicate n (flatten x)) | ||
645 | repCols n x = fromColumns (replicate n (flatten x)) | ||
646 | |||
647 | size m = (rows m, cols m) | ||
648 | |||
649 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
650 | |||
651 | condM a b l e t = reshape c $ cond a' b' l' e' t' | ||
652 | where | ||
653 | r = maximum (map rows [a,b,l,e,t]) | ||
654 | c = maximum (map cols [a,b,l,e,t]) | ||
655 | [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t] | ||
656 | |||
657 | condV f a b l e t = f a' b' l' e' t' | ||
658 | where | ||
659 | n = maximum (map dim [a,b,l,e,t]) | ||
660 | [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t] | ||
661 | |||