summaryrefslogtreecommitdiff
path: root/lib/Numeric/ContainerBoot.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/ContainerBoot.hs')
-rw-r--r--lib/Numeric/ContainerBoot.hs39
1 files changed, 39 insertions, 0 deletions
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index 5a8e243..e33857a 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -127,6 +127,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where
127 find :: (e -> Bool) -> c e -> [IndexOf c] 127 find :: (e -> Bool) -> c e -> [IndexOf c]
128 -- | create a structure from an association list 128 -- | create a structure from an association list
129 assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e 129 assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e
130 -- | a vectorized form of case 'compare' a_i b_i of LT -> l_i; EQ -> e_i; GT -> g_i
131 cond :: RealFloat e => c e -> c e -> c e -> c e -> c e -> c e
130 132
131-------------------------------------------------------------------------- 133--------------------------------------------------------------------------
132 134
@@ -155,6 +157,7 @@ instance Container Vector Float where
155 step = stepF 157 step = stepF
156 find = findV 158 find = findV
157 assoc = assocV 159 assoc = assocV
160 cond = condV condF
158 161
159instance Container Vector Double where 162instance Container Vector Double where
160 scale = vectorMapValR Scale 163 scale = vectorMapValR Scale
@@ -181,6 +184,7 @@ instance Container Vector Double where
181 step = stepD 184 step = stepD
182 find = findV 185 find = findV
183 assoc = assocV 186 assoc = assocV
187 cond = condV condD
184 188
185instance Container Vector (Complex Double) where 189instance Container Vector (Complex Double) where
186 scale = vectorMapValC Scale 190 scale = vectorMapValC Scale
@@ -207,6 +211,7 @@ instance Container Vector (Complex Double) where
207 step = undefined -- cannot match 211 step = undefined -- cannot match
208 find = findV 212 find = findV
209 assoc = assocV 213 assoc = assocV
214 cond = undefined -- cannot match
210 215
211instance Container Vector (Complex Float) where 216instance Container Vector (Complex Float) where
212 scale = vectorMapValQ Scale 217 scale = vectorMapValQ Scale
@@ -233,6 +238,7 @@ instance Container Vector (Complex Float) where
233 step = undefined -- cannot match 238 step = undefined -- cannot match
234 find = findV 239 find = findV
235 assoc = assocV 240 assoc = assocV
241 cond = undefined -- cannot match
236 242
237--------------------------------------------------------------- 243---------------------------------------------------------------
238 244
@@ -265,6 +271,7 @@ instance (Container Vector a) => Container Matrix a where
265 step = liftMatrix step 271 step = liftMatrix step
266 find = findM 272 find = findM
267 assoc = assocM 273 assoc = assocM
274 cond = condM
268 275
269---------------------------------------------------- 276----------------------------------------------------
270 277
@@ -620,3 +627,35 @@ assocM (r,c) z xs = ST.runSTMatrix $ do
620 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs 627 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
621 return m 628 return m
622 629
630----------------------------------------------------------------------
631
632conformMTo (r,c) m
633 | size m == (r,c) = m
634 | size m == (1,1) = konst (m@@>(0,0)) (r,c)
635 | size m == (r,1) = repCols c m
636 | size m == (1,c) = repRows r m
637 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
638
639conformVTo n v
640 | dim v == n = v
641 | dim v == 1 = konst (v@>0) n
642 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
643
644repRows n x = fromRows (replicate n (flatten x))
645repCols n x = fromColumns (replicate n (flatten x))
646
647size m = (rows m, cols m)
648
649shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
650
651condM a b l e t = reshape c $ cond a' b' l' e' t'
652 where
653 r = maximum (map rows [a,b,l,e,t])
654 c = maximum (map cols [a,b,l,e,t])
655 [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t]
656
657condV f a b l e t = f a' b' l' e' t'
658 where
659 n = maximum (map dim [a,b,l,e,t])
660 [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t]
661