From 1871acb835b4fc164bcff3f6e7467884b87fbd0f Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 25 Jun 2007 07:32:56 +0000 Subject: l.a. algorithms, etc. --- lib/Data/Packed/Instances.hs | 391 +++++++++++++++++++++++++++++++++++++ lib/Data/Packed/Internal/Matrix.hs | 18 +- lib/Data/Packed/Internal/Vector.hs | 6 +- lib/Data/Packed/Matrix.hs | 35 +++- lib/Data/Packed/Vector.hs | 13 +- 5 files changed, 452 insertions(+), 11 deletions(-) create mode 100644 lib/Data/Packed/Instances.hs (limited to 'lib/Data') diff --git a/lib/Data/Packed/Instances.hs b/lib/Data/Packed/Instances.hs new file mode 100644 index 0000000..4478469 --- /dev/null +++ b/lib/Data/Packed/Instances.hs @@ -0,0 +1,391 @@ +{-# OPTIONS_GHC -fglasgow-exts #-} +----------------------------------------------------------------------------- +{- | +Module : Data.Packed.Instances +Copyright : (c) Alberto Ruiz 2006 +License : GPL-style + +Maintainer : Alberto Ruiz (aruiz at um dot es) +Stability : provisional +Portability : uses -fffi and -fglasgow-exts + +Creates reasonable numeric instances for Vectors and Matrices. In the context of the standard numeric operators, one-component vectors and matrices automatically expand to match the dimensions of the other operand. + +-} +----------------------------------------------------------------------------- + +module Data.Packed.Instances( + Contractible(..) +) where + +import Data.Packed.Internal +import Data.Packed.Vector +import Data.Packed.Matrix +import GSL.Vector +import GSL.Matrix +import LinearAlgebra.Algorithms +import Complex + +instance (Eq a, Field a) => Eq (Vector a) where + a == b = dim a == dim b && toList a == toList b + +instance (Num a, Field a) => Num (Vector a) where + (+) = add + (-) = sub + (*) = mul + signum = liftVector signum + abs = liftVector abs + fromInteger = fromList . return . fromInteger + +instance (Eq a, Field a) => Eq (Matrix a) where + a == b = rows a == rows b && cols a == cols b && cdat a == cdat b && fdat a == fdat b + +instance (Num a, Field a) => Num (Matrix a) where + (+) = liftMatrix2 add + (-) = liftMatrix2 sub + (*) = liftMatrix2 mul + signum = liftMatrix signum + abs = liftMatrix abs + fromInteger = (1><1) . return . fromInteger + +--------------------------------------------------- + +adaptScalar f1 f2 f3 x y + | dim x == 1 = f1 (x@>0) y + | dim y == 1 = f3 x (y@>0) + | otherwise = f2 x y + +{- +subvv = vectorZip 4 +subvc v c = addConstant (-c) v +subcv c v = addConstant c (scale (-1) v) + +mul = vectorZip 1 + +instance Num (Vector Double) where + (+) = adaptScalar addConstant add (flip addConstant) + (-) = adaptScalar subcv subvv subvc + (*) = adaptScalar scale mul (flip scale) + abs = vectorMap 3 + signum = vectorMap 15 + fromInteger n = fromList [fromInteger n] + +---------------------------------------------------- + +--addConstantC a = gmap (+a) +--subCvv u v = u `add` scale (-1) v +subCvv = vectorZipComplex 4 -- faster? +subCvc v c = addConstantC (-c) v +subCcv c v = addConstantC c (scale (-1) v) + + +instance Num (Vector (Complex Double)) where + (+) = adaptScalar addConstantC add (flip addConstantC) + (-) = adaptScalar subCcv subCvv subCvc + (*) = adaptScalar scale (vectorZipComplex 1) (flip scale) + abs = gmap abs + signum = gmap signum + fromInteger n = fromList [fromInteger n] + + +-- | adapts a function on two vectors to work on all the elements of two matrices +liftMatrix2' :: (Vector a -> Vector b -> Vector c) -> Matrix a -> Matrix b -> Matrix c +liftMatrix2' f m1@(M r1 c1 _) m2@(M r2 c2 _) + | sameShape m1 m2 || r1*c1==1 || r2*c2==1 + = reshape (max c1 c2) $ f (flatten m1) (flatten m2) + | otherwise = error "inconsistent matrix dimensions" + +--------------------------------------------------- + +instance (Eq a, Field a) => Eq (Matrix a) where + a == b = rows a == rows b && cdat a == cdat b + +instance Num (Matrix Double) where + (+) = liftMatrix2' (+) + (-) = liftMatrix2' (-) + (*) = liftMatrix2' (*) + abs = liftMatrix abs + signum = liftMatrix signum + fromInteger n = fromLists [[fromInteger n]] + +---------------------------------------------------- + +instance Num (Matrix (Complex Double)) where + (+) = liftMatrix2' (+) + (-) = liftMatrix2' (-) + (*) = liftMatrix2' (*) + abs = liftMatrix abs + signum = liftMatrix signum + fromInteger n = fromLists [[fromInteger n]] + +------------------------------------------------------ + +instance Fractional (Vector Double) where + fromRational n = fromList [fromRational n] + (/) = adaptScalar f (vectorZip 2) g where + r `f` v = vectorZip 2 (constant r (dim v)) v + v `g` r = scale (recip r) v + +------------------------------------------------------- + +instance Fractional (Vector (Complex Double)) where + fromRational n = fromList [fromRational n] + (/) = adaptScalar f (vectorZipComplex 2) g where + r `f` v = gmap ((*r).recip) v + v `g` r = gmap (/r) v + +------------------------------------------------------ + +instance Fractional (Matrix Double) where + fromRational n = fromLists [[fromRational n]] + (/) = liftMatrix2' (/) + +------------------------------------------------------- + +instance Fractional (Matrix (Complex Double)) where + fromRational n = fromLists [[fromRational n]] + (/) = liftMatrix2' (/) + +--------------------------------------------------------- + +instance Floating (Vector Double) where + sin = vectorMap 0 + cos = vectorMap 1 + tan = vectorMap 2 + asin = vectorMap 4 + acos = vectorMap 5 + atan = vectorMap 6 + sinh = vectorMap 7 + cosh = vectorMap 8 + tanh = vectorMap 9 + asinh = vectorMap 10 + acosh = vectorMap 11 + atanh = vectorMap 12 + exp = vectorMap 13 + log = vectorMap 14 + sqrt = vectorMap 16 + (**) = adaptScalar f (vectorZip 5) g where f s v = constant s (dim v) ** v + g v s = v ** constant s (dim v) + pi = fromList [pi] + +----------------------------------------------------------- + +instance Floating (Matrix Double) where + sin = liftMatrix sin + cos = liftMatrix cos + tan = liftMatrix tan + asin = liftMatrix asin + acos = liftMatrix acos + atan = liftMatrix atan + sinh = liftMatrix sinh + cosh = liftMatrix cosh + tanh = liftMatrix tanh + asinh = liftMatrix asinh + acosh = liftMatrix acosh + atanh = liftMatrix atanh + exp = liftMatrix exp + log = liftMatrix log + sqrt = liftMatrix sqrt + (**) = liftMatrix2 (**) + pi = fromLists [[pi]] + +------------------------------------------------------------- + +instance Floating (Vector (Complex Double)) where + sin = vectorMapComplex 0 + cos = vectorMapComplex 1 + tan = vectorMapComplex 2 + asin = vectorMapComplex 4 + acos = vectorMapComplex 5 + atan = vectorMapComplex 6 + sinh = vectorMapComplex 7 + cosh = vectorMapComplex 8 + tanh = vectorMapComplex 9 + asinh = vectorMapComplex 10 + acosh = vectorMapComplex 11 + atanh = vectorMapComplex 12 + exp = vectorMapComplex 13 + log = vectorMapComplex 14 + sqrt = vectorMapComplex 16 + (**) = adaptScalar f (vectorZipComplex 5) g where f s v = constantC s (dim v) ** v + g v s = v ** constantC s (dim v) + pi = fromList [pi] + +--------------------------------------------------------------- + +instance Floating (Matrix (Complex Double)) where + sin = liftMatrix sin + cos = liftMatrix cos + tan = liftMatrix tan + asin = liftMatrix asin + acos = liftMatrix acos + atan = liftMatrix atan + sinh = liftMatrix sinh + cosh = liftMatrix cosh + tanh = liftMatrix tanh + asinh = liftMatrix asinh + acosh = liftMatrix acosh + atanh = liftMatrix atanh + exp = liftMatrix exp + log = liftMatrix log + (**) = liftMatrix2 (**) + sqrt = liftMatrix sqrt + pi = fromLists [[pi]] + +--------------------------------------------------------------- +-} + +class Contractible a b c | a b -> c where + infixl 7 <> +{- | An overloaded operator for matrix products, matrix-vector and vector-matrix products, dot products and scaling of vectors and matrices. Type consistency is statically checked. Alternatively, you can use the specific functions described below, but using this operator you can automatically combine real and complex objects. + +@v = 'fromList' [1,2,3] :: Vector Double +cv = 'fromList' [1+'i',2] +m = 'fromLists' [[1,2,3], + [4,5,7]] :: Matrix Double +cm = 'fromLists' [[ 1, 2], + [3+'i',7*'i'], + [ 'i', 1]] +\ +\> m \<\> v +14. 35. +\ +\> cv \<\> m +9.+1.i 12.+2.i 17.+3.i +\ +\> m \<\> cm + 7.+5.i 5.+14.i +19.+12.i 15.+35.i +\ +\> v \<\> 'i' +1.i 2.i 3.i +\ +\> v \<\> v +14.0 +\ +\> cv \<\> cv +4.0 :+ 2.0@ + +-} + (<>) :: a -> b -> c + + +instance Contractible Double Double Double where + (<>) = (*) + +instance Contractible Double (Complex Double) (Complex Double) where + a <> b = (a:+0) * b + +instance Contractible (Complex Double) Double (Complex Double) where + a <> b = a * (b:+0) + +instance Contractible (Complex Double) (Complex Double) (Complex Double) where + (<>) = (*) + +--------------------------------- matrix matrix + +instance Contractible (Matrix Double) (Matrix Double) (Matrix Double) where + (<>) = mXm + +instance Contractible (Matrix (Complex Double)) (Matrix (Complex Double)) (Matrix (Complex Double)) where + (<>) = mXm + +instance Contractible (Matrix (Complex Double)) (Matrix Double) (Matrix (Complex Double)) where + c <> r = c <> liftMatrix comp r + +instance Contractible (Matrix Double) (Matrix (Complex Double)) (Matrix (Complex Double)) where + r <> c = liftMatrix comp r <> c + +--------------------------------- (Matrix Double) (Vector Double) + +instance Contractible (Matrix Double) (Vector Double) (Vector Double) where + (<>) = mXv + +instance Contractible (Matrix (Complex Double)) (Vector (Complex Double)) (Vector (Complex Double)) where + (<>) = mXv + +instance Contractible (Matrix (Complex Double)) (Vector Double) (Vector (Complex Double)) where + m <> v = m <> comp v + +instance Contractible (Matrix Double) (Vector (Complex Double)) (Vector (Complex Double)) where + m <> v = liftMatrix comp m <> v + +--------------------------------- (Vector Double) (Matrix Double) + +instance Contractible (Vector Double) (Matrix Double) (Vector Double) where + (<>) = vXm + +instance Contractible (Vector (Complex Double)) (Matrix (Complex Double)) (Vector (Complex Double)) where + (<>) = vXm + +instance Contractible (Vector (Complex Double)) (Matrix Double) (Vector (Complex Double)) where + v <> m = v <> liftMatrix comp m + +instance Contractible (Vector Double) (Matrix (Complex Double)) (Vector (Complex Double)) where + v <> m = comp v <> m + +--------------------------------- dot product + +instance Contractible (Vector Double) (Vector Double) Double where + (<>) = dot + +instance Contractible (Vector (Complex Double)) (Vector (Complex Double)) (Complex Double) where + (<>) = dot + +instance Contractible (Vector Double) (Vector (Complex Double)) (Complex Double) where + a <> b = comp a <> b + +instance Contractible (Vector (Complex Double)) (Vector Double) (Complex Double) where + (<>) = flip (<>) + +--------------------------------- scaling vectors + +instance Contractible Double (Vector Double) (Vector Double) where + (<>) = scale + +instance Contractible (Vector Double) Double (Vector Double) where + (<>) = flip (<>) + +instance Contractible (Complex Double) (Vector (Complex Double)) (Vector (Complex Double)) where + (<>) = scale + +instance Contractible (Vector (Complex Double)) (Complex Double) (Vector (Complex Double)) where + (<>) = flip (<>) + +instance Contractible Double (Vector (Complex Double)) (Vector (Complex Double)) where + a <> v = (a:+0) <> v + +instance Contractible (Vector (Complex Double)) Double (Vector (Complex Double)) where + (<>) = flip (<>) + +instance Contractible (Complex Double) (Vector Double) (Vector (Complex Double)) where + a <> v = a <> comp v + +instance Contractible (Vector Double) (Complex Double) (Vector (Complex Double)) where + (<>) = flip (<>) + +--------------------------------- scaling matrices + +instance Contractible Double (Matrix Double) (Matrix Double) where + (<>) a = liftMatrix (a <>) + +instance Contractible (Matrix Double) Double (Matrix Double) where + (<>) = flip (<>) + +instance Contractible (Complex Double) (Matrix (Complex Double)) (Matrix (Complex Double)) where + (<>) a = liftMatrix (a <>) + +instance Contractible (Matrix (Complex Double)) (Complex Double) (Matrix (Complex Double)) where + (<>) = flip (<>) + +instance Contractible Double (Matrix (Complex Double)) (Matrix (Complex Double)) where + a <> m = (a:+0) <> m + +instance Contractible (Matrix (Complex Double)) Double (Matrix (Complex Double)) where + (<>) = flip (<>) + +instance Contractible (Complex Double) (Matrix Double) (Matrix (Complex Double)) where + a <> m = a <> liftMatrix comp m + +instance Contractible (Matrix Double) (Complex Double) (Matrix (Complex Double)) where + (<>) = flip (<>) diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 9309d1d..dd33943 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -93,6 +93,15 @@ createMatrix order r c = do p <- createVector (r*c) return (matrixFromVector order c p) +{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. + +@\> reshape 4 ('fromList' [1..12]) +(3><4) + [ 1.0, 2.0, 3.0, 4.0 + , 5.0, 6.0, 7.0, 8.0 + , 9.0, 10.0, 11.0, 12.0 ]@ + +-} reshape :: (Field t) => Int -> Vector t -> Matrix t reshape c v = matrixFromVector RowMajor c v @@ -140,7 +149,6 @@ liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes - ------------------------------------------------------------------ dotL a b = sum (zipWith (*) a b) @@ -200,6 +208,14 @@ multiplyD order a b outer' u v = dat (outer u v) +{- | Outer product of two vectors. + +@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] +(3><3) + [ 5.0, 2.0, 3.0 + , 10.0, 4.0, 6.0 + , 15.0, 6.0, 9.0 ]@ +-} outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t outer u v = multiply RowMajor r c where r = matrixFromVector RowMajor 1 u diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 25e848d..f1addf4 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -48,7 +48,7 @@ fromList l = unsafePerformIO $ do toList :: Storable a => Vector a -> [a] toList v = unsafePerformIO $ peekArray (dim v) (ptr v) -n # l = if length l == n then fromList l else error "# with wrong size" +n |> l = if length l == n then fromList l else error "|> with wrong size" at' :: Storable a => Vector a -> Int -> a at' v n = unsafePerformIO $ peekElemOff (ptr v) n @@ -58,7 +58,7 @@ at v n | n >= 0 && n < dim v = at' v n | otherwise = error "vector index out of range" instance (Show a, Storable a) => (Show (Vector a)) where - show v = (show (dim v))++" # " ++ show (toList v) + show v = (show (dim v))++" |> " ++ show (toList v) -- | creates a Vector taking a number of consecutive toList from another Vector subVector :: Storable t => Int -- ^ index of the starting element @@ -129,3 +129,5 @@ constant x n | isReal id x = scast $ constantR (scast x) n | isComp id x = scast $ constantC (scast x) n | otherwise = constantG x n +liftVector f = fromList . map f . toList +liftVector2 f u v = fromList $ zipWith f (toList u) (toList v) diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 0f9d998..36bf32e 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -13,11 +13,11 @@ ----------------------------------------------------------------------------- module Data.Packed.Matrix ( - Matrix(rows,cols), Field, + Matrix(rows,cols), fromLists, toLists, (><), (>|<), (@@>), trans, conjTrans, - reshape, flatten, - fromRows, toRows, fromColumns, toColumns, + reshape, flatten, asRow, asColumn, + fromRows, toRows, fromColumns, toColumns, fromBlocks, joinVert, joinHoriz, flipud, fliprl, liftMatrix, liftMatrix2, @@ -43,6 +43,22 @@ joinVert ms = case common cols ms of joinHoriz :: Field t => [Matrix t] -> Matrix t joinHoriz ms = trans. joinVert . map trans $ ms +{- | Creates a matrix from blocks given as a list of lists of matrices: + +@\> let a = 'diag' $ 'fromList' [5,7,2] +\> let b = 'reshape' 4 $ 'constant' (-1) 12 +\> fromBlocks [[a,b],[b,a]] +(6><7) + [ 5.0, 0.0, 0.0, -1.0, -1.0, -1.0, -1.0 + , 0.0, 7.0, 0.0, -1.0, -1.0, -1.0, -1.0 + , 0.0, 0.0, 2.0, -1.0, -1.0, -1.0, -1.0 + , -1.0, -1.0, -1.0, -1.0, 5.0, 0.0, 0.0 + , -1.0, -1.0, -1.0, -1.0, 0.0, 7.0, 0.0 + , -1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 2.0 ]@ +-} +fromBlocks :: Field t => [[Matrix t]] -> Matrix t +fromBlocks = joinVert . map joinHoriz + -- | Reverse rows flipud :: Field t => Matrix t -> Matrix t flipud m = fromRows . reverse . toRows $ m @@ -98,6 +114,11 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat ---------------------------------------------------------------- +{- | Creates a vector by concatenation of rows + +@\> flatten ('ident' 3) +9 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ +-} flatten :: Matrix t -> Vector t flatten = cdat @@ -106,4 +127,10 @@ fromLists :: Field t => [[t]] -> Matrix t fromLists = fromRows . map fromList conjTrans :: Matrix (Complex Double) -> Matrix (Complex Double) -conjTrans = trans . liftMatrix conj \ No newline at end of file +conjTrans = trans . liftMatrix conj + +asRow :: Field a => Vector a -> Matrix a +asRow v = reshape (dim v) v + +asColumn :: Field a => Vector a -> Matrix a +asColumn v = reshape 1 v diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs index 9d9d879..94f70be 100644 --- a/lib/Data/Packed/Vector.hs +++ b/lib/Data/Packed/Vector.hs @@ -15,7 +15,7 @@ module Data.Packed.Vector ( Vector(dim), Field, fromList, toList, - at, + (@>), subVector, join, constant, toComplex, comp, @@ -26,6 +26,7 @@ module Data.Packed.Vector ( import Data.Packed.Internal import Complex +import GSL.Vector -- | creates a complex vector from vectors with real and imaginary parts toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) @@ -41,10 +42,14 @@ comp v = toComplex (v,constant 0 (dim v)) {- | Creates a real vector containing a range of values: -> > linspace 10 (-2,2) ->-2. -1.556 -1.111 -0.667 -0.222 0.222 0.667 1.111 1.556 2. - +@\> linspace 5 (-3,7) +5 |> [-3.0,-0.5,2.0,4.5,7.0]@ -} linspace :: Int -> (Double, Double) -> Vector Double linspace n (a,b) = fromList [a::Double,a+delta .. b] where delta = (b-a)/(fromIntegral n -1) + +-- | Reads a vector position. +(@>) :: Field t => Vector t -> Int -> t +infixl 9 @> +(@>) = at -- cgit v1.2.3