summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2011-01-01 20:25:47 +0000
committerAlberto Ruiz <aruiz@um.es>2011-01-01 20:25:47 +0000
commit2b5ea5fdbf68b8c125a9a256aa15a6de849cdbca (patch)
tree842245c80e82a1838db400655695b1f22f56cf31 /lib
parent7633e42d95095e16ad459de6cd65b9f7e700136b (diff)
simplified liftMatrix2Auto
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs34
-rw-r--r--lib/Data/Packed/Matrix.hs28
-rw-r--r--lib/Numeric/ContainerBoot.hs29
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs2
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs12
5 files changed, 59 insertions, 46 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 29aba51..57142b7 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -31,7 +31,8 @@ module Data.Packed.Internal.Matrix(
31 liftMatrix, liftMatrix2, 31 liftMatrix, liftMatrix2,
32 (@@>), 32 (@@>),
33 saveMatrix, 33 saveMatrix,
34 singleton 34 singleton,
35 size, shSize, conformVs, conformMs, conformVTo, conformMTo
35) where 36) where
36 37
37import Data.Packed.Internal.Common 38import Data.Packed.Internal.Common
@@ -441,3 +442,34 @@ saveMatrix filename fmt m = do
441 free charfmt 442 free charfmt
442 443
443foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM 444foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM
445
446----------------------------------------------------------------------
447
448conformMs ms = map (conformMTo (r,c)) ms
449 where
450 r = maximum (map rows ms)
451 c = maximum (map cols ms)
452
453conformVs vs = map (conformVTo n) vs
454 where
455 n = maximum (map dim vs)
456
457conformMTo (r,c) m
458 | size m == (r,c) = m
459 | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c))
460 | size m == (r,1) = repCols c m
461 | size m == (1,c) = repRows r m
462 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
463
464conformVTo n v
465 | dim v == n = v
466 | dim v == 1 = constantD (v@>0) n
467 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
468
469repRows n x = fromRows (replicate n (flatten x))
470repCols n x = fromColumns (replicate n (flatten x))
471
472size m = (rows m, cols m)
473
474shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
475
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index aad73dd..ab68618 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -302,30 +302,22 @@ repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m
302-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. 302-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
303liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 303liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
304liftMatrix2Auto f m1 m2 304liftMatrix2Auto f m1 m2
305 | compat' m1 m2 = lM f m1 m2 305 | compat' m1 m2 = lM f m1 m2
306 306 | ok = lM f m1' m2'
307 | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) 307 | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2
308 | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2)
309
310 | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2)
311 | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2
312
313 | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2)
314 | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2
315
316 | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: "
317 ++ show (size m1) ++ ", " ++ show (size m2)
318 where 308 where
319 (r1,c1) = size m1 309 (r1,c1) = size m1
320 (r2,c2) = size m2 310 (r2,c2) = size m2
321 311 r = max r1 r2
322size m = (rows m, cols m) 312 c = max c1 c2
313 r0 = min r1 r2
314 c0 = min c1 c2
315 ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2
316 m1' = conformMTo (r,c) m1
317 m2' = conformMTo (r,c) m2
323 318
324lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) 319lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2))
325 320
326repRows n x = fromRows (replicate n (flatten x))
327repCols n x = fromColumns (replicate n (flatten x))
328
329compat' :: Matrix a -> Matrix b -> Bool 321compat' :: Matrix a -> Matrix b -> Bool
330compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 322compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
331 where 323 where
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index 276eaa8..d9f0d78 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -639,33 +639,12 @@ assocM (r,c) z xs = ST.runSTMatrix $ do
639 639
640---------------------------------------------------------------------- 640----------------------------------------------------------------------
641 641
642conformMTo (r,c) m 642condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t'
643 | size m == (r,c) = m
644 | size m == (1,1) = konst (m@@>(0,0)) (r,c)
645 | size m == (r,1) = repCols c m
646 | size m == (1,c) = repRows r m
647 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
648
649conformVTo n v
650 | dim v == n = v
651 | dim v == 1 = konst (v@>0) n
652 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
653
654repRows n x = fromRows (replicate n (flatten x))
655repCols n x = fromColumns (replicate n (flatten x))
656
657size m = (rows m, cols m)
658
659shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
660
661condM a b l e t = reshape c $ cond a' b' l' e' t'
662 where 643 where
663 r = maximum (map rows [a,b,l,e,t]) 644 args@(a'':_) = conformMs [a,b,l,e,t]
664 c = maximum (map cols [a,b,l,e,t]) 645 [a', b', l', e', t'] = map flatten args
665 [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t]
666 646
667condV f a b l e t = f a' b' l' e' t' 647condV f a b l e t = f a' b' l' e' t'
668 where 648 where
669 n = maximum (map dim [a,b,l,e,t]) 649 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
670 [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t]
671 650
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 83464f4..a6b3174 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -167,8 +167,6 @@ vertical m = rows m >= cols m
167 167
168exactHermitian m = m `equal` ctrans m 168exactHermitian m = m `equal` ctrans m
169 169
170shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
171
172-------------------------------------------------------------- 170--------------------------------------------------------------
173 171
174-- | Full singular value decomposition. 172-- | Full singular value decomposition.
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 76eaaae..3bcfec5 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -380,6 +380,17 @@ condTest = utest "cond" ok
380 380
381--------------------------------------------------------------------- 381---------------------------------------------------------------------
382 382
383conformTest = utest "conform" ok
384 where
385 ok = 1 + row [1,2,3] + col [10,20,30,40] + (4><3) [1..]
386 == (4><3) [13,15,17
387 ,26,28,30
388 ,39,41,43
389 ,52,54,56]
390 row = asRow . fromList
391 col = asColumn . fromList :: [Double] -> Matrix Double
392
393---------------------------------------------------------------------
383 394
384-- | All tests must pass with a maximum dimension of about 20 395-- | All tests must pass with a maximum dimension of about 20
385-- (some tests may fail with bigger sizes due to precision loss). 396-- (some tests may fail with bigger sizes due to precision loss).
@@ -550,6 +561,7 @@ runTests n = do
550 , succTest 561 , succTest
551 , findAssocTest 562 , findAssocTest
552 , condTest 563 , condTest
564 , conformTest
553 ] 565 ]
554 return () 566 return ()
555 567