diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-06 08:50:50 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-06 08:50:50 +0200 |
commit | c9914d694d3b86ece46fa0c76e0466c6cd394d14 (patch) | |
tree | 7fa1c5a95b204912f5d560c843ae6045ee8d2780 /lib/Data | |
parent | 4078cf44c98b42960be27843782f6983bb66017f (diff) |
extend conformability to empty arrays
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/Internal/Common.hs | 6 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 38 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 44 |
3 files changed, 69 insertions, 19 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index 49f17b0..edef3c2 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs | |||
@@ -49,7 +49,11 @@ common f = commonval . map f where | |||
49 | compatdim :: [Int] -> Maybe Int | 49 | compatdim :: [Int] -> Maybe Int |
50 | compatdim [] = Nothing | 50 | compatdim [] = Nothing |
51 | compatdim [a] = Just a | 51 | compatdim [a] = Just a |
52 | compatdim (a:b:xs) = if a==b || a==1 || b==1 then compatdim (max a b:xs) else Nothing | 52 | compatdim (a:b:xs) |
53 | | a==b = compatdim (b:xs) | ||
54 | | a==1 = compatdim (b:xs) | ||
55 | | b==1 = compatdim (a:xs) | ||
56 | | otherwise = Nothing | ||
53 | 57 | ||
54 | -- | Formatting tool | 58 | -- | Formatting tool |
55 | table :: String -> [[String]] -> String | 59 | table :: String -> [[String]] -> String |
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2004e85..9719fc0 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -32,6 +32,7 @@ module Data.Packed.Internal.Matrix( | |||
32 | (@@>), atM', | 32 | (@@>), atM', |
33 | saveMatrix, | 33 | saveMatrix, |
34 | singleton, | 34 | singleton, |
35 | emptyM, | ||
35 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | 36 | size, shSize, conformVs, conformMs, conformVTo, conformMTo |
36 | ) where | 37 | ) where |
37 | 38 | ||
@@ -157,16 +158,24 @@ toLists m = splitEvery (cols m) . toList . flatten $ m | |||
157 | -- All vectors must have the same dimension, | 158 | -- All vectors must have the same dimension, |
158 | -- or dimension 1, which is are automatically expanded. | 159 | -- or dimension 1, which is are automatically expanded. |
159 | fromRows :: Element t => [Vector t] -> Matrix t | 160 | fromRows :: Element t => [Vector t] -> Matrix t |
161 | fromRows [] = emptyM 0 0 | ||
160 | fromRows vs = case compatdim (map dim vs) of | 162 | fromRows vs = case compatdim (map dim vs) of |
161 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | 163 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) |
162 | Just c -> reshape c . vjoin . map (adapt c) $ vs | 164 | Just 0 -> emptyM r 0 |
165 | Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs | ||
163 | where | 166 | where |
164 | adapt c v | dim v == c = v | 167 | r = length vs |
165 | | otherwise = constantD (v@>0) c | 168 | adapt c v |
169 | | c == 0 = fromList[] | ||
170 | | dim v == c = v | ||
171 | | otherwise = constantD (v@>0) c | ||
166 | 172 | ||
167 | -- | extracts the rows of a matrix as a list of vectors | 173 | -- | extracts the rows of a matrix as a list of vectors |
168 | toRows :: Element t => Matrix t -> [Vector t] | 174 | toRows :: Element t => Matrix t -> [Vector t] |
169 | toRows m = toRows' 0 where | 175 | toRows m |
176 | | c == 0 = replicate r (fromList[]) | ||
177 | | otherwise = toRows' 0 | ||
178 | where | ||
170 | v = flatten m | 179 | v = flatten m |
171 | r = rows m | 180 | r = rows m |
172 | c = cols m | 181 | c = cols m |
@@ -200,7 +209,7 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | |||
200 | 209 | ||
201 | matrixFromVector o r c v | 210 | matrixFromVector o r c v |
202 | | r * c == dim v = m | 211 | | r * c == dim v = m |
203 | | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) | 212 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
204 | where | 213 | where |
205 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | 214 | m = Matrix { irows = r, icols = c, xdat = v, order = o } |
206 | 215 | ||
@@ -398,8 +407,8 @@ subMatrix :: Element a | |||
398 | -> Matrix a -- ^ input matrix | 407 | -> Matrix a -- ^ input matrix |
399 | -> Matrix a -- ^ result | 408 | -> Matrix a -- ^ result |
400 | subMatrix (r0,c0) (rt,ct) m | 409 | subMatrix (r0,c0) (rt,ct) m |
401 | | 0 <= r0 && 0 < rt && r0+rt <= (rows m) && | 410 | | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) && |
402 | 0 <= c0 && 0 < ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m | 411 | 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m |
403 | | otherwise = error $ "wrong subMatrix "++ | 412 | | otherwise = error $ "wrong subMatrix "++ |
404 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | 413 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) |
405 | 414 | ||
@@ -437,18 +446,21 @@ foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr | |||
437 | 446 | ||
438 | ---------------------------------------------------------------------- | 447 | ---------------------------------------------------------------------- |
439 | 448 | ||
449 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | ||
450 | |||
440 | conformMs ms = map (conformMTo (r,c)) ms | 451 | conformMs ms = map (conformMTo (r,c)) ms |
441 | where | 452 | where |
442 | r = maximum (map rows ms) | 453 | r = maxZ (map rows ms) |
443 | c = maximum (map cols ms) | 454 | c = maxZ (map cols ms) |
455 | |||
444 | 456 | ||
445 | conformVs vs = map (conformVTo n) vs | 457 | conformVs vs = map (conformVTo n) vs |
446 | where | 458 | where |
447 | n = maximum (map dim vs) | 459 | n = maxZ (map dim vs) |
448 | 460 | ||
449 | conformMTo (r,c) m | 461 | conformMTo (r,c) m |
450 | | size m == (r,c) = m | 462 | | size m == (r,c) = m |
451 | | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) | 463 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
452 | | size m == (r,1) = repCols c m | 464 | | size m == (r,1) = repCols c m |
453 | | size m == (1,c) = repRows r m | 465 | | size m == (1,c) = repRows r m |
454 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | 466 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" |
@@ -465,6 +477,8 @@ size m = (rows m, cols m) | |||
465 | 477 | ||
466 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | 478 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" |
467 | 479 | ||
480 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | ||
481 | |||
468 | ---------------------------------------------------------------------- | 482 | ---------------------------------------------------------------------- |
469 | 483 | ||
470 | instance (Storable t, NFData t) => NFData (Matrix t) | 484 | instance (Storable t, NFData t) => NFData (Matrix t) |
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index b92d60f..d94d167 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -35,7 +35,7 @@ module Data.Packed.Matrix ( | |||
35 | repmat, | 35 | repmat, |
36 | flipud, fliprl, | 36 | flipud, fliprl, |
37 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, | 37 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, |
38 | extractRows, | 38 | extractRows, extractColumns, |
39 | diagRect, takeDiag, | 39 | diagRect, takeDiag, |
40 | mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, | 40 | mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, |
41 | liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D | 41 | liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D |
@@ -104,6 +104,7 @@ breakAt c l = (a++[c],tail b) where | |||
104 | 104 | ||
105 | -- | creates a matrix from a vertical list of matrices | 105 | -- | creates a matrix from a vertical list of matrices |
106 | joinVert :: Element t => [Matrix t] -> Matrix t | 106 | joinVert :: Element t => [Matrix t] -> Matrix t |
107 | joinVert [] = emptyM 0 0 | ||
107 | joinVert ms = case common cols ms of | 108 | joinVert ms = case common cols ms of |
108 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" | 109 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" |
109 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) | 110 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) |
@@ -173,6 +174,11 @@ adaptBlocks ms = ms' where | |||
173 | 0 0 0 0 0 0 0 5 | 174 | 0 0 0 0 0 0 0 5 |
174 | 0 0 0 0 0 0 0 7 | 175 | 0 0 0 0 0 0 0 7 |
175 | 176 | ||
177 | >>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double | ||
178 | (2><7) | ||
179 | [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 | ||
180 | , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] | ||
181 | |||
176 | -} | 182 | -} |
177 | diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t | 183 | diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t |
178 | diagBlock ms = fromBlocks $ zipWith f ms [0..] | 184 | diagBlock ms = fromBlocks $ zipWith f ms [0..] |
@@ -186,11 +192,15 @@ diagBlock ms = fromBlocks $ zipWith f ms [0..] | |||
186 | 192 | ||
187 | -- | Reverse rows | 193 | -- | Reverse rows |
188 | flipud :: Element t => Matrix t -> Matrix t | 194 | flipud :: Element t => Matrix t -> Matrix t |
189 | flipud m = fromRows . reverse . toRows $ m | 195 | flipud m = extractRows [r-1,r-2 .. 0] $ m |
196 | where | ||
197 | r = rows m | ||
190 | 198 | ||
191 | -- | Reverse columns | 199 | -- | Reverse columns |
192 | fliprl :: Element t => Matrix t -> Matrix t | 200 | fliprl :: Element t => Matrix t -> Matrix t |
193 | fliprl m = fromColumns . reverse . toColumns $ m | 201 | fliprl m = extractColumns [c-1,c-2 .. 0] $ m |
202 | where | ||
203 | c = cols m | ||
194 | 204 | ||
195 | ------------------------------------------------------------ | 205 | ------------------------------------------------------------ |
196 | 206 | ||
@@ -327,8 +337,25 @@ fromArray2D m = (r><c) (elems m) | |||
327 | 337 | ||
328 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | 338 | -- | rearranges the rows of a matrix according to the order given in a list of integers. |
329 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t | 339 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t |
340 | extractRows [] m = emptyM 0 (cols m) | ||
330 | extractRows l m = fromRows $ extract (toRows m) l | 341 | extractRows l m = fromRows $ extract (toRows m) l |
331 | where extract l' is = [l'!!i |i<-is] | 342 | where |
343 | extract l' is = [l'!!i | i<- map verify is] | ||
344 | verify k | ||
345 | | k >= 0 && k < rows m = k | ||
346 | | otherwise = error $ "can't extract row " | ||
347 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
348 | |||
349 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
350 | extractColumns :: Element t => [Int] -> Matrix t -> Matrix t | ||
351 | extractColumns l m = trans . extractRows (map verify l) . trans $ m | ||
352 | where | ||
353 | verify k | ||
354 | | k >= 0 && k < cols m = k | ||
355 | | otherwise = error $ "can't extract column " | ||
356 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
357 | |||
358 | |||
332 | 359 | ||
333 | {- | creates matrix by repetition of a matrix a given number of rows and columns | 360 | {- | creates matrix by repetition of a matrix a given number of rows and columns |
334 | 361 | ||
@@ -341,7 +368,9 @@ extractRows l m = fromRows $ extract (toRows m) l | |||
341 | 368 | ||
342 | -} | 369 | -} |
343 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t | 370 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t |
344 | repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m | 371 | repmat m r c |
372 | | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) | ||
373 | | otherwise = fromBlocks $ replicate r $ replicate c $ m | ||
345 | 374 | ||
346 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | 375 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. |
347 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 376 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
@@ -390,7 +419,10 @@ toBlocks rs cs m = map (toBlockCols cs) . toBlockRows rs $ m | |||
390 | -- | Fully partition a matrix into blocks of the same size. If the dimensions are not | 419 | -- | Fully partition a matrix into blocks of the same size. If the dimensions are not |
391 | -- a multiple of the given size the last blocks will be smaller. | 420 | -- a multiple of the given size the last blocks will be smaller. |
392 | toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] | 421 | toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] |
393 | toBlocksEvery r c m = toBlocks rs cs m where | 422 | toBlocksEvery r c m |
423 | | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c | ||
424 | | otherwise = toBlocks rs cs m | ||
425 | where | ||
394 | (qr,rr) = rows m `divMod` r | 426 | (qr,rr) = rows m `divMod` r |
395 | (qc,rc) = cols m `divMod` c | 427 | (qc,rc) = cols m `divMod` c |
396 | rs = replicate qr r ++ if rr > 0 then [rr] else [] | 428 | rs = replicate qr r ++ if rr > 0 then [rr] else [] |