summaryrefslogtreecommitdiff
path: root/lib/Data/Packed
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs34
-rw-r--r--lib/Data/Packed/Matrix.hs28
2 files changed, 43 insertions, 19 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 29aba51..57142b7 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -31,7 +31,8 @@ module Data.Packed.Internal.Matrix(
31 liftMatrix, liftMatrix2, 31 liftMatrix, liftMatrix2,
32 (@@>), 32 (@@>),
33 saveMatrix, 33 saveMatrix,
34 singleton 34 singleton,
35 size, shSize, conformVs, conformMs, conformVTo, conformMTo
35) where 36) where
36 37
37import Data.Packed.Internal.Common 38import Data.Packed.Internal.Common
@@ -441,3 +442,34 @@ saveMatrix filename fmt m = do
441 free charfmt 442 free charfmt
442 443
443foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM 444foreign import ccall "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM
445
446----------------------------------------------------------------------
447
448conformMs ms = map (conformMTo (r,c)) ms
449 where
450 r = maximum (map rows ms)
451 c = maximum (map cols ms)
452
453conformVs vs = map (conformVTo n) vs
454 where
455 n = maximum (map dim vs)
456
457conformMTo (r,c) m
458 | size m == (r,c) = m
459 | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c))
460 | size m == (r,1) = repCols c m
461 | size m == (1,c) = repRows r m
462 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
463
464conformVTo n v
465 | dim v == n = v
466 | dim v == 1 = constantD (v@>0) n
467 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
468
469repRows n x = fromRows (replicate n (flatten x))
470repCols n x = fromColumns (replicate n (flatten x))
471
472size m = (rows m, cols m)
473
474shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
475
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index aad73dd..ab68618 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -302,30 +302,22 @@ repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m
302-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. 302-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
303liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 303liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
304liftMatrix2Auto f m1 m2 304liftMatrix2Auto f m1 m2
305 | compat' m1 m2 = lM f m1 m2 305 | compat' m1 m2 = lM f m1 m2
306 306 | ok = lM f m1' m2'
307 | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) 307 | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2
308 | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2)
309
310 | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2)
311 | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2
312
313 | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2)
314 | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2
315
316 | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: "
317 ++ show (size m1) ++ ", " ++ show (size m2)
318 where 308 where
319 (r1,c1) = size m1 309 (r1,c1) = size m1
320 (r2,c2) = size m2 310 (r2,c2) = size m2
321 311 r = max r1 r2
322size m = (rows m, cols m) 312 c = max c1 c2
313 r0 = min r1 r2
314 c0 = min c1 c2
315 ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2
316 m1' = conformMTo (r,c) m1
317 m2' = conformMTo (r,c) m2
323 318
324lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) 319lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2))
325 320
326repRows n x = fromRows (replicate n (flatten x))
327repCols n x = fromColumns (replicate n (flatten x))
328
329compat' :: Matrix a -> Matrix b -> Bool 321compat' :: Matrix a -> Matrix b -> Bool
330compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 322compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
331 where 323 where