From 2b5ea5fdbf68b8c125a9a256aa15a6de849cdbca Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 1 Jan 2011 20:25:47 +0000 Subject: simplified liftMatrix2Auto --- lib/Data/Packed/Internal/Matrix.hs | 34 +++++++++++++++++++++++++++++++++- lib/Data/Packed/Matrix.hs | 28 ++++++++++------------------ 2 files changed, 43 insertions(+), 19 deletions(-) (limited to 'lib/Data') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 29aba51..57142b7 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -31,7 +31,8 @@ module Data.Packed.Internal.Matrix( liftMatrix, liftMatrix2, (@@>), saveMatrix, - singleton + singleton, + size, shSize, conformVs, conformMs, conformVTo, conformMTo ) where import Data.Packed.Internal.Common @@ -441,3 +442,34 @@ saveMatrix filename fmt m = do free charfmt foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM + +---------------------------------------------------------------------- + +conformMs ms = map (conformMTo (r,c)) ms + where + r = maximum (map rows ms) + c = maximum (map cols ms) + +conformVs vs = map (conformVTo n) vs + where + n = maximum (map dim vs) + +conformMTo (r,c) m + | size m == (r,c) = m + | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) + | size m == (r,1) = repCols c m + | size m == (1,c) = repRows r m + | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" + +conformVTo n v + | dim v == n = v + | dim v == 1 = constantD (v@>0) n + | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n + +repRows n x = fromRows (replicate n (flatten x)) +repCols n x = fromColumns (replicate n (flatten x)) + +size m = (rows m, cols m) + +shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" + diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index aad73dd..ab68618 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -302,30 +302,22 @@ repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2Auto f m1 m2 - | compat' m1 m2 = lM f m1 m2 - - | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) - | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2) - - | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2) - | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2 - - | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2) - | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2 - - | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " - ++ show (size m1) ++ ", " ++ show (size m2) + | compat' m1 m2 = lM f m1 m2 + | ok = lM f m1' m2' + | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2 where (r1,c1) = size m1 (r2,c2) = size m2 - -size m = (rows m, cols m) + r = max r1 r2 + c = max c1 c2 + r0 = min r1 r2 + c0 = min c1 c2 + ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2 + m1' = conformMTo (r,c) m1 + m2' = conformMTo (r,c) m2 lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) -repRows n x = fromRows (replicate n (flatten x)) -repCols n x = fromColumns (replicate n (flatten x)) - compat' :: Matrix a -> Matrix b -> Bool compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 where -- cgit v1.2.3