From 624046d6b55d37104f950e8888ab68c53a2e6bf0 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 24 Jun 2015 17:07:29 +0200 Subject: initial support of sliceMatrix, remove transdata --- packages/base/src/Internal/Matrix.hs | 190 ++++++++++++++--------------------- 1 file changed, 77 insertions(+), 113 deletions(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index f76b9dc..bdf2785 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -32,49 +32,13 @@ import Foreign.C.Types ( CInt(..) ) import Foreign.C.String ( CString, newCString ) import System.IO.Unsafe ( unsafePerformIO ) import Control.DeepSeq ( NFData(..) ) -import Data.List.Split(chunksOf) +import Text.Printf ----------------------------------------------------------------- -{- Design considerations for the Matrix Type - ----------------------------------------- - -- we must easily handle both row major and column major order, - for bindings to LAPACK and GSL/C - -- we'd like to simplify redundant matrix transposes: - - Some of them arise from the order requirements of some functions - - some functions (matrix product) admit transposed arguments - -- maybe we don't really need this kind of simplification: - - more complex code - - some computational overhead - - only appreciable gain in code with a lot of redundant transpositions - and cheap matrix computations - -- we could carry both the matrix and its (lazily computed) transpose. - This may save some transpositions, but it is necessary to keep track of the - data which is actually computed to be used by functions like the matrix product - which admit both orders. - -- but if we need the transposed data and it is not in the structure, we must make - sure that we touch the same foreignptr that is used in the computation. - -- a reasonable solution is using two constructors for a matrix. Transposition just - "flips" the constructor. Actual data transposition is not done if followed by a - matrix product or another transpose. - --} - data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) -transOrder RowMajor = ColumnMajor -transOrder ColumnMajor = RowMajor -{- | Matrix representation suitable for BLAS\/LAPACK computations. - -The elements are stored in a continuous memory array. - --} +-- | Matrix representation suitable for BLAS\/LAPACK computations. data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int @@ -83,8 +47,6 @@ data Matrix t = Matrix , xCol :: {-# UNPACK #-} !Int , xdat :: {-# UNPACK #-} !(Vector t) } --- RowMajor: preferred by C, fdat may require a transposition --- ColumnMajor: preferred by LAPACK, cdat may require a transposition rows :: Matrix t -> Int @@ -95,32 +57,55 @@ cols :: Matrix t -> Int cols = icols {-# INLINE cols #-} -rowOrder m = xRow m > 1 +size m = (irows m, icols m) +{-# INLINE size #-} + +rowOrder m = xCol m == 1 || cols m == 1 {-# INLINE rowOrder #-} -isSlice m = cols m < xRow m || rows m < xCol m +colOrder m = xRow m == 1 || rows m == 1 +{-# INLINE colOrder #-} + +is1d (size->(r,c)) = r==1 || c==1 +{-# INLINE is1d #-} + +-- data is not contiguous +isSlice m@(size->(r,c)) = (c < xRow m || r < xCol m) && min r c > 1 {-# INLINE isSlice #-} orderOf :: Matrix t -> MatrixOrder orderOf m = if rowOrder m then RowMajor else ColumnMajor +showInternal :: Storable t => Matrix t -> IO () +showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv + where + r = rows m + c = cols m + xr = xRow m + xc = xCol m + slc = if isSlice m then "slice" else "full" + ord = if is1d m then "1d" else if rowOrder m then "rows" else "cols" + dv = dim (xdat m) + +-------------------------------------------------------------------------------- + -- | Matrix transpose. trans :: Matrix t -> Matrix t -trans m@Matrix { irows = r, icols = c } | rowOrder m = - m { irows = c, icols = r, xRow = 1, xCol = c } -trans m@Matrix { irows = r, icols = c } = - m { irows = c, icols = r, xRow = r, xCol = 1 } +trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = + m { irows = c, icols = r, xRow = xc, xCol = xr } + cmat :: (Element t) => Matrix t -> Matrix t -cmat m | rowOrder m = m -cmat m@Matrix { irows = r, icols = c, xdat = d } = - m { xdat = transdata r d c, xRow = c, xCol = 1 } +cmat m + | rowOrder m = m + | otherwise = extractAll RowMajor m + fmat :: (Element t) => Matrix t -> Matrix t -fmat m | not (rowOrder m) = m -fmat m@Matrix { irows = r, icols = c, xdat = d} = - m { xdat = transdata c d r, xRow = 1, xCol = r } +fmat m + | colOrder m = m + | otherwise = extractAll ColumnMajor m -- C-Haskell matrix adapters @@ -157,6 +142,11 @@ a # b = apply a b -------------------------------------------------------------------------------- +extractAll ord m = unsafePerformIO $ + extractR ord m + 0 (idxs[0,rows m-1]) + 0 (idxs[0,cols m-1]) + {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. >>> flatten (ident 3) @@ -164,12 +154,14 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] -} flatten :: Element t => Matrix t -> Vector t -flatten = xdat . cmat +flatten m + | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) + | otherwise = xdat m -- | the inverse of 'Data.Packed.Matrix.fromLists' toLists :: (Element t) => Matrix t -> [[t]] -toLists m = chunksOf (cols m) . toList . flatten $ m +toLists = map toList . toRows @@ -205,6 +197,14 @@ fromRows vs = case compatdim (map dim vs) of -- | extracts the rows of a matrix as a list of vectors toRows :: Element t => Matrix t -> [Vector t] toRows m + | rowOrder m = map sub rowRange + | otherwise = map ext rowRange + where + rowRange = [0..rows m-1] + sub k = subVector (k*xRow m) (cols m) (xdat m) + ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) + +{- | c == 0 = replicate r (fromList[]) | otherwise = toRows' 0 where @@ -213,6 +213,7 @@ toRows m c = cols m toRows' k | k == r*c = [] | otherwise = subVector k c v : toRows' (k+c) +-} -- | Creates a matrix from a list of vectors, as columns fromColumns :: Element t => [Vector t] -> Matrix t @@ -240,7 +241,7 @@ matrixFromVector o r c v | r * c == dim v = m | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m where - m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } + m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } -- allocates memory for a new matrix @@ -263,31 +264,26 @@ reshape :: Storable t => Int -> Vector t -> Matrix t reshape 0 v = matrixFromVector RowMajor 0 0 v reshape c v = matrixFromVector RowMajor (dim v `div` c) c v ---singleton x = reshape 1 (fromList [x]) -- | application of a vector function on the flattened matrix elements -liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b -liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d) +liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b +liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} + | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) + | otherwise = matrixFromVector (orderOf m) r c (f d) -- | application of a vector function on the flattened matrices elements liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t -liftMatrix2 f m1 m2 - | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" - | otherwise = case orderOf m1 of - RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) - ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2)) - - -compat :: Matrix a -> Matrix b -> Bool -compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 +liftMatrix2 f m1@(size->(r,c)) m2 + | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" + | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) + | otherwise = matrixFromVector ColumnMajor r c (f (flatten (trans m1)) (flatten (trans m2))) ------------------------------------------------------------------ -- | Supported matrix elements. class (Storable a) => Element a where - transdata :: Int -> Vector a -> Int -> Vector a constantD :: a -> Int -> Vector a - extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) + extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () sortI :: Ord a => Vector a -> Vector CInt sortV :: Ord a => Vector a -> Vector a @@ -299,7 +295,6 @@ class (Storable a) => Element a where instance Element Float where - transdata = transdataAux ctransF constantD = constantAux cconstantF extractR = extractAux c_extractF setRect = setRectAux c_setRectF @@ -312,7 +307,6 @@ instance Element Float where gemm = gemmg c_gemmF instance Element Double where - transdata = transdataAux ctransR constantD = constantAux cconstantR extractR = extractAux c_extractD setRect = setRectAux c_setRectD @@ -325,7 +319,6 @@ instance Element Double where gemm = gemmg c_gemmD instance Element (Complex Float) where - transdata = transdataAux ctransQ constantD = constantAux cconstantQ extractR = extractAux c_extractQ setRect = setRectAux c_setRectQ @@ -338,7 +331,6 @@ instance Element (Complex Float) where gemm = gemmg c_gemmQ instance Element (Complex Double) where - transdata = transdataAux ctransC constantD = constantAux cconstantC extractR = extractAux c_extractC setRect = setRectAux c_setRectC @@ -351,7 +343,6 @@ instance Element (Complex Double) where gemm = gemmg c_gemmC instance Element (CInt) where - transdata = transdataAux ctransI constantD = constantAux cconstantI extractR = extractAux c_extractI setRect = setRectAux c_setRectI @@ -364,7 +355,6 @@ instance Element (CInt) where gemm = gemmg c_gemmI instance Element Z where - transdata = transdataAux ctransL constantD = constantAux cconstantL extractR = extractAux c_extractL setRect = setRectAux c_setRectL @@ -378,32 +368,6 @@ instance Element Z where ------------------------------------------------------------------- -transdataAux fun c1 d c2 = - if noneed - then d - else unsafePerformIO $ do - -- putStrLn "T" - v <- createVector (dim d) - unsafeWith d $ \pd -> - unsafeWith v $ \pv -> - fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux" - return v - where r1 = dim d `div` c1 - r2 = dim d `div` c2 - noneed = dim d == 0 || r1 == 1 || c1 == 1 - - -type TMM t = t ..> t ..> Ok - -foreign import ccall unsafe "transF" ctransF :: TMM Float -foreign import ccall unsafe "transR" ctransR :: TMM Double -foreign import ccall unsafe "transQ" ctransQ :: TMM (Complex Float) -foreign import ccall unsafe "transC" ctransC :: TMM (Complex Double) -foreign import ccall unsafe "transI" ctransI :: TMM CInt -foreign import ccall unsafe "transL" ctransL :: TMM Z - ----------------------------------------------------------------------- - subMatrix :: Element a => (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix @@ -411,9 +375,8 @@ subMatrix :: Element a -> Matrix a -- ^ result subMatrix (r0,c0) (rt,ct) m | 0 <= r0 && 0 <= rt && r0+rt <= rows m && - 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) - | otherwise = error $ "wrong subMatrix "++ - show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) + 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR RowMajor m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) + | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m sliceMatrix :: Element a @@ -424,11 +387,12 @@ sliceMatrix :: Element a sliceMatrix (r0,c0) (rt,ct) m | 0 <= r0 && 0 <= rt && r0+rt <= rows m && 0 <= c0 && 0 <= ct && c0+ct <= cols m = res - | otherwise = error $ "wrong sliceMatrix "++ - show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) + | otherwise = error $ "wrong sliceMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m where - t = r0 * xRow m + c0 * xCol m - res = m { irows = rt, icols = ct, xdat = subVector t (rt*ct) (xdat m) } + p = r0 * xRow m + c0 * xCol m + tot | rowOrder m = ct + (rt-1) * xRow m + | otherwise = rt + (ct-1) * xCol m + res = m { irows = rt, icols = ct, xdat = subVector p tot (xdat m) } -------------------------------------------------------------------------- @@ -449,7 +413,7 @@ conformMTo (r,c) m | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | size m == (r,1) = repCols c m | size m == (1,c) = repRows r m - | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" + | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) conformVTo n v | dim v == n = v @@ -459,9 +423,9 @@ conformVTo n v repRows n x = fromRows (replicate n (flatten x)) repCols n x = fromColumns (replicate n (flatten x)) -size m = (rows m, cols m) +shSize = shDim . size -shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" +shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" emptyM r c = matrixFromVector RowMajor r c (fromList[]) @@ -477,10 +441,10 @@ instance (Storable t, NFData t) => NFData (Matrix t) --------------------------------------------------------------- -extractAux f m moder vr modec vc = do +extractAux f ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc - r <- createMatrix RowMajor nr nc + r <- createMatrix ord nr nc f moder modec # vr # vc # m # r #|"extract" return r -- cgit v1.2.3