From 89651db9f2577ba42dbbb91c85565a12f34d0fb2 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 10 Sep 2007 09:13:20 +0000 Subject: simplified --- lib/Data/Packed/Internal/Matrix.hs | 398 +++++++++++++++---------------------- 1 file changed, 163 insertions(+), 235 deletions(-) (limited to 'lib/Data') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 6ba2d06..ba32a67 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -1,4 +1,4 @@ -{-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances #-} +{-# OPTIONS_GHC -fglasgow-exts #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Internal.Matrix @@ -22,65 +22,10 @@ import Foreign hiding (xor) import Complex import Control.Monad(when) import Data.List(transpose,intersperse) ---import Data.Typeable import Data.Maybe(fromJust) ----------------------------------------------------------------- - --- the condition Storable a => Field a means that we can only put --- in Field types that are in Storable, and therefore Storable a --- is not required in signatures if we have a Field a. - -class Storable a => Field a where - constant :: a -> Int -> Vector a - transdata :: Int -> Vector a -> Int -> Vector a - multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a - subMatrix :: (Int,Int) -- ^ (r0,c0) starting position - -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix - -> Matrix a -> Matrix a - diag :: Vector a -> Matrix a - -instance Field Double where - constant = constantR - transdata = transdataR - multiplyD = multiplyR - subMatrix = subMatrixR - diag = diagR - -instance Field (Complex Double) where - constant = constantC - transdata = transdataC - multiplyD = multiplyC - subMatrix = subMatrixC - diag = diagC - ----------------------------------------------------------------- -transdataR :: Int -> Vector Double -> Int -> Vector Double -transdataR = transdataAux ctransR - -transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) -transdataC = transdataAux ctransC - -transdataAux fun c1 d c2 = - if noneed - then d - else unsafePerformIO $ do - v <- createVector (dim d) - fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] - --putStrLn "---> transdataAux" - return v - where r1 = dim d `div` c1 - r2 = dim d `div` c2 - noneed = r1 == 1 || c1 == 1 - -foreign import ccall safe "aux.h transR" - ctransR :: TMM -- Double ::> Double ::> IO Int -foreign import ccall safe "aux.h transC" - ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int - -transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d - {- Design considerations for the Matrix Type ----------------------------------------- @@ -111,103 +56,79 @@ transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) -{- - - - -data Matrix t = M { rows :: Int - , cols :: Int - , dat :: Vector t - , tdat :: Vector t - , isTrans :: Bool - , order :: MatrixOrder - } -- deriving Typeable --} - -data Matrix t = MC { rows :: Int, cols :: Int, dat :: Vector t } -- row major order - | MF { rows :: Int, cols :: Int, dat :: Vector t } -- column major order +data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } + | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } -- transposition just changes the data order trans :: Matrix t -> Matrix t -trans MC {rows = r, cols = c, dat = d} = MF {rows = c, cols = r, dat = d} -trans MF {rows = r, cols = c, dat = d} = MC {rows = c, cols = r, dat = d} - -viewC m@MC{} = m -viewC MF {rows = r, cols = c, dat = d} = MC {rows = r, cols = c, dat = transdata r d c} +trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } +trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } -viewF m@MF{} = m -viewF MC {rows = r, cols = c, dat = d} = MF {rows = r, cols = c, dat = transdata c d r} +dat MC { cdat = d } = d +dat MF { fdat = d } = d ---fortran m = order m == ColumnMajor - -cdat m = dat (viewC m) -fdat m = dat (viewF m) +mat d m f = f (rows m) (cols m) (ptr (d m)) type Mt t s = Int -> Int -> Ptr t -> s -- not yet admitted by my haddock version -- infixr 6 ::> -- type t ::> s = Mt t s -mat d m f = f (rows m) (cols m) (ptr (d m)) ---mat m f = f (rows m) (cols m) (ptr (dat m)) ---matC m f = f (rows m) (cols m) (ptr (cdat m)) +-- | the inverse of 'fromLists' +toLists :: (Field t) => Matrix t -> [[t]] +toLists m = partit (cols m) . toList . cdat $ m +-- | creates a Matrix from a list of vectors +fromRows :: Field t => [Vector t] -> Matrix t +fromRows vs = case common dim vs of + Nothing -> error "fromRows applied to [] or to vectors with different sizes" + Just c -> reshape c (join vs) ---toLists :: (Storable t) => Matrix t -> [[t]] -toLists m = partit (cols m) . toList . cdat $ m +-- | extracts the rows of a matrix as a list of vectors +toRows :: Field t => Matrix t -> [Vector t] +toRows m = toRows' 0 where + v = cdat m + r = rows m + c = cols m + toRows' k | k == r*c = [] + | otherwise = subVector k c v : toRows' (k+c) -instance (Show a, Field a) => (Show (Matrix a)) where - show m = (sizes++) . dsp . map (map show) . toLists $ m - where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" +-- | Creates a matrix from a list of vectors, as columns +fromColumns :: Field t => [Vector t] -> Matrix t +fromColumns m = trans . fromRows $ m -dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp - where - mt = transpose as - longs = map (maximum . map length) mt - mtp = zipWith (\a b -> map (pad a) b) longs mt - pad n str = replicate (n - length str) ' ' ++ str - unwords' = concat . intersperse ", " +-- | Creates a list of vectors from the columns of a matrix +toColumns :: Field t => Matrix t -> [Vector t] +toColumns m = toRows . trans $ m -{- -matrixFromVector RowMajor c v = - M { rows = r - , cols = c - , dat = v - , tdat = transdata c v r - , order = RowMajor - , isTrans = False - } where (d,m) = dim v `divMod` c - r | m==0 = d - | otherwise = error "matrixFromVector" - -matrixFromVector ColumnMajor c v = - M { rows = r - , cols = c - , dat = v - , tdat = transdata r v c - , order = ColumnMajor - , isTrans = False - } where (d,m) = dim v `divMod` c - r | m==0 = d - | otherwise = error "matrixFromVector" --} +-- | Reads a matrix position. +(@@>) :: Storable t => Matrix t -> (Int,Int) -> t +infixl 9 @@> +--m@M {rows = r, cols = c} @@> (i,j) +-- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" +-- | otherwise = cdat m `at` (i*c+j) -matrixFromVector RowMajor c v = MC { rows = r, cols = c, dat = v} +MC {rows = r, cols = c, cdat = v} @@> (i,j) + | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" + | otherwise = v `at` (i*c+j) + +MF {rows = r, cols = c, fdat = v} @@> (i,j) + | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" + | otherwise = v `at` (j*r+i) + +------------------------------------------------------------------ + +matrixFromVector RowMajor c v = MC { rows = r, cols = c, cdat = v, fdat = transdata c v r } where (d,m) = dim v `divMod` c r | m==0 = d | otherwise = error "matrixFromVector" -matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, dat = v} +matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, fdat = v, cdat = transdata r v c } where (d,m) = dim v `divMod` c r | m==0 = d | otherwise = error "matrixFromVector" - - - - - createMatrix order r c = do p <- createVector (r*c) return (matrixFromVector order c p) @@ -226,45 +147,94 @@ reshape c v = matrixFromVector RowMajor c v singleton x = reshape 1 (fromList [x]) ---liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b -liftMatrix f m = reshape (cols m) (f (cdat m)) +liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b +liftMatrix f MC { cols = c, cdat = d } = matrixFromVector RowMajor c (f d) +liftMatrix f MF { cols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) + + +liftMatrix2 :: (Field t, Field a, Field b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t +liftMatrix2 f m1 m2 + | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" + | otherwise = case m1 of + MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (cdat m2)) + MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) (fdat m2)) ---liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t -liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) - | otherwise = error "nonconformant matrices in liftMatrix2" ------------------------------------------------------------------- compat :: Matrix a -> Matrix b -> Bool compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 -dotL a b = sum (zipWith (*) a b) +---------------------------------------------------------------- -multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] - | otherwise = error "inconsistent dimensions in contraction " - where ok = case common length a of - Nothing -> False - Just c -> c == length b +-- | element types for which optimized matrix computations are provided +class Storable a => Field a where + -- | @constant val n@ creates a vector with @n@ elements, all equal to @val@. + constant :: a -> Int -> Vector a + transdata :: Int -> Vector a -> Int -> Vector a + multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a + -- | extracts a submatrix froma a matrix + subMatrix :: (Int,Int) -- ^ (r0,c0) starting position + -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix + -> Matrix a -> Matrix a + -- | creates a square matrix with the given diagonal + diag :: Vector a -> Matrix a -transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) +instance Field Double where + constant = constantR + transdata = transdataR + multiplyD = multiplyR + subMatrix = subMatrixR + diag = diagR -multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) +instance Field (Complex Double) where + constant = constantC + transdata = transdataC + multiplyD = multiplyC + subMatrix = subMatrixC + diag = diagC ------------------------------------------------------------------ -{- -gmatC m f | fortran m = - if (isTrans m) - then f 0 (rows m) (cols m) (ptr (dat m)) - else f 1 (cols m) (rows m) (ptr (dat m)) - | otherwise = - if isTrans m - then f 1 (cols m) (rows m) (ptr (dat m)) - else f 0 (rows m) (cols m) (ptr (dat m)) --} +instance (Show a, Field a) => (Show (Matrix a)) where + show m = (sizes++) . dsp . map (map show) . toLists $ m + where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" + +dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp + where + mt = transpose as + longs = map (maximum . map length) mt + mtp = zipWith (\a b -> map (pad a) b) longs mt + pad n str = replicate (n - length str) ' ' ++ str + unwords' = concat . intersperse ", " -gmatC MF {rows = r, cols = c, dat = d} f = f 1 c r (ptr d) -gmatC MC {rows = r, cols = c, dat = d} f = f 0 r c (ptr d) -{-# INLINE gmatC #-} +------------------------------------------------------------------ + +transdataR :: Int -> Vector Double -> Int -> Vector Double +transdataR = transdataAux ctransR + +transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) +transdataC = transdataAux ctransC + +transdataAux fun c1 d c2 = + if noneed + then d + else unsafePerformIO $ do + v <- createVector (dim d) + fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] + --putStrLn "---> transdataAux" + return v + where r1 = dim d `div` c1 + r2 = dim d `div` c2 + noneed = r1 == 1 || c1 == 1 + +foreign import ccall safe "aux.h transR" + ctransR :: TMM -- Double ::> Double ::> IO Int +foreign import ccall safe "aux.h transC" + ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int + +------------------------------------------------------------------ + +gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) +gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) multiplyAux fun order a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ @@ -272,14 +242,15 @@ multiplyAux fun order a b = unsafePerformIO $ do r <- createMatrix order (rows a) (cols b) fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] return r -{-# INLINE multiplyAux #-} +multiplyR = multiplyAux cmultiplyR foreign import ccall safe "aux.h multiplyR" cmultiplyR :: Int -> Int -> Int -> Ptr Double -> Int -> Int -> Int -> Ptr Double -> Int -> Int -> Ptr Double -> IO Int +multiplyC = multiplyAux cmultiplyC foreign import ccall safe "aux.h multiplyC" cmultiplyC :: Int -> Int -> Int -> Ptr (Complex Double) -> Int -> Int -> Int -> Ptr (Complex Double) @@ -288,14 +259,9 @@ foreign import ccall safe "aux.h multiplyC" multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a multiply RowMajor a b = multiplyD RowMajor a b -multiply ColumnMajor a b = MF {rows = c, cols = r, dat = d} - where MC {rows = r, cols = c, dat = d } = multiplyD RowMajor (trans b) (trans a) - - -multiplyR = multiplyAux cmultiplyR' -multiplyC = multiplyAux cmultiplyC - -cmultiplyR' p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 = {-# SCC "mulR" #-} cmultiplyR p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 +multiply ColumnMajor a b = MF {rows = c, cols = r, fdat = d, cdat = dt } + where MC {rows = r, cols = c, cdat = d, fdat = dt } = multiplyD RowMajor (trans b) (trans a) +-- FIXME using MatrixFromVector ---------------------------------------------------------------------- @@ -318,18 +284,6 @@ subMatrixC (r0,c0) (rt,ct) x = subMatrixR (r0,2*c0) (rt,2*ct) . reshape (2*cols x) . asReal . cdat $ x ---subMatrix :: (Field a) --- => (Int,Int) -- ^ (r0,c0) starting position --- -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix --- -> Matrix a -> Matrix a ---subMatrix st sz m --- | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) --- | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) --- | otherwise = subMatrixG st sz m - -subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) - where subList s n = take n . drop s - --------------------------------------------------------------------- diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do @@ -347,66 +301,7 @@ diagC :: Vector (Complex Double) -> Matrix (Complex Double) diagC = diagAux c_diagC "diagC" foreign import ccall "aux.h diagC" c_diagC :: TCVCM --- | diagonal matrix from a vector ---diag :: (Num a, Field a) => Vector a -> Matrix a ---diag v --- | isReal (baseOf) v = scast $ diagR (scast v) --- | isComp (baseOf) v = scast $ diagC (scast v) --- | otherwise = diagG v - -diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] - where c = dim v - l = toList v - delta i j | i==j = 1 - | otherwise = 0 - --- | creates a Matrix from a list of vectors ---fromRows :: Field t => [Vector t] -> Matrix t -fromRows vs = case common dim vs of - Nothing -> error "fromRows applied to [] or to vectors with different sizes" - Just c -> reshape c (join vs) - --- | extracts the rows of a matrix as a list of vectors ---toRows :: Storable t => Matrix t -> [Vector t] -toRows m = toRows' 0 where - v = cdat m - r = rows m - c = cols m - toRows' k | k == r*c = [] - | otherwise = subVector k c v : toRows' (k+c) - --- | Creates a matrix from a list of vectors, as columns -fromColumns :: Field t => [Vector t] -> Matrix t -fromColumns m = trans . fromRows $ m - --- | Creates a list of vectors from the columns of a matrix -toColumns :: Field t => Matrix t -> [Vector t] -toColumns m = toRows . trans $ m - - --- | Reads a matrix position. -(@@>) :: Storable t => Matrix t -> (Int,Int) -> t -infixl 9 @@> ---m@M {rows = r, cols = c} @@> (i,j) --- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" --- | otherwise = cdat m `at` (i*c+j) - -MC {rows = r, cols = c, dat = v} @@> (i,j) - | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" - | otherwise = v `at` (i*c+j) - -MF {rows = r, cols = c, dat = v} @@> (i,j) - | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" - | otherwise = v `at` (j*r+i) - - ------------------------------------------------------------------- - -constantR :: Double -> Int -> Vector Double -constantR = constantAux cconstantR - -constantC :: Complex Double -> Int -> Vector (Complex Double) -constantC = constantAux cconstantC +------------------------------------------------------------------------ constantAux fun x n = unsafePerformIO $ do v <- createVector n @@ -415,8 +310,41 @@ constantAux fun x n = unsafePerformIO $ do free px return v +constantR :: Double -> Int -> Vector Double +constantR = constantAux cconstantR foreign import ccall safe "aux.h constantR" cconstantR :: Ptr Double -> TV -- Double :> IO Int +constantC :: Complex Double -> Int -> Vector (Complex Double) +constantC = constantAux cconstantC foreign import ccall safe "aux.h constantC" cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int + +------------------------------------------------------------------------- + +-- Generic definitions + +{- +transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) + +subMatrixG (r0,c0) (rt,ct) x = matrixFromVector RowMajor ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) + where subList s n = take n . drop s + +diagG v = matrixFromVector RowMajor c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] + where c = dim v + l = toList v + delta i j | i==j = 1 + | otherwise = 0 +-} + +transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d + +dotL a b = sum (zipWith (*) a b) + +multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) + +multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] + | otherwise = error "inconsistent dimensions in contraction " + where ok = case common length a of + Nothing -> False + Just c -> c == length b -- cgit v1.2.3