From 2b5ea5fdbf68b8c125a9a256aa15a6de849cdbca Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 1 Jan 2011 20:25:47 +0000 Subject: simplified liftMatrix2Auto --- lib/Data/Packed/Internal/Matrix.hs | 34 ++++++++++++++++++++++++++++++++- lib/Data/Packed/Matrix.hs | 28 ++++++++++----------------- lib/Numeric/ContainerBoot.hs | 29 ++++------------------------ lib/Numeric/LinearAlgebra/Algorithms.hs | 2 -- lib/Numeric/LinearAlgebra/Tests.hs | 12 ++++++++++++ 5 files changed, 59 insertions(+), 46 deletions(-) (limited to 'lib') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 29aba51..57142b7 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -31,7 +31,8 @@ module Data.Packed.Internal.Matrix( liftMatrix, liftMatrix2, (@@>), saveMatrix, - singleton + singleton, + size, shSize, conformVs, conformMs, conformVTo, conformMTo ) where import Data.Packed.Internal.Common @@ -441,3 +442,34 @@ saveMatrix filename fmt m = do free charfmt foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM + +---------------------------------------------------------------------- + +conformMs ms = map (conformMTo (r,c)) ms + where + r = maximum (map rows ms) + c = maximum (map cols ms) + +conformVs vs = map (conformVTo n) vs + where + n = maximum (map dim vs) + +conformMTo (r,c) m + | size m == (r,c) = m + | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) + | size m == (r,1) = repCols c m + | size m == (1,c) = repRows r m + | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" + +conformVTo n v + | dim v == n = v + | dim v == 1 = constantD (v@>0) n + | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n + +repRows n x = fromRows (replicate n (flatten x)) +repCols n x = fromColumns (replicate n (flatten x)) + +size m = (rows m, cols m) + +shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" + diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index aad73dd..ab68618 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -302,30 +302,22 @@ repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2Auto f m1 m2 - | compat' m1 m2 = lM f m1 m2 - - | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) - | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2) - - | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2) - | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2 - - | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2) - | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2 - - | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " - ++ show (size m1) ++ ", " ++ show (size m2) + | compat' m1 m2 = lM f m1 m2 + | ok = lM f m1' m2' + | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2 where (r1,c1) = size m1 (r2,c2) = size m2 - -size m = (rows m, cols m) + r = max r1 r2 + c = max c1 c2 + r0 = min r1 r2 + c0 = min c1 c2 + ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2 + m1' = conformMTo (r,c) m1 + m2' = conformMTo (r,c) m2 lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) -repRows n x = fromRows (replicate n (flatten x)) -repCols n x = fromColumns (replicate n (flatten x)) - compat' :: Matrix a -> Matrix b -> Bool compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 where diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 276eaa8..d9f0d78 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs @@ -639,33 +639,12 @@ assocM (r,c) z xs = ST.runSTMatrix $ do ---------------------------------------------------------------------- -conformMTo (r,c) m - | size m == (r,c) = m - | size m == (1,1) = konst (m@@>(0,0)) (r,c) - | size m == (r,1) = repCols c m - | size m == (1,c) = repRows r m - | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" - -conformVTo n v - | dim v == n = v - | dim v == 1 = konst (v@>0) n - | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n - -repRows n x = fromRows (replicate n (flatten x)) -repCols n x = fromColumns (replicate n (flatten x)) - -size m = (rows m, cols m) - -shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" - -condM a b l e t = reshape c $ cond a' b' l' e' t' +condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' where - r = maximum (map rows [a,b,l,e,t]) - c = maximum (map cols [a,b,l,e,t]) - [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t] + args@(a'':_) = conformMs [a,b,l,e,t] + [a', b', l', e', t'] = map flatten args condV f a b l e t = f a' b' l' e' t' where - n = maximum (map dim [a,b,l,e,t]) - [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t] + [a', b', l', e', t'] = conformVs [a,b,l,e,t] diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 83464f4..a6b3174 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -167,8 +167,6 @@ vertical m = rows m >= cols m exactHermitian m = m `equal` ctrans m -shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" - -------------------------------------------------------------- -- | Full singular value decomposition. diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 76eaaae..3bcfec5 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -380,6 +380,17 @@ condTest = utest "cond" ok --------------------------------------------------------------------- +conformTest = utest "conform" ok + where + ok = 1 + row [1,2,3] + col [10,20,30,40] + (4><3) [1..] + == (4><3) [13,15,17 + ,26,28,30 + ,39,41,43 + ,52,54,56] + row = asRow . fromList + col = asColumn . fromList :: [Double] -> Matrix Double + +--------------------------------------------------------------------- -- | All tests must pass with a maximum dimension of about 20 -- (some tests may fail with bigger sizes due to precision loss). @@ -550,6 +561,7 @@ runTests n = do , succTest , findAssocTest , condTest + , conformTest ] return () -- cgit v1.2.3