summaryrefslogtreecommitdiff
path: root/lib/Data/Packed
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r--lib/Data/Packed/Internal/Common.hs6
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs38
-rw-r--r--lib/Data/Packed/Matrix.hs44
3 files changed, 69 insertions, 19 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs
index 49f17b0..edef3c2 100644
--- a/lib/Data/Packed/Internal/Common.hs
+++ b/lib/Data/Packed/Internal/Common.hs
@@ -49,7 +49,11 @@ common f = commonval . map f where
49compatdim :: [Int] -> Maybe Int 49compatdim :: [Int] -> Maybe Int
50compatdim [] = Nothing 50compatdim [] = Nothing
51compatdim [a] = Just a 51compatdim [a] = Just a
52compatdim (a:b:xs) = if a==b || a==1 || b==1 then compatdim (max a b:xs) else Nothing 52compatdim (a:b:xs)
53 | a==b = compatdim (b:xs)
54 | a==1 = compatdim (b:xs)
55 | b==1 = compatdim (a:xs)
56 | otherwise = Nothing
53 57
54-- | Formatting tool 58-- | Formatting tool
55table :: String -> [[String]] -> String 59table :: String -> [[String]] -> String
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 2004e85..9719fc0 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -32,6 +32,7 @@ module Data.Packed.Internal.Matrix(
32 (@@>), atM', 32 (@@>), atM',
33 saveMatrix, 33 saveMatrix,
34 singleton, 34 singleton,
35 emptyM,
35 size, shSize, conformVs, conformMs, conformVTo, conformMTo 36 size, shSize, conformVs, conformMs, conformVTo, conformMTo
36) where 37) where
37 38
@@ -157,16 +158,24 @@ toLists m = splitEvery (cols m) . toList . flatten $ m
157-- All vectors must have the same dimension, 158-- All vectors must have the same dimension,
158-- or dimension 1, which is are automatically expanded. 159-- or dimension 1, which is are automatically expanded.
159fromRows :: Element t => [Vector t] -> Matrix t 160fromRows :: Element t => [Vector t] -> Matrix t
161fromRows [] = emptyM 0 0
160fromRows vs = case compatdim (map dim vs) of 162fromRows vs = case compatdim (map dim vs) of
161 Nothing -> error "fromRows applied to [] or to vectors with different sizes" 163 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
162 Just c -> reshape c . vjoin . map (adapt c) $ vs 164 Just 0 -> emptyM r 0
165 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
163 where 166 where
164 adapt c v | dim v == c = v 167 r = length vs
165 | otherwise = constantD (v@>0) c 168 adapt c v
169 | c == 0 = fromList[]
170 | dim v == c = v
171 | otherwise = constantD (v@>0) c
166 172
167-- | extracts the rows of a matrix as a list of vectors 173-- | extracts the rows of a matrix as a list of vectors
168toRows :: Element t => Matrix t -> [Vector t] 174toRows :: Element t => Matrix t -> [Vector t]
169toRows m = toRows' 0 where 175toRows m
176 | c == 0 = replicate r (fromList[])
177 | otherwise = toRows' 0
178 where
170 v = flatten m 179 v = flatten m
171 r = rows m 180 r = rows m
172 c = cols m 181 c = cols m
@@ -200,7 +209,7 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
200 209
201matrixFromVector o r c v 210matrixFromVector o r c v
202 | r * c == dim v = m 211 | r * c == dim v = m
203 | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) 212 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
204 where 213 where
205 m = Matrix { irows = r, icols = c, xdat = v, order = o } 214 m = Matrix { irows = r, icols = c, xdat = v, order = o }
206 215
@@ -398,8 +407,8 @@ subMatrix :: Element a
398 -> Matrix a -- ^ input matrix 407 -> Matrix a -- ^ input matrix
399 -> Matrix a -- ^ result 408 -> Matrix a -- ^ result
400subMatrix (r0,c0) (rt,ct) m 409subMatrix (r0,c0) (rt,ct) m
401 | 0 <= r0 && 0 < rt && r0+rt <= (rows m) && 410 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
402 0 <= c0 && 0 < ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m 411 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
403 | otherwise = error $ "wrong subMatrix "++ 412 | otherwise = error $ "wrong subMatrix "++
404 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) 413 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
405 414
@@ -437,18 +446,21 @@ foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr
437 446
438---------------------------------------------------------------------- 447----------------------------------------------------------------------
439 448
449maxZ xs = if minimum xs == 0 then 0 else maximum xs
450
440conformMs ms = map (conformMTo (r,c)) ms 451conformMs ms = map (conformMTo (r,c)) ms
441 where 452 where
442 r = maximum (map rows ms) 453 r = maxZ (map rows ms)
443 c = maximum (map cols ms) 454 c = maxZ (map cols ms)
455
444 456
445conformVs vs = map (conformVTo n) vs 457conformVs vs = map (conformVTo n) vs
446 where 458 where
447 n = maximum (map dim vs) 459 n = maxZ (map dim vs)
448 460
449conformMTo (r,c) m 461conformMTo (r,c) m
450 | size m == (r,c) = m 462 | size m == (r,c) = m
451 | size m == (1,1) = reshape c (constantD (m@@>(0,0)) (r*c)) 463 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
452 | size m == (r,1) = repCols c m 464 | size m == (r,1) = repCols c m
453 | size m == (1,c) = repRows r m 465 | size m == (1,c) = repRows r m
454 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" 466 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
@@ -465,6 +477,8 @@ size m = (rows m, cols m)
465 477
466shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" 478shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
467 479
480emptyM r c = matrixFromVector RowMajor r c (fromList[])
481
468---------------------------------------------------------------------- 482----------------------------------------------------------------------
469 483
470instance (Storable t, NFData t) => NFData (Matrix t) 484instance (Storable t, NFData t) => NFData (Matrix t)
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index b92d60f..d94d167 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -35,7 +35,7 @@ module Data.Packed.Matrix (
35 repmat, 35 repmat,
36 flipud, fliprl, 36 flipud, fliprl,
37 subMatrix, takeRows, dropRows, takeColumns, dropColumns, 37 subMatrix, takeRows, dropRows, takeColumns, dropColumns,
38 extractRows, 38 extractRows, extractColumns,
39 diagRect, takeDiag, 39 diagRect, takeDiag,
40 mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, 40 mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_,
41 liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D 41 liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D
@@ -104,6 +104,7 @@ breakAt c l = (a++[c],tail b) where
104 104
105-- | creates a matrix from a vertical list of matrices 105-- | creates a matrix from a vertical list of matrices
106joinVert :: Element t => [Matrix t] -> Matrix t 106joinVert :: Element t => [Matrix t] -> Matrix t
107joinVert [] = emptyM 0 0
107joinVert ms = case common cols ms of 108joinVert ms = case common cols ms of
108 Nothing -> error "(impossible) joinVert on matrices with different number of columns" 109 Nothing -> error "(impossible) joinVert on matrices with different number of columns"
109 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) 110 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms)
@@ -173,6 +174,11 @@ adaptBlocks ms = ms' where
1730 0 0 0 0 0 0 5 1740 0 0 0 0 0 0 5
1740 0 0 0 0 0 0 7 1750 0 0 0 0 0 0 7
175 176
177>>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double
178(2><7)
179 [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0
180 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ]
181
176-} 182-}
177diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t 183diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t
178diagBlock ms = fromBlocks $ zipWith f ms [0..] 184diagBlock ms = fromBlocks $ zipWith f ms [0..]
@@ -186,11 +192,15 @@ diagBlock ms = fromBlocks $ zipWith f ms [0..]
186 192
187-- | Reverse rows 193-- | Reverse rows
188flipud :: Element t => Matrix t -> Matrix t 194flipud :: Element t => Matrix t -> Matrix t
189flipud m = fromRows . reverse . toRows $ m 195flipud m = extractRows [r-1,r-2 .. 0] $ m
196 where
197 r = rows m
190 198
191-- | Reverse columns 199-- | Reverse columns
192fliprl :: Element t => Matrix t -> Matrix t 200fliprl :: Element t => Matrix t -> Matrix t
193fliprl m = fromColumns . reverse . toColumns $ m 201fliprl m = extractColumns [c-1,c-2 .. 0] $ m
202 where
203 c = cols m
194 204
195------------------------------------------------------------ 205------------------------------------------------------------
196 206
@@ -327,8 +337,25 @@ fromArray2D m = (r><c) (elems m)
327 337
328-- | rearranges the rows of a matrix according to the order given in a list of integers. 338-- | rearranges the rows of a matrix according to the order given in a list of integers.
329extractRows :: Element t => [Int] -> Matrix t -> Matrix t 339extractRows :: Element t => [Int] -> Matrix t -> Matrix t
340extractRows [] m = emptyM 0 (cols m)
330extractRows l m = fromRows $ extract (toRows m) l 341extractRows l m = fromRows $ extract (toRows m) l
331 where extract l' is = [l'!!i |i<-is] 342 where
343 extract l' is = [l'!!i | i<- map verify is]
344 verify k
345 | k >= 0 && k < rows m = k
346 | otherwise = error $ "can't extract row "
347 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
348
349-- | rearranges the rows of a matrix according to the order given in a list of integers.
350extractColumns :: Element t => [Int] -> Matrix t -> Matrix t
351extractColumns l m = trans . extractRows (map verify l) . trans $ m
352 where
353 verify k
354 | k >= 0 && k < cols m = k
355 | otherwise = error $ "can't extract column "
356 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
357
358
332 359
333{- | creates matrix by repetition of a matrix a given number of rows and columns 360{- | creates matrix by repetition of a matrix a given number of rows and columns
334 361
@@ -341,7 +368,9 @@ extractRows l m = fromRows $ extract (toRows m) l
341 368
342-} 369-}
343repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t 370repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t
344repmat m r c = fromBlocks $ splitEvery c $ replicate (r*c) m 371repmat m r c
372 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m)
373 | otherwise = fromBlocks $ replicate r $ replicate c $ m
345 374
346-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. 375-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
347liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 376liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
@@ -390,7 +419,10 @@ toBlocks rs cs m = map (toBlockCols cs) . toBlockRows rs $ m
390-- | Fully partition a matrix into blocks of the same size. If the dimensions are not 419-- | Fully partition a matrix into blocks of the same size. If the dimensions are not
391-- a multiple of the given size the last blocks will be smaller. 420-- a multiple of the given size the last blocks will be smaller.
392toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] 421toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]]
393toBlocksEvery r c m = toBlocks rs cs m where 422toBlocksEvery r c m
423 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c
424 | otherwise = toBlocks rs cs m
425 where
394 (qr,rr) = rows m `divMod` r 426 (qr,rr) = rows m `divMod` r
395 (qc,rc) = cols m `divMod` c 427 (qc,rc) = cols m `divMod` c
396 rs = replicate qr r ++ if rr > 0 then [rr] else [] 428 rs = replicate qr r ++ if rr > 0 then [rr] else []