From fdf8d8778d52cf14aec493ef5ab18d363b900ed7 Mon Sep 17 00:00:00 2001 From: Reiner Pope Date: Sat, 7 Jan 2012 11:47:06 +1100 Subject: Make Matrix a product type --- lib/Data/Packed/Internal/Matrix.hs | 74 +++++++++++++++---------------------- lib/Data/Packed/ST.hs | 11 +++--- lib/Numeric/LinearAlgebra/LAPACK.hs | 8 ++-- 3 files changed, 39 insertions(+), 54 deletions(-) (limited to 'lib') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index a39c0f0..28bebbc 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -18,7 +18,7 @@ -- #hide module Data.Packed.Internal.Matrix( - Matrix(..), rows, cols, + Matrix(..), rows, cols, cdat, fdat, MatrixOrder(..), orderOf, createMatrix, mat, cmat, fmat, @@ -82,21 +82,23 @@ import System.IO.Unsafe(unsafePerformIO) data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) +transOrder RowMajor = ColumnMajor +transOrder ColumnMajor = RowMajor {- | Matrix representation suitable for GSL and LAPACK computations. The elements are stored in a continuous memory array. -} -data Matrix t = MC { irows :: {-# UNPACK #-} !Int - , icols :: {-# UNPACK #-} !Int - , cdat :: {-# UNPACK #-} !(Vector t) } - | MF { irows :: {-# UNPACK #-} !Int - , icols :: {-# UNPACK #-} !Int - , fdat :: {-# UNPACK #-} !(Vector t) } +data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int + , icols :: {-# UNPACK #-} !Int + , xdat :: {-# UNPACK #-} !(Vector t) + , order :: !MatrixOrder } +-- RowMajor: preferred by C, fdat may require a transposition +-- ColumnMajor: preferred by LAPACK, cdat may require a transposition --- MC: preferred by C, fdat may require a transposition --- MF: preferred by LAPACK, cdat may require a transposition +cdat = xdat +fdat = xdat rows :: Matrix t -> Int rows = irows @@ -104,25 +106,21 @@ rows = irows cols :: Matrix t -> Int cols = icols -xdat MC {cdat = d } = d -xdat MF {fdat = d } = d - orderOf :: Matrix t -> MatrixOrder -orderOf MF{} = ColumnMajor -orderOf MC{} = RowMajor +orderOf = order + -- | Matrix transpose. trans :: Matrix t -> Matrix t -trans MC {irows = r, icols = c, cdat = d } = MF {irows = c, icols = r, fdat = d } -trans MF {irows = r, icols = c, fdat = d } = MC {irows = c, icols = r, cdat = d } +trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o} cmat :: (Element t) => Matrix t -> Matrix t -cmat m@MC{} = m -cmat MF {irows = r, icols = c, fdat = d } = MC {irows = r, icols = c, cdat = transdata r d c} +cmat m@Matrix{order = RowMajor} = m +cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor} fmat :: (Element t) => Matrix t -> Matrix t -fmat m@MF{} = m -fmat MC {irows = r, icols = c, cdat = d } = MF {irows = r, icols = c, fdat = transdata c d r} +fmat m@Matrix{order = ColumnMajor} = m +fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor} -- C-Haskell matrix adapter -- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r @@ -140,7 +138,7 @@ mat a f = 9 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ -} flatten :: Element t => Matrix t -> Vector t -flatten = cdat . cmat +flatten = xdat . cmat type Mt t s = Int -> Int -> Ptr t -> s -- not yet admitted by my haddock version @@ -186,32 +184,21 @@ infixl 9 @@> -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" -- | otherwise = cdat m `at` (i*c+j) -MC {irows = r, icols = c, cdat = v} @@> (i,j) - | safe = if i<0 || i>=r || j<0 || j>=c - then error "matrix indexing out of range" - else v `at` (i*c+j) - | otherwise = v `at` (i*c+j) - -MF {irows = r, icols = c, fdat = v} @@> (i,j) +m@Matrix {irows = r, icols = c, xdat = v, order = o} @@> (i,j) | safe = if i<0 || i>=r || j<0 || j>=c then error "matrix indexing out of range" - else v `at` (j*r+i) - | otherwise = v `at` (j*r+i) + else atM' m i j + | otherwise = atM' m i j {-# INLINE (@@>) #-} -- Unsafe matrix access without range checking -atM' MC {icols = c, cdat = v} i j = v `at'` (i*c+j) -atM' MF {irows = r, fdat = v} i j = v `at'` (j*r+i) +atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j) +atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) {-# INLINE atM' #-} ------------------------------------------------------------------ -matrixFromVector RowMajor c v = MC { irows = r, icols = c, cdat = v } - where (d,m) = dim v `quotRem` c - r | m==0 = d - | otherwise = error "matrixFromVector" - -matrixFromVector ColumnMajor c v = MF { irows = r, icols = c, fdat = v } +matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o } where (d,m) = dim v `quotRem` c r | m==0 = d | otherwise = error "matrixFromVector" @@ -239,16 +226,15 @@ singleton x = reshape 1 (fromList [x]) -- | application of a vector function on the flattened matrix elements liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b -liftMatrix f MC { icols = c, cdat = d } = matrixFromVector RowMajor c (f d) -liftMatrix f MF { icols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) +liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d) -- | application of a vector function on the flattened matrices elements liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2 f m1 m2 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" - | otherwise = case m1 of - MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (flatten m2)) - MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) ((fdat.fmat) m2)) + | otherwise = case orderOf m1 of + RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2)) + ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2)) compat :: Matrix a -> Matrix b -> Bool @@ -427,7 +413,7 @@ subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do go (rt-1) (ct-1) return w -subMatrix' (r0,c0) (rt,ct) (MC _r c v) = MC rt ct $ subMatrix'' (r0,c0) (rt,ct) c v +subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m) -------------------------------------------------------------------------- diff --git a/lib/Data/Packed/ST.hs b/lib/Data/Packed/ST.hs index 00f5e78..c96a209 100644 --- a/lib/Data/Packed/ST.hs +++ b/lib/Data/Packed/ST.hs @@ -113,13 +113,13 @@ newVector x n = do {-# INLINE ioReadM #-} ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t -ioReadM (MC _ nc cv) r c = ioReadV cv (r*nc+c) -ioReadM (MF nr _ fv) r c = ioReadV fv (c*nr+r) +ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c) +ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r) {-# INLINE ioWriteM #-} ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () -ioWriteM (MC _ nc cv) r c val = ioWriteV cv (r*nc+c) val -ioWriteM (MF nr _ fv) r c val = ioWriteV fv (c*nr+r) val +ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val +ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val newtype STMatrix s t = STMatrix (Matrix t) @@ -153,8 +153,7 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) freezeMatrix m = liftSTMatrix id m -cloneMatrix (MC r c d) = cloneVector d >>= return . MC r c -cloneMatrix (MF r c d) = cloneVector d >>= return . MF r c +cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) {-# INLINE safeIndexM #-} safeIndexM f (STMatrix m) r c diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index d1aa564..349650c 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -58,11 +58,11 @@ foreign import ccall "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM -isT MF{} = 0 -isT MC{} = 1 +isT Matrix{order = ColumnMajor} = 0 +isT Matrix{order = RowMajor} = 1 -tt x@MF{} = x -tt x@MC{} = trans x +tt x@Matrix{order = RowMajor} = x +tt x@Matrix{order = ColumnMajor} = trans x multiplyAux f st a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ -- cgit v1.2.3