summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-06 08:50:50 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-06 08:50:50 +0200
commitc9914d694d3b86ece46fa0c76e0466c6cd394d14 (patch)
tree7fa1c5a95b204912f5d560c843ae6045ee8d2780 /lib/Data/Packed/Internal/Matrix.hs
parent4078cf44c98b42960be27843782f6983bb66017f (diff)
extend conformability to empty arrays
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs38
1 files changed, 26 insertions, 12 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 2004e85..9719fc0 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -32,6 +32,7 @@ module Data.Packed.Internal.Matrix(
32 (@@>), atM', 32 (@@>), atM',
33 saveMatrix, 33 saveMatrix,
34 singleton, 34 singleton,
35 emptyM,
35 size, shSize, conformVs, conformMs, conformVTo, conformMTo 36 size, shSize, conformVs, conformMs, conformVTo, conformMTo
36) where 37) where
37 38
@@ -157,16 +158,24 @@ toLists m = splitEvery (cols m) . toList . flatten $ m
157-- All vectors must have the same dimension, 158-- All vectors must have the same dimension,
158-- or dimension 1, which is are automatically expanded. 159-- or dimension 1, which is are automatically expanded.
159fromRows :: Element t => [Vector t] -> Matrix t 160fromRows :: Element t => [Vector t] -> Matrix t
161fromRows [] = emptyM 0 0
160fromRows vs = case compatdim (map dim vs) of 162fromRows vs = case compatdim (map dim vs) of
161 Nothing -> error "fromRows applied to [] or to vectors with different sizes" 163 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
162 Just c -> reshape c . vjoin . map (adapt c) $ vs 164 Just 0 -> emptyM r 0
165 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
163 where 166 where
164 adapt c v | dim v == c = v 167 r = length vs
165 | otherwise = constantD (v@>0) c 168 adapt c v
169 | c == 0 = fromList[]
170 | dim v == c = v
171 | otherwise = constantD (v@>0) c
166 172
167-- | extracts the rows of a matrix as a list of vectors 173-- | extracts the rows of a matrix as a list of vectors
168toRows :: Element t => Matrix t -> [Vector t] 174toRows :: Element t => Matrix t -> [Vector t]
169toRows m = toRows' 0 where 175toRows m
176 | c == 0 = replicate r (fromList[])
177 | otherwise = toRows' 0
178 where
170 v = flatten m 179 v = flatten m
171 r = rows m 180 r = rows m
172 c = cols m 181 c = cols m
@@ -200,7 +209,7 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
200 209
201matrixFromVector o r c v 210matrixFromVector o r c v
202 | r * c == dim v = m 211 | r * c == dim v = m
203 | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) 212 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
204 where 213 where
205 m = Matrix { irows = r, icols = c, xdat = v, order = o } 214 m = Matrix { irows = r, icols = c, xdat = v, order = o }
206 215
@@ -398,8 +407,8 @@ subMatrix :: Element a
398 -> Matrix a -- ^ input matrix 407 -> Matrix a -- ^ input matrix
399 -> Matrix a -- ^ result 408 -> Matrix a -- ^ result
400subMatrix (r0,c0) (rt,ct) m 409subMatrix (r0,c0) (rt,ct) m
401 | 0 <= r0 && 0 < rt && r0+rt <= (rows m) && 410 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
402 0 <= c0 && 0 < ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m 411 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
403 | otherwise = error $ "wrong subMatrix "++ 412 | otherwise = error $ "wrong subMatrix "++
404 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) 413 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
405 414
@@ -437,18 +446,21 @@ foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr
437 446
438---------------------------------------------------------------------- 447----------------------------------------------------------------------
439 448
449maxZ xs = if minimum xs == 0 then 0 else maximum xs
450
440conformMs ms = map (conformMTo (r,c)) ms 451conformMs ms = map (conformMTo (r,c)) ms
441 where 452 where
442 r = maximum (map rows ms) 453 r = maxZ (map rows ms)
443 c = maximum (map cols ms) 454 c = maxZ (map cols ms)
455
444 456
445conformVs vs = map (conformVTo n) vs 457conformVs vs = map (conformVTo n) vs
446 where 458 where
447 n = maximum (map dim vs) 459 n = maxZ (map dim vs)
448 460
449conformMTo (r,c) m 461conformMTo (r,c) m
450 | size m == (r,c) = m 462 | size m == (r,c) = m
451 | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) 463 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
452 | size m == (r,1) = repCols c m 464 | size m == (r,1) = repCols c m
453 | size m == (1,c) = repRows r m 465 | size m == (1,c) = repRows r m
454 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" 466 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
@@ -465,6 +477,8 @@ size m = (rows m, cols m)
465 477
466shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" 478shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
467 479
480emptyM r c = matrixFromVector RowMajor r c (fromList[])
481
468---------------------------------------------------------------------- 482----------------------------------------------------------------------
469 483
470instance (Storable t, NFData t) => NFData (Matrix t) 484instance (Storable t, NFData t) => NFData (Matrix t)