From 8053285df72177dab6b6d86241307d743fa0025f Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 22 Jun 2015 11:54:16 +0200 Subject: implicit rowOrder --- packages/base/src/Internal/Chain.hs | 2 +- packages/base/src/Internal/LAPACK.hs | 9 ++-- packages/base/src/Internal/Matrix.hs | 83 ++++++++++++++---------------------- packages/base/src/Internal/ST.hs | 10 ++--- 4 files changed, 43 insertions(+), 61 deletions(-) diff --git a/packages/base/src/Internal/Chain.hs b/packages/base/src/Internal/Chain.hs index fa518d1..f87eb02 100644 --- a/packages/base/src/Internal/Chain.hs +++ b/packages/base/src/Internal/Chain.hs @@ -22,7 +22,7 @@ module Internal.Chain ( import Data.Maybe -import Internal.Matrix hiding (order) +import Internal.Matrix import Internal.Numeric import qualified Data.Array.IArray as A diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 3a9abbb..fc9e3ad 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -1,4 +1,5 @@ {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} ----------------------------------------------------------------------------- -- | @@ -49,11 +50,11 @@ foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok -isT Matrix{order = ColumnMajor} = 0 -isT Matrix{order = RowMajor} = 1 +isT (rowOrder -> False) = 0 +isT _ = 1 -tt x@Matrix{order = ColumnMajor} = x -tt x@Matrix{order = RowMajor} = trans x +tt x@(rowOrder -> False) = x +tt x = trans x multiplyAux f st a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index db0a609..c0d1318 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -3,7 +3,9 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ViewPatterns #-} + -- | @@ -74,10 +76,14 @@ The elements are stored in a continuous memory array. -} -data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int - , icols :: {-# UNPACK #-} !Int - , xdat :: {-# UNPACK #-} !(Vector t) - , order :: !MatrixOrder } +data Matrix t = Matrix + { irows :: {-# UNPACK #-} !Int + , icols :: {-# UNPACK #-} !Int + , xRow :: {-# UNPACK #-} !CInt + , xCol :: {-# UNPACK #-} !CInt +-- , rowOrder :: {-# UNPACK #-} !Bool + , xdat :: {-# UNPACK #-} !(Vector t) + } -- RowMajor: preferred by C, fdat may require a transposition -- ColumnMajor: preferred by LAPACK, cdat may require a transposition @@ -88,49 +94,32 @@ rows = irows cols :: Matrix t -> Int cols = icols -orderOf :: Matrix t -> MatrixOrder -orderOf = order - -stepRow :: Matrix t -> CInt -stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c -stepRow _ = 1 +rowOrder m = xRow m > 1 +{-# INLINE rowOrder #-} -stepCol :: Matrix t -> CInt -stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r -stepCol _ = 1 +orderOf :: Matrix t -> MatrixOrder +orderOf m = if rowOrder m then RowMajor else ColumnMajor -- | Matrix transpose. trans :: Matrix t -> Matrix t -trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} +trans m@Matrix { irows = r, icols = c } | rowOrder m = + m { irows = c, icols = r, xRow = 1, xCol = fi c } +trans m@Matrix { irows = r, icols = c } = + m { irows = c, icols = r, xRow = fi r, xCol = 1 } cmat :: (Element t) => Matrix t -> Matrix t -cmat m@Matrix{order = RowMajor} = m -cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} +cmat m | rowOrder m = m +cmat m@Matrix { irows = r, icols = c, xdat = d } = + m { xdat = transdata r d c, xRow = fi c, xCol = 1 } fmat :: (Element t) => Matrix t -> Matrix t -fmat m@Matrix{order = ColumnMajor} = m -fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} - --- C-Haskell matrix adapter --- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r - -mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b -mat a f = - unsafeWith (xdat a) $ \p -> do - let m g = do - g (fi (rows a)) (fi (cols a)) p - f m - -omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b -omat a f = - unsafeWith (xdat a) $ \p -> do - let m g = do - g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p - f m +fmat m | not (rowOrder m) = m +fmat m@Matrix { irows = r, icols = c, xdat = d} = + m { xdat = transdata c d r, xRow = 1, xCol = fi r } --------------------------------------------------------------------------------- +-- C-Haskell matrix adapters {-# INLINE amatr #-} amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) @@ -144,14 +133,8 @@ amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) where r = fromIntegral (rows x) c = fromIntegral (cols x) - sr = stepRow x - sc = stepCol x - -{-# INLINE arrmat #-} -arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b -arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p))) - where - s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x] + sr = xRow x + sc = xCol x instance Storable t => TransArray (Matrix t) @@ -163,8 +146,6 @@ instance Storable t => TransArray (Matrix t) {-# INLINE apply #-} applyRaw = amatr {-# INLINE applyRaw #-} - applyArray = arrmat - {-# INLINE applyArray #-} infixl 1 # a # b = apply a b @@ -246,8 +227,7 @@ m@Matrix {irows = r, icols = c} @@> (i,j) {-# INLINE (@@>) #-} -- Unsafe matrix access without range checking -atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) -atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) +atM' m i j = xdat m `at'` (i * (ti $ xRow m) + j * (ti $ xCol m)) {-# INLINE atM' #-} ------------------------------------------------------------------ @@ -256,7 +236,8 @@ matrixFromVector o r c v | r * c == dim v = m | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m where - m = Matrix { irows = r, icols = c, xdat = v, order = o } + m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = fi c, xCol = 1 } + | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1 , xCol = fi r } -- allocates memory for a new matrix createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) @@ -282,7 +263,7 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v -- | application of a vector function on the flattened matrix elements liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b -liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) +liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d) -- | application of a vector function on the flattened matrices elements liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index d1defda..c98ff0e 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -109,13 +109,13 @@ newVector x n = do {-# INLINE ioReadM #-} ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t -ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) -ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) +ioReadM m r c = ioReadV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) + {-# INLINE ioWriteM #-} ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () -ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val -ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val +ioWriteM m r c val = ioWriteV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) val + newtype STMatrix s t = STMatrix (Matrix t) @@ -150,7 +150,7 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) freezeMatrix m = liftSTMatrix id m -cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) +cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'}) {-# INLINE safeIndexM #-} safeIndexM f (STMatrix m) r c -- cgit v1.2.3