diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 2922cbe..aad73dd 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -300,16 +300,24 @@ repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t | |||
300 | repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m | 300 | repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m |
301 | 301 | ||
302 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | 302 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. |
303 | liftMatrix2Auto :: (Element t, Element a, Element b) | 303 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
304 | => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 304 | liftMatrix2Auto f m1 m2 |
305 | liftMatrix2Auto f m1 m2 | compat' m1 m2 = lM f m1 m2 | 305 | | compat' m1 m2 = lM f m1 m2 |
306 | | rows m1 == rows m2 && cols m2 == 1 = lM f m1 (repCols (cols m1) m2) | 306 | |
307 | | rows m1 == rows m2 && cols m1 == 1 = lM f (repCols (cols m2) m1) m2 | 307 | | r1 == 1 && c2 == 1 = lM f (repRows r2 m1) (repCols c1 m2) |
308 | | cols m1 == cols m2 && rows m2 == 1 = lM f m1 (repRows (rows m1) m2) | 308 | | c1 == 1 && r2 == 1 = lM f (repCols c2 m1) (repRows r1 m2) |
309 | | cols m1 == cols m2 && cols m1 == 1 = lM f (repRows (rows m2) m1) m2 | 309 | |
310 | | rows m1 == 1 && cols m2 == 1 = lM f (repRows (rows m2) m1) (repCols (cols m1) m2) | 310 | | r1 == r2 && c2 == 1 = lM f m1 (repCols c1 m2) |
311 | | cols m1 == 1 && rows m2 == 1 = lM f (repCols (cols m2) m1) (repRows (rows m1) m2) | 311 | | r1 == r2 && c1 == 1 = lM f (repCols c2 m1) m2 |
312 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ show (size m1) ++ ", " ++ show (size m2) | 312 | |
313 | | c1 == c2 && r2 == 1 = lM f m1 (repRows r1 m2) | ||
314 | | c1 == c2 && r1 == 1 = lM f (repRows r2 m1) m2 | ||
315 | |||
316 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " | ||
317 | ++ show (size m1) ++ ", " ++ show (size m2) | ||
318 | where | ||
319 | (r1,c1) = size m1 | ||
320 | (r2,c2) = size m2 | ||
313 | 321 | ||
314 | size m = (rows m, cols m) | 322 | size m = (rows m, cols m) |
315 | 323 | ||
@@ -319,9 +327,10 @@ repRows n x = fromRows (replicate n (flatten x)) | |||
319 | repCols n x = fromColumns (replicate n (flatten x)) | 327 | repCols n x = fromColumns (replicate n (flatten x)) |
320 | 328 | ||
321 | compat' :: Matrix a -> Matrix b -> Bool | 329 | compat' :: Matrix a -> Matrix b -> Bool |
322 | compat' m1 m2 = rows m1 == 1 && cols m1 == 1 | 330 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 |
323 | || rows m2 == 1 && cols m2 == 1 | 331 | where |
324 | || rows m1 == rows m2 && cols m1 == cols m2 | 332 | s1 = size m1 |
333 | s2 = size m2 | ||
325 | 334 | ||
326 | ------------------------------------------------------------ | 335 | ------------------------------------------------------------ |
327 | 336 | ||