From 2a3baa18b508851a9d30e4abc7d7de7978146557 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 1 Feb 2010 10:29:20 +0000 Subject: export liftMatrix2Auto --- lib/Data/Packed/Matrix.hs | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) (limited to 'lib/Data') diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 4cb7a88..66e5082 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -27,7 +27,7 @@ module Data.Packed.Matrix ( subMatrix, takeRows, dropRows, takeColumns, dropColumns, extractRows, ident, diag, diagRect, takeDiag, - liftMatrix, liftMatrix2, + liftMatrix, liftMatrix2, liftMatrix2Auto, format, dispf, disps, loadMatrix, saveMatrix, fromFile, fileDimensions, readMatrix, fromArray2D @@ -386,3 +386,28 @@ extractRows l m = fromRows $ extract (toRows $ m) l -} repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t repmat m r c = fromBlocks $ partit c $ replicate (r*c) m + +-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. +liftMatrix2Auto :: (Element t, Element a, Element b) + => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t +liftMatrix2Auto f m1 m2 | compat' m1 m2 = lM f m1 m2 + | rows m1 == rows m2 && cols m2 == 1 = lM f m1 (repCols (cols m1) m2) + | rows m1 == rows m2 && cols m1 == 1 = lM f (repCols (cols m2) m1) m2 + | cols m1 == cols m2 && rows m2 == 1 = lM f m1 (repRows (rows m1) m2) + | cols m1 == cols m2 && cols m1 == 1 = lM f (repRows (rows m2) m1) m2 + | rows m1 == 1 && cols m2 == 1 = lM f (repRows (rows m2) m1) (repCols (cols m1) m2) + | cols m1 == 1 && rows m2 == 1 = lM f (repCols (cols m2) m1) (repRows (rows m1) m2) + | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ show (size m1) ++ ", " ++ show (size m2) + +size m = (rows m, cols m) + +lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) + +repRows n x = fromRows (replicate n (flatten x)) +repCols n x = fromColumns (replicate n (flatten x)) + +compat' :: Matrix a -> Matrix b -> Bool +compat' m1 m2 = rows m1 == 1 && cols m1 == 1 + || rows m2 == 1 && cols m2 == 1 + || rows m1 == rows m2 && cols m1 == cols m2 + -- cgit v1.2.3