diff options
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 34 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 28 | ||||
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 29 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 12 |
5 files changed, 59 insertions, 46 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 29aba51..57142b7 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -31,7 +31,8 @@ module Data.Packed.Internal.Matrix( | |||
31 | liftMatrix, liftMatrix2, | 31 | liftMatrix, liftMatrix2, |
32 | (@@>), | 32 | (@@>), |
33 | saveMatrix, | 33 | saveMatrix, |
34 | singleton | 34 | singleton, |
35 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | ||
35 | ) where | 36 | ) where |
36 | 37 | ||
37 | import Data.Packed.Internal.Common | 38 | import Data.Packed.Internal.Common |
@@ -441,3 +442,34 @@ saveMatrix filename fmt m = do | |||
441 | free charfmt | 442 | free charfmt |
442 | 443 | ||
443 | foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM | 444 | foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM |
445 | |||
446 | ---------------------------------------------------------------------- | ||
447 | |||
448 | conformMs ms = map (conformMTo (r,c)) ms | ||
449 | where | ||
450 | r = maximum (map rows ms) | ||
451 | c = maximum (map cols ms) | ||
452 | |||
453 | conformVs vs = map (conformVTo n) vs | ||
454 | where | ||
455 | n = maximum (map dim vs) | ||
456 | |||
457 | conformMTo (r,c) m | ||
458 | | size m == (r,c) = m | ||
459 | | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) | ||
460 | | size m == (r,1) = repCols c m | ||
461 | | size m == (1,c) = repRows r m | ||
462 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
463 | |||
464 | conformVTo n v | ||
465 | | dim v == n = v | ||
466 | | dim v == 1 = constantD (v@>0) n | ||
467 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
468 | |||
469 | repRows n x = fromRows (replicate n (flatten x)) | ||
470 | repCols n x = fromColumns (replicate n (flatten x)) | ||
471 | |||
472 | size m = (rows m, cols m) | ||
473 | |||
474 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
475 | |||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index aad73dd..ab68618 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -302,30 +302,22 @@ repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m | |||
302 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | 302 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. |
303 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 303 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
304 | liftMatrix2Auto f m1 m2 | 304 | liftMatrix2Auto f m1 m2 |
305 | | compat' m1 m2 = lM f m1 m2 | 305 | | compat' m1 m2 = lM f m1 m2 |
306 | 306 | | ok = lM f m1' m2' | |
307 | | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) | 307 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2 |
308 | | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2) | ||
309 | |||
310 | | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2) | ||
311 | | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2 | ||
312 | |||
313 | | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2) | ||
314 | | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2 | ||
315 | |||
316 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " | ||
317 | ++ show (size m1) ++ ", " ++ show (size m2) | ||
318 | where | 308 | where |
319 | (r1,c1) = size m1 | 309 | (r1,c1) = size m1 |
320 | (r2,c2) = size m2 | 310 | (r2,c2) = size m2 |
321 | 311 | r = max r1 r2 | |
322 | size m = (rows m, cols m) | 312 | c = max c1 c2 |
313 | r0 = min r1 r2 | ||
314 | c0 = min c1 c2 | ||
315 | ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2 | ||
316 | m1' = conformMTo (r,c) m1 | ||
317 | m2' = conformMTo (r,c) m2 | ||
323 | 318 | ||
324 | lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) | 319 | lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) |
325 | 320 | ||
326 | repRows n x = fromRows (replicate n (flatten x)) | ||
327 | repCols n x = fromColumns (replicate n (flatten x)) | ||
328 | |||
329 | compat' :: Matrix a -> Matrix b -> Bool | 321 | compat' :: Matrix a -> Matrix b -> Bool |
330 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | 322 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 |
331 | where | 323 | where |
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 276eaa8..d9f0d78 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -639,33 +639,12 @@ assocM (r,c) z xs = ST.runSTMatrix $ do | |||
639 | 639 | ||
640 | ---------------------------------------------------------------------- | 640 | ---------------------------------------------------------------------- |
641 | 641 | ||
642 | conformMTo (r,c) m | 642 | condM a b l e t = reshape (cols a'') $ cond a' b' l' e' t' |
643 | | size m == (r,c) = m | ||
644 | | size m == (1,1) = konst (m@@>(0,0)) (r,c) | ||
645 | | size m == (r,1) = repCols c m | ||
646 | | size m == (1,c) = repRows r m | ||
647 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
648 | |||
649 | conformVTo n v | ||
650 | | dim v == n = v | ||
651 | | dim v == 1 = konst (v@>0) n | ||
652 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
653 | |||
654 | repRows n x = fromRows (replicate n (flatten x)) | ||
655 | repCols n x = fromColumns (replicate n (flatten x)) | ||
656 | |||
657 | size m = (rows m, cols m) | ||
658 | |||
659 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
660 | |||
661 | condM a b l e t = reshape c $ cond a' b' l' e' t' | ||
662 | where | 643 | where |
663 | r = maximum (map rows [a,b,l,e,t]) | 644 | args@(a'':_) = conformMs [a,b,l,e,t] |
664 | c = maximum (map cols [a,b,l,e,t]) | 645 | [a', b', l', e', t'] = map flatten args |
665 | [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t] | ||
666 | 646 | ||
667 | condV f a b l e t = f a' b' l' e' t' | 647 | condV f a b l e t = f a' b' l' e' t' |
668 | where | 648 | where |
669 | n = maximum (map dim [a,b,l,e,t]) | 649 | [a', b', l', e', t'] = conformVs [a,b,l,e,t] |
670 | [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t] | ||
671 | 650 | ||
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 83464f4..a6b3174 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -167,8 +167,6 @@ vertical m = rows m >= cols m | |||
167 | 167 | ||
168 | exactHermitian m = m `equal` ctrans m | 168 | exactHermitian m = m `equal` ctrans m |
169 | 169 | ||
170 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
171 | |||
172 | -------------------------------------------------------------- | 170 | -------------------------------------------------------------- |
173 | 171 | ||
174 | -- | Full singular value decomposition. | 172 | -- | Full singular value decomposition. |
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 76eaaae..3bcfec5 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -380,6 +380,17 @@ condTest = utest "cond" ok | |||
380 | 380 | ||
381 | --------------------------------------------------------------------- | 381 | --------------------------------------------------------------------- |
382 | 382 | ||
383 | conformTest = utest "conform" ok | ||
384 | where | ||
385 | ok = 1 + row [1,2,3] + col [10,20,30,40] + (4><3) [1..] | ||
386 | == (4><3) [13,15,17 | ||
387 | ,26,28,30 | ||
388 | ,39,41,43 | ||
389 | ,52,54,56] | ||
390 | row = asRow . fromList | ||
391 | col = asColumn . fromList :: [Double] -> Matrix Double | ||
392 | |||
393 | --------------------------------------------------------------------- | ||
383 | 394 | ||
384 | -- | All tests must pass with a maximum dimension of about 20 | 395 | -- | All tests must pass with a maximum dimension of about 20 |
385 | -- (some tests may fail with bigger sizes due to precision loss). | 396 | -- (some tests may fail with bigger sizes due to precision loss). |
@@ -550,6 +561,7 @@ runTests n = do | |||
550 | , succTest | 561 | , succTest |
551 | , findAssocTest | 562 | , findAssocTest |
552 | , condTest | 563 | , condTest |
564 | , conformTest | ||
553 | ] | 565 | ] |
554 | return () | 566 | return () |
555 | 567 | ||