diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-06 08:50:50 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-06 08:50:50 +0200 |
commit | c9914d694d3b86ece46fa0c76e0466c6cd394d14 (patch) | |
tree | 7fa1c5a95b204912f5d560c843ae6045ee8d2780 /lib/Data/Packed/Internal/Matrix.hs | |
parent | 4078cf44c98b42960be27843782f6983bb66017f (diff) |
extend conformability to empty arrays
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 38 |
1 files changed, 26 insertions, 12 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2004e85..9719fc0 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -32,6 +32,7 @@ module Data.Packed.Internal.Matrix( | |||
32 | (@@>), atM', | 32 | (@@>), atM', |
33 | saveMatrix, | 33 | saveMatrix, |
34 | singleton, | 34 | singleton, |
35 | emptyM, | ||
35 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | 36 | size, shSize, conformVs, conformMs, conformVTo, conformMTo |
36 | ) where | 37 | ) where |
37 | 38 | ||
@@ -157,16 +158,24 @@ toLists m = splitEvery (cols m) . toList . flatten $ m | |||
157 | -- All vectors must have the same dimension, | 158 | -- All vectors must have the same dimension, |
158 | -- or dimension 1, which is are automatically expanded. | 159 | -- or dimension 1, which is are automatically expanded. |
159 | fromRows :: Element t => [Vector t] -> Matrix t | 160 | fromRows :: Element t => [Vector t] -> Matrix t |
161 | fromRows [] = emptyM 0 0 | ||
160 | fromRows vs = case compatdim (map dim vs) of | 162 | fromRows vs = case compatdim (map dim vs) of |
161 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | 163 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) |
162 | Just c -> reshape c . vjoin . map (adapt c) $ vs | 164 | Just 0 -> emptyM r 0 |
165 | Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs | ||
163 | where | 166 | where |
164 | adapt c v | dim v == c = v | 167 | r = length vs |
165 | | otherwise = constantD (v@>0) c | 168 | adapt c v |
169 | | c == 0 = fromList[] | ||
170 | | dim v == c = v | ||
171 | | otherwise = constantD (v@>0) c | ||
166 | 172 | ||
167 | -- | extracts the rows of a matrix as a list of vectors | 173 | -- | extracts the rows of a matrix as a list of vectors |
168 | toRows :: Element t => Matrix t -> [Vector t] | 174 | toRows :: Element t => Matrix t -> [Vector t] |
169 | toRows m = toRows' 0 where | 175 | toRows m |
176 | | c == 0 = replicate r (fromList[]) | ||
177 | | otherwise = toRows' 0 | ||
178 | where | ||
170 | v = flatten m | 179 | v = flatten m |
171 | r = rows m | 180 | r = rows m |
172 | c = cols m | 181 | c = cols m |
@@ -200,7 +209,7 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | |||
200 | 209 | ||
201 | matrixFromVector o r c v | 210 | matrixFromVector o r c v |
202 | | r * c == dim v = m | 211 | | r * c == dim v = m |
203 | | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) | 212 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
204 | where | 213 | where |
205 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | 214 | m = Matrix { irows = r, icols = c, xdat = v, order = o } |
206 | 215 | ||
@@ -398,8 +407,8 @@ subMatrix :: Element a | |||
398 | -> Matrix a -- ^ input matrix | 407 | -> Matrix a -- ^ input matrix |
399 | -> Matrix a -- ^ result | 408 | -> Matrix a -- ^ result |
400 | subMatrix (r0,c0) (rt,ct) m | 409 | subMatrix (r0,c0) (rt,ct) m |
401 | | 0 <= r0 && 0 < rt && r0+rt <= (rows m) && | 410 | | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) && |
402 | 0 <= c0 && 0 < ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m | 411 | 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m |
403 | | otherwise = error $ "wrong subMatrix "++ | 412 | | otherwise = error $ "wrong subMatrix "++ |
404 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | 413 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) |
405 | 414 | ||
@@ -437,18 +446,21 @@ foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr | |||
437 | 446 | ||
438 | ---------------------------------------------------------------------- | 447 | ---------------------------------------------------------------------- |
439 | 448 | ||
449 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | ||
450 | |||
440 | conformMs ms = map (conformMTo (r,c)) ms | 451 | conformMs ms = map (conformMTo (r,c)) ms |
441 | where | 452 | where |
442 | r = maximum (map rows ms) | 453 | r = maxZ (map rows ms) |
443 | c = maximum (map cols ms) | 454 | c = maxZ (map cols ms) |
455 | |||
444 | 456 | ||
445 | conformVs vs = map (conformVTo n) vs | 457 | conformVs vs = map (conformVTo n) vs |
446 | where | 458 | where |
447 | n = maximum (map dim vs) | 459 | n = maxZ (map dim vs) |
448 | 460 | ||
449 | conformMTo (r,c) m | 461 | conformMTo (r,c) m |
450 | | size m == (r,c) = m | 462 | | size m == (r,c) = m |
451 | | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) | 463 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
452 | | size m == (r,1) = repCols c m | 464 | | size m == (r,1) = repCols c m |
453 | | size m == (1,c) = repRows r m | 465 | | size m == (1,c) = repRows r m |
454 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | 466 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" |
@@ -465,6 +477,8 @@ size m = (rows m, cols m) | |||
465 | 477 | ||
466 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | 478 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" |
467 | 479 | ||
480 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | ||
481 | |||
468 | ---------------------------------------------------------------------- | 482 | ---------------------------------------------------------------------- |
469 | 483 | ||
470 | instance (Storable t, NFData t) => NFData (Matrix t) | 484 | instance (Storable t, NFData t) => NFData (Matrix t) |