diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 29aba51..57142b7 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -31,7 +31,8 @@ module Data.Packed.Internal.Matrix( | |||
31 | liftMatrix, liftMatrix2, | 31 | liftMatrix, liftMatrix2, |
32 | (@@>), | 32 | (@@>), |
33 | saveMatrix, | 33 | saveMatrix, |
34 | singleton | 34 | singleton, |
35 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | ||
35 | ) where | 36 | ) where |
36 | 37 | ||
37 | import Data.Packed.Internal.Common | 38 | import Data.Packed.Internal.Common |
@@ -441,3 +442,34 @@ saveMatrix filename fmt m = do | |||
441 | free charfmt | 442 | free charfmt |
442 | 443 | ||
443 | foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM | 444 | foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM |
445 | |||
446 | ---------------------------------------------------------------------- | ||
447 | |||
448 | conformMs ms = map (conformMTo (r,c)) ms | ||
449 | where | ||
450 | r = maximum (map rows ms) | ||
451 | c = maximum (map cols ms) | ||
452 | |||
453 | conformVs vs = map (conformVTo n) vs | ||
454 | where | ||
455 | n = maximum (map dim vs) | ||
456 | |||
457 | conformMTo (r,c) m | ||
458 | | size m == (r,c) = m | ||
459 | | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) | ||
460 | | size m == (r,1) = repCols c m | ||
461 | | size m == (1,c) = repRows r m | ||
462 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
463 | |||
464 | conformVTo n v | ||
465 | | dim v == n = v | ||
466 | | dim v == 1 = constantD (v@>0) n | ||
467 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
468 | |||
469 | repRows n x = fromRows (replicate n (flatten x)) | ||
470 | repCols n x = fromColumns (replicate n (flatten x)) | ||
471 | |||
472 | size m = (rows m, cols m) | ||
473 | |||
474 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
475 | |||