From 20831d6521e54b42aa8410a1434d55c5b9bc004b Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 14 Dec 2010 09:24:16 +0000 Subject: fixed bug in liftMatrix2Auto --- lib/Data/Packed/Matrix.hs | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) (limited to 'lib/Data') diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 2922cbe..aad73dd 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -300,16 +300,24 @@ repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. -liftMatrix2Auto :: (Element t, Element a, Element b) - => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t -liftMatrix2Auto f m1 m2 | compat' m1 m2 = lM f m1 m2 - | rows m1 == rows m2 && cols m2 == 1 = lM f m1 (repCols (cols m1) m2) - | rows m1 == rows m2 && cols m1 == 1 = lM f (repCols (cols m2) m1) m2 - | cols m1 == cols m2 && rows m2 == 1 = lM f m1 (repRows (rows m1) m2) - | cols m1 == cols m2 && cols m1 == 1 = lM f (repRows (rows m2) m1) m2 - | rows m1 == 1 && cols m2 == 1 = lM f (repRows (rows m2) m1) (repCols (cols m1) m2) - | cols m1 == 1 && rows m2 == 1 = lM f (repCols (cols m2) m1) (repRows (rows m1) m2) - | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ show (size m1) ++ ", " ++ show (size m2) +liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t +liftMatrix2Auto f m1 m2 + | compat' m1 m2 = lM f m1 m2 + + | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) + | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2) + + | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2) + | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2 + + | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2) + | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2 + + | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " + ++ show (size m1) ++ ", " ++ show (size m2) + where + (r1,c1) = size m1 + (r2,c2) = size m2 size m = (rows m, cols m) @@ -319,9 +327,10 @@ repRows n x = fromRows (replicate n (flatten x)) repCols n x = fromColumns (replicate n (flatten x)) compat' :: Matrix a -> Matrix b -> Bool -compat' m1 m2 = rows m1 == 1 && cols m1 == 1 - || rows m2 == 1 && cols m2 == 1 - || rows m1 == rows m2 && cols m1 == cols m2 +compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 + where + s1 = size m1 + s2 = size m2 ------------------------------------------------------------ -- cgit v1.2.3