diff options
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Common.hs | 6 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 38 |
2 files changed, 31 insertions, 13 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index 49f17b0..edef3c2 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs | |||
@@ -49,7 +49,11 @@ common f = commonval . map f where | |||
49 | compatdim :: [Int] -> Maybe Int | 49 | compatdim :: [Int] -> Maybe Int |
50 | compatdim [] = Nothing | 50 | compatdim [] = Nothing |
51 | compatdim [a] = Just a | 51 | compatdim [a] = Just a |
52 | compatdim (a:b:xs) = if a==b || a==1 || b==1 then compatdim (max a b:xs) else Nothing | 52 | compatdim (a:b:xs) |
53 | | a==b = compatdim (b:xs) | ||
54 | | a==1 = compatdim (b:xs) | ||
55 | | b==1 = compatdim (a:xs) | ||
56 | | otherwise = Nothing | ||
53 | 57 | ||
54 | -- | Formatting tool | 58 | -- | Formatting tool |
55 | table :: String -> [[String]] -> String | 59 | table :: String -> [[String]] -> String |
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2004e85..9719fc0 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -32,6 +32,7 @@ module Data.Packed.Internal.Matrix( | |||
32 | (@@>), atM', | 32 | (@@>), atM', |
33 | saveMatrix, | 33 | saveMatrix, |
34 | singleton, | 34 | singleton, |
35 | emptyM, | ||
35 | size, shSize, conformVs, conformMs, conformVTo, conformMTo | 36 | size, shSize, conformVs, conformMs, conformVTo, conformMTo |
36 | ) where | 37 | ) where |
37 | 38 | ||
@@ -157,16 +158,24 @@ toLists m = splitEvery (cols m) . toList . flatten $ m | |||
157 | -- All vectors must have the same dimension, | 158 | -- All vectors must have the same dimension, |
158 | -- or dimension 1, which is are automatically expanded. | 159 | -- or dimension 1, which is are automatically expanded. |
159 | fromRows :: Element t => [Vector t] -> Matrix t | 160 | fromRows :: Element t => [Vector t] -> Matrix t |
161 | fromRows [] = emptyM 0 0 | ||
160 | fromRows vs = case compatdim (map dim vs) of | 162 | fromRows vs = case compatdim (map dim vs) of |
161 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | 163 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) |
162 | Just c -> reshape c . vjoin . map (adapt c) $ vs | 164 | Just 0 -> emptyM r 0 |
165 | Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs | ||
163 | where | 166 | where |
164 | adapt c v | dim v == c = v | 167 | r = length vs |
165 | | otherwise = constantD (v@>0) c | 168 | adapt c v |
169 | | c == 0 = fromList[] | ||
170 | | dim v == c = v | ||
171 | | otherwise = constantD (v@>0) c | ||
166 | 172 | ||
167 | -- | extracts the rows of a matrix as a list of vectors | 173 | -- | extracts the rows of a matrix as a list of vectors |
168 | toRows :: Element t => Matrix t -> [Vector t] | 174 | toRows :: Element t => Matrix t -> [Vector t] |
169 | toRows m = toRows' 0 where | 175 | toRows m |
176 | | c == 0 = replicate r (fromList[]) | ||
177 | | otherwise = toRows' 0 | ||
178 | where | ||
170 | v = flatten m | 179 | v = flatten m |
171 | r = rows m | 180 | r = rows m |
172 | c = cols m | 181 | c = cols m |
@@ -200,7 +209,7 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | |||
200 | 209 | ||
201 | matrixFromVector o r c v | 210 | matrixFromVector o r c v |
202 | | r * c == dim v = m | 211 | | r * c == dim v = m |
203 | | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) | 212 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
204 | where | 213 | where |
205 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | 214 | m = Matrix { irows = r, icols = c, xdat = v, order = o } |
206 | 215 | ||
@@ -398,8 +407,8 @@ subMatrix :: Element a | |||
398 | -> Matrix a -- ^ input matrix | 407 | -> Matrix a -- ^ input matrix |
399 | -> Matrix a -- ^ result | 408 | -> Matrix a -- ^ result |
400 | subMatrix (r0,c0) (rt,ct) m | 409 | subMatrix (r0,c0) (rt,ct) m |
401 | | 0 <= r0 && 0 < rt && r0+rt <= (rows m) && | 410 | | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) && |
402 | 0 <= c0 && 0 < ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m | 411 | 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m |
403 | | otherwise = error $ "wrong subMatrix "++ | 412 | | otherwise = error $ "wrong subMatrix "++ |
404 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | 413 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) |
405 | 414 | ||
@@ -437,18 +446,21 @@ foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr | |||
437 | 446 | ||
438 | ---------------------------------------------------------------------- | 447 | ---------------------------------------------------------------------- |
439 | 448 | ||
449 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | ||
450 | |||
440 | conformMs ms = map (conformMTo (r,c)) ms | 451 | conformMs ms = map (conformMTo (r,c)) ms |
441 | where | 452 | where |
442 | r = maximum (map rows ms) | 453 | r = maxZ (map rows ms) |
443 | c = maximum (map cols ms) | 454 | c = maxZ (map cols ms) |
455 | |||
444 | 456 | ||
445 | conformVs vs = map (conformVTo n) vs | 457 | conformVs vs = map (conformVTo n) vs |
446 | where | 458 | where |
447 | n = maximum (map dim vs) | 459 | n = maxZ (map dim vs) |
448 | 460 | ||
449 | conformMTo (r,c) m | 461 | conformMTo (r,c) m |
450 | | size m == (r,c) = m | 462 | | size m == (r,c) = m |
451 | | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) | 463 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
452 | | size m == (r,1) = repCols c m | 464 | | size m == (r,1) = repCols c m |
453 | | size m == (1,c) = repRows r m | 465 | | size m == (1,c) = repRows r m |
454 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | 466 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" |
@@ -465,6 +477,8 @@ size m = (rows m, cols m) | |||
465 | 477 | ||
466 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | 478 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" |
467 | 479 | ||
480 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | ||
481 | |||
468 | ---------------------------------------------------------------------- | 482 | ---------------------------------------------------------------------- |
469 | 483 | ||
470 | instance (Storable t, NFData t) => NFData (Matrix t) | 484 | instance (Storable t, NFData t) => NFData (Matrix t) |