{-# OPTIONS_GHC -fglasgow-exts #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Internal.Matrix -- Copyright : (c) Alberto Ruiz 2007 -- License : GPL-style -- -- Maintainer : Alberto Ruiz -- Stability : provisional -- Portability : portable (uses FFI) -- -- Internal matrix representation -- ----------------------------------------------------------------------------- -- --#hide module Data.Packed.Internal.Matrix where import Data.Packed.Internal.Common import Data.Packed.Internal.Vector import Foreign hiding (xor) import Complex import Control.Monad(when) import Data.List(transpose,intersperse) import Data.Maybe(fromJust) ----------------------------------------------------------------- {- Design considerations for the Matrix Type ----------------------------------------- - we must easily handle both row major and column major order, for bindings to LAPACK and GSL/C - we'd like to simplify redundant matrix transposes: - Some of them arise from the order requirements of some functions - some functions (matrix product) admit transposed arguments - maybe we don't really need this kind of simplification: - more complex code - some computational overhead - only appreciable gain in code with a lot of redundant transpositions and cheap matrix computations - we could carry both the matrix and its (lazily computed) transpose. This may save some transpositions, but it is necessary to keep track of the data which is actually computed to be used by functions like the matrix product which admit both orders. Therefore, maybe it is better to have something like viewC and viewF, which may actually perform a transpose if required. - but if we need the transposed data and it is not in the structure, we must make sure that we touch the same foreignptr that is used in the computation. Access to such pointer cannot be made by creating a new vector. -} data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } -- transposition just changes the data order trans :: Matrix t -> Matrix t trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } dat MC { cdat = d } = d dat MF { fdat = d } = d mat d m f = f (rows m) (cols m) (ptr (d m)) type Mt t s = Int -> Int -> Ptr t -> s -- not yet admitted by my haddock version -- infixr 6 ::> -- type t ::> s = Mt t s -- | the inverse of 'Data.Packed.Matrix.fromLists' toLists :: (Field t) => Matrix t -> [[t]] toLists m = partit (cols m) . toList . cdat $ m -- | creates a Matrix from a list of vectors fromRows :: Field t => [Vector t] -> Matrix t fromRows vs = case common dim vs of Nothing -> error "fromRows applied to [] or to vectors with different sizes" Just c -> reshape c (join vs) -- | extracts the rows of a matrix as a list of vectors toRows :: Field t => Matrix t -> [Vector t] toRows m = toRows' 0 where v = cdat m r = rows m c = cols m toRows' k | k == r*c = [] | otherwise = subVector k c v : toRows' (k+c) -- | Creates a matrix from a list of vectors, as columns fromColumns :: Field t => [Vector t] -> Matrix t fromColumns m = trans . fromRows $ m -- | Creates a list of vectors from the columns of a matrix toColumns :: Field t => Matrix t -> [Vector t] toColumns m = toRows . trans $ m -- | Reads a matrix position. (@@>) :: Storable t => Matrix t -> (Int,Int) -> t infixl 9 @@> --m@M {rows = r, cols = c} @@> (i,j) -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" -- | otherwise = cdat m `at` (i*c+j) MC {rows = r, cols = c, cdat = v} @@> (i,j) | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | otherwise = v `at` (i*c+j) MF {rows = r, cols = c, fdat = v} @@> (i,j) | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | otherwise = v `at` (j*r+i) ------------------------------------------------------------------ matrixFromVector RowMajor c v = MC { rows = r, cols = c, cdat = v, fdat = transdata c v r } where (d,m) = dim v `divMod` c r | m==0 = d | otherwise = error "matrixFromVector" matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, fdat = v, cdat = transdata r v c } where (d,m) = dim v `divMod` c r | m==0 = d | otherwise = error "matrixFromVector" createMatrix order r c = do p <- createVector (r*c) return (matrixFromVector order c p) {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. @\> reshape 4 ('fromList' [1..12]) (3><4) [ 1.0, 2.0, 3.0, 4.0 , 5.0, 6.0, 7.0, 8.0 , 9.0, 10.0, 11.0, 12.0 ]@ -} reshape :: Field t => Int -> Vector t -> Matrix t reshape c v = matrixFromVector RowMajor c v singleton x = reshape 1 (fromList [x]) liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b liftMatrix f MC { cols = c, cdat = d } = matrixFromVector RowMajor c (f d) liftMatrix f MF { cols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) liftMatrix2 :: (Field t, Field a, Field b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2 f m1 m2 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | otherwise = case m1 of MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (cdat m2)) MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) (fdat m2)) compat :: Matrix a -> Matrix b -> Bool compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 ---------------------------------------------------------------- -- | element types for which optimized matrix computations are provided class Storable a => Field a where -- | @constant val n@ creates a vector with @n@ elements, all equal to @val@. constant :: a -> Int -> Vector a transdata :: Int -> Vector a -> Int -> Vector a multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a -- | extracts a submatrix froma a matrix subMatrix :: (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -> Matrix a -- | creates a square matrix with the given diagonal diag :: Vector a -> Matrix a instance Field Double where constant = constantR transdata = transdataR multiplyD = multiplyR subMatrix = subMatrixR diag = diagR instance Field (Complex Double) where constant = constantC transdata = transdataC multiplyD = multiplyC subMatrix = subMatrixC diag = diagC ------------------------------------------------------------------ instance (Show a, Field a) => (Show (Matrix a)) where show m = (sizes++) . dsp . map (map show) . toLists $ m where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp where mt = transpose as longs = map (maximum . map length) mt mtp = zipWith (\a b -> map (pad a) b) longs mt pad n str = replicate (n - length str) ' ' ++ str unwords' = concat . intersperse ", " ------------------------------------------------------------------ transdataR :: Int -> Vector Double -> Int -> Vector Double transdataR = transdataAux ctransR transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) transdataC = transdataAux ctransC transdataAux fun c1 d c2 = if noneed then d else unsafePerformIO $ do v <- createVector (dim d) fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] --putStrLn "---> transdataAux" return v where r1 = dim d `div` c1 r2 = dim d `div` c2 noneed = r1 == 1 || c1 == 1 foreign import ccall safe "aux.h transR" ctransR :: TMM -- Double ::> Double ::> IO Int foreign import ccall safe "aux.h transC" ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int ------------------------------------------------------------------ gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) multiplyAux fun order a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ show (rows a,cols a) ++ " x " ++ show (rows b, cols b) r <- createMatrix order (rows a) (cols b) fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] return r multiplyR = multiplyAux cmultiplyR foreign import ccall safe "aux.h multiplyR" cmultiplyR :: Int -> Int -> Int -> Ptr Double -> Int -> Int -> Int -> Ptr Double -> Int -> Int -> Ptr Double -> IO Int multiplyC = multiplyAux cmultiplyC foreign import ccall safe "aux.h multiplyC" cmultiplyC :: Int -> Int -> Int -> Ptr (Complex Double) -> Int -> Int -> Int -> Ptr (Complex Double) -> Int -> Int -> Ptr (Complex Double) -> IO Int multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a multiply RowMajor a b = multiplyD RowMajor a b multiply ColumnMajor a b = MF {rows = c, cols = r, fdat = d, cdat = dt } where MC {rows = r, cols = c, cdat = d, fdat = dt } = multiplyD RowMajor (trans b) (trans a) -- FIXME using MatrixFromVector ---------------------------------------------------------------------- -- | extraction of a submatrix of a real matrix subMatrixR :: (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix Double -> Matrix Double subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do r <- createMatrix RowMajor rt ct c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] return r foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM -- | extraction of a submatrix of a complex matrix subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix (Complex Double) -> Matrix (Complex Double) subMatrixC (r0,c0) (rt,ct) x = reshape ct . asComplex . cdat . subMatrixR (r0,2*c0) (rt,2*ct) . reshape (2*cols x) . asReal . cdat $ x --------------------------------------------------------------------- diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do m <- createMatrix RowMajor n n fun // vec v // mat cdat m // check msg [dat m] return m -- {tdat = dat m} -- | diagonal matrix from a real vector diagR :: Vector Double -> Matrix Double diagR = diagAux c_diagR "diagR" foreign import ccall "aux.h diagR" c_diagR :: TVM -- | diagonal matrix from a real vector diagC :: Vector (Complex Double) -> Matrix (Complex Double) diagC = diagAux c_diagC "diagC" foreign import ccall "aux.h diagC" c_diagC :: TCVCM ------------------------------------------------------------------------ constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] fun px // vec v // check "constantAux" [] free px return v constantR :: Double -> Int -> Vector Double constantR = constantAux cconstantR foreign import ccall safe "aux.h constantR" cconstantR :: Ptr Double -> TV -- Double :> IO Int constantC :: Complex Double -> Int -> Vector (Complex Double) constantC = constantAux cconstantC foreign import ccall safe "aux.h constantC" cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int ------------------------------------------------------------------------- -- Generic definitions {- transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) subMatrixG (r0,c0) (rt,ct) x = matrixFromVector RowMajor ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) where subList s n = take n . drop s diagG v = matrixFromVector RowMajor c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] where c = dim v l = toList v delta i j | i==j = 1 | otherwise = 0 -} transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d dotL a b = sum (zipWith (*) a b) multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] | otherwise = error "inconsistent dimensions in contraction " where ok = case common length a of Nothing -> False Just c -> c == length b