From 05908719a7323110ba1955038d8341a8b7483351 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 20 Sep 2010 17:08:34 +0000 Subject: generalized diagRect --- lib/Data/Packed/Internal/Matrix.hs | 5 ++-- lib/Data/Packed/Matrix.hs | 50 +++++++++++++------------------------- lib/Data/Packed/ST.hs | 8 +++--- 3 files changed, 23 insertions(+), 40 deletions(-) (limited to 'lib/Data/Packed') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index c0824a3..94b56cf 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -221,13 +221,13 @@ where r is the desired number of rows.) , 9.0, 10.0, 11.0, 12.0 ]@ -} -reshape :: Element t => Int -> Vector t -> Matrix t +reshape :: Storable t => Int -> Vector t -> Matrix t reshape c v = matrixFromVector RowMajor c v singleton x = reshape 1 (fromList [x]) -- | application of a vector function on the flattened matrix elements -liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b +liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) @@ -246,7 +246,6 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 ------------------------------------------------------------------ -- | Supported element types for basic matrix operations. ---class (Storable a, Floating a) => Element a where class (Storable a) => Element a where subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index b8c309c..ea16748 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -22,7 +22,7 @@ module Data.Packed.Matrix ( Element, Matrix,rows,cols, (><), - trans, ctrans, + trans, reshape, flatten, fromLists, toLists, buildMatrix, (@@>), @@ -33,7 +33,7 @@ module Data.Packed.Matrix ( flipud, fliprl, subMatrix, takeRows, dropRows, takeColumns, dropColumns, extractRows, - ident, diag, diagRect, takeDiag, + diagRect, takeDiag, liftMatrix, liftMatrix2, liftMatrix2Auto, dispf, disps, dispcf, vecdisp, latexFormat, format, loadMatrix, saveMatrix, fromFile, fileDimensions, @@ -169,28 +169,19 @@ fliprl m = fromColumns . reverse . toColumns $ m ------------------------------------------------------------ --- | Creates a square matrix with a given diagonal. -diag :: (Num a, Element a) => Vector a -> Matrix a -diag v = ST.runSTMatrix $ do - let d = dim v - m <- ST.newMatrix 0 d d - mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] - return m +{- | creates a rectangular diagonal matrix: -{- | creates a rectangular diagonal matrix - -@> diagRect (constant 5 3) 3 4 :: Matrix Double -(3><4) - [ 5.0, 0.0, 0.0, 0.0 - , 0.0, 5.0, 0.0, 0.0 - , 0.0, 0.0, 5.0, 0.0 ]@ +@> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double +(4><5) + [ 10.0, 7.0, 7.0, 7.0, 7.0 + , 7.0, 20.0, 7.0, 7.0, 7.0 + , 7.0, 7.0, 30.0, 7.0, 7.0 + , 7.0, 7.0, 7.0, 7.0, 7.0 ]@ -} -diagRect :: (Element t, Num t) => Vector t -> Int -> Int -> Matrix t -diagRect v r c - | dim v < min r c = error "diagRect called with dim v < min r c" - | otherwise = ST.runSTMatrix $ do - m <- ST.newMatrix 0 r c - let d = min r c +diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t +diagRect z v r c = ST.runSTMatrix $ do + m <- ST.newMatrix z r c + let d = min r c `min` (dim v) mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] return m @@ -198,10 +189,6 @@ diagRect v r c takeDiag :: (Element t) => Matrix t -> Vector t takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] --- | creates the identity matrix of given dimension -ident :: (Num a, Element a) => Int -> Matrix a -ident n = diag (constantD 1 n) - ------------------------------------------------------------ {- | An easy way to create a matrix: @@ -225,7 +212,7 @@ Example: , 4.0, 5.0, 6.0 ]@ -} -(><) :: (Element a) => Int -> Int -> [a] -> Matrix a +(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a r >< c = f where f l | dim v == r*c = matrixFromVector RowMajor c v | otherwise = error $ "inconsistent list size = " @@ -261,16 +248,13 @@ fromLists :: Element t => [[t]] -> Matrix t fromLists = fromRows . map fromList -- | creates a 1-row matrix from a vector -asRow :: Element a => Vector a -> Matrix a +asRow :: Storable a => Vector a -> Matrix a asRow v = reshape (dim v) v -- | creates a 1-column matrix from a vector -asColumn :: Element a => Vector a -> Matrix a +asColumn :: Storable a => Vector a -> Matrix a asColumn v = reshape 1 v --- | conjugate transpose -ctrans :: Element e => Matrix e -> Matrix e -ctrans = liftMatrix conjugateD . trans {- | creates a Matrix of the specified size using the supplied function to @@ -289,7 +273,7 @@ buildMatrix rc cc f = ----------------------------------------------------- -fromArray2D :: (Element e) => Array (Int, Int) e -> Matrix e +fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e fromArray2D m = (r> STVector s t -> Int -> t -> ST s () writeVector = safeIndexV unsafeWriteVector {-# NOINLINE newUndefinedVector #-} -newUndefinedVector :: Element t => Int -> ST s (STVector s t) +newUndefinedVector :: Storable t => Int -> ST s (STVector s t) newUndefinedVector = unsafeIOToST . fmap STVector . createVector {-# INLINE newVector #-} -newVector :: Element t => t -> Int -> ST s (STVector s t) +newVector :: Storable t => t -> Int -> ST s (STVector s t) newVector x n = do v <- newUndefinedVector n let go (-1) = return v @@ -164,9 +164,9 @@ writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () writeMatrix = safeIndexM unsafeWriteMatrix {-# NOINLINE newUndefinedMatrix #-} -newUndefinedMatrix :: Element t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) +newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) newUndefinedMatrix order r c = unsafeIOToST $ fmap STMatrix $ createMatrix order r c {-# NOINLINE newMatrix #-} -newMatrix :: Element t => t -> Int -> Int -> ST s (STMatrix s t) +newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) -- cgit v1.2.3