diff options
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 34 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 28 |
2 files changed, 43 insertions, 19 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 29aba51..57142b7 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -31,7 +31,8 @@ module Data.Packed.Internal.Matrix( | |||
31 | liftMatrix, liftMatrix2, | 31 | liftMatrix, liftMatrix2, |
32 | (@@>), | 32 | (@@>), |
33 | saveMatrix, | 33 | saveMatrix, |
34 | singleton | 34 | singleton, |
35 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | ||
35 | ) where | 36 | ) where |
36 | 37 | ||
37 | import Data.Packed.Internal.Common | 38 | import Data.Packed.Internal.Common |
@@ -441,3 +442,34 @@ saveMatrix filename fmt m = do | |||
441 | free charfmt | 442 | free charfmt |
442 | 443 | ||
443 | foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM | 444 | foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM |
445 | |||
446 | ---------------------------------------------------------------------- | ||
447 | |||
448 | conformMs ms = map (conformMTo (r,c)) ms | ||
449 | where | ||
450 | r = maximum (map rows ms) | ||
451 | c = maximum (map cols ms) | ||
452 | |||
453 | conformVs vs = map (conformVTo n) vs | ||
454 | where | ||
455 | n = maximum (map dim vs) | ||
456 | |||
457 | conformMTo (r,c) m | ||
458 | | size m == (r,c) = m | ||
459 | | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) | ||
460 | | size m == (r,1) = repCols c m | ||
461 | | size m == (1,c) = repRows r m | ||
462 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
463 | |||
464 | conformVTo n v | ||
465 | | dim v == n = v | ||
466 | | dim v == 1 = constantD (v@>0) n | ||
467 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
468 | |||
469 | repRows n x = fromRows (replicate n (flatten x)) | ||
470 | repCols n x = fromColumns (replicate n (flatten x)) | ||
471 | |||
472 | size m = (rows m, cols m) | ||
473 | |||
474 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
475 | |||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index aad73dd..ab68618 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -302,30 +302,22 @@ repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m | |||
302 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | 302 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. |
303 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 303 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
304 | liftMatrix2Auto f m1 m2 | 304 | liftMatrix2Auto f m1 m2 |
305 | | compat' m1 m2 = lM f m1 m2 | 305 | | compat' m1 m2 = lM f m1 m2 |
306 | 306 | | ok = lM f m1' m2' | |
307 | | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) | 307 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2 |
308 | | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2) | ||
309 | |||
310 | | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2) | ||
311 | | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2 | ||
312 | |||
313 | | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2) | ||
314 | | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2 | ||
315 | |||
316 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " | ||
317 | ++ show (size m1) ++ ", " ++ show (size m2) | ||
318 | where | 308 | where |
319 | (r1,c1) = size m1 | 309 | (r1,c1) = size m1 |
320 | (r2,c2) = size m2 | 310 | (r2,c2) = size m2 |
321 | 311 | r = max r1 r2 | |
322 | size m = (rows m, cols m) | 312 | c = max c1 c2 |
313 | r0 = min r1 r2 | ||
314 | c0 = min c1 c2 | ||
315 | ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2 | ||
316 | m1' = conformMTo (r,c) m1 | ||
317 | m2' = conformMTo (r,c) m2 | ||
323 | 318 | ||
324 | lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) | 319 | lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) |
325 | 320 | ||
326 | repRows n x = fromRows (replicate n (flatten x)) | ||
327 | repCols n x = fromColumns (replicate n (flatten x)) | ||
328 | |||
329 | compat' :: Matrix a -> Matrix b -> Bool | 321 | compat' :: Matrix a -> Matrix b -> Bool |
330 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | 322 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 |
331 | where | 323 | where |