From 6058e1b17c005be1ea95ebb7d98d9fd15bb538d2 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Thu, 26 Aug 2010 17:49:45 +0000 Subject: Float matrix product --- lib/Numeric/LinearAlgebra/Algorithms.hs | 81 +----------------------- lib/Numeric/LinearAlgebra/Interface.hs | 2 +- lib/Numeric/LinearAlgebra/LAPACK.hs | 18 ++++-- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 81 ++++++++++++++++++++++-- lib/Numeric/LinearAlgebra/Linear.hs | 90 ++++++++++++++++++++++++++- lib/Numeric/LinearAlgebra/Tests.hs | 16 +++-- lib/Numeric/LinearAlgebra/Tests/Properties.hs | 6 +- 7 files changed, 194 insertions(+), 100 deletions(-) (limited to 'lib/Numeric') diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index f4b7ee9..8962c60 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -21,9 +21,6 @@ imported from "Numeric.LinearAlgebra.LAPACK". module Numeric.LinearAlgebra.Algorithms ( -- * Supported types Field(), --- * Products - multiply, -- dot, moved dot to typeclass - outer, kronecker, -- * Linear Systems linearSolve, luSolve, @@ -64,7 +61,6 @@ module Numeric.LinearAlgebra.Algorithms ( -- * Norms Normed(..), NormType(..), -- * Misc - ctrans, eps, i, -- * Util haussholder, @@ -86,7 +82,7 @@ import Data.List(foldl1') import Data.Array -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. -class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where +class (Prod t, Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) sv' :: Matrix t -> Vector Double @@ -105,8 +101,6 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where qr' :: Matrix t -> (Matrix t, Matrix t) hess' :: Matrix t -> (Matrix t, Matrix t) schur' :: Matrix t -> (Matrix t, Matrix t) - ctrans' :: Matrix t -> Matrix t - multiply' :: Matrix t -> Matrix t -> Matrix t instance Field Double where @@ -119,7 +113,6 @@ instance Field Double where cholSolve' = cholSolveR linearSolveLS' = linearSolveLSR linearSolveSVD' = linearSolveSVDR Nothing - ctrans' = trans eig' = eigR eigSH'' = eigS eigOnly = eigOnlyR @@ -129,7 +122,6 @@ instance Field Double where qr' = unpackQR . qrR hess' = unpackHess hessR schur' = schurR - multiply' = multiplyR instance Field (Complex Double) where #ifdef NOZGESDD @@ -146,7 +138,6 @@ instance Field (Complex Double) where cholSolve' = cholSolveC linearSolveLS' = linearSolveLSC linearSolveSVD' = linearSolveSVDC Nothing - ctrans' = conj . trans eig' = eigC eigOnly = eigOnlyC eigSH'' = eigH @@ -156,7 +147,6 @@ instance Field (Complex Double) where qr' = unpackQR . qrC hess' = unpackHess hessC schur' = schurC - multiply' = multiplyC -------------------------------------------------------------- @@ -324,13 +314,6 @@ hess = hess' schur :: Field t => Matrix t -> (Matrix t, Matrix t) schur = schur' --- | Generic conjugate transpose. -ctrans :: Field t => Matrix t -> Matrix t -ctrans = ctrans' - --- | Matrix product. -multiply :: Field t => Matrix t -> Matrix t -> Matrix t -multiply = {-# SCC "multiply" #-} multiply' -- | Similar to 'cholSH', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'. mbCholSH :: Field t => Matrix t -> Maybe (Matrix t) @@ -404,20 +387,6 @@ peps x = 2.0**(fromIntegral $ 1-floatDigits x) i :: Complex Double i = 0:+1 - --- matrix product -mXm :: (Num t, Field t) => Matrix t -> Matrix t -> Matrix t -mXm = multiply - --- matrix - vector product -mXv :: (Num t, Field t) => Matrix t -> Vector t -> Vector t -mXv m v = flatten $ m `mXm` (asColumn v) - --- vector - matrix product -vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t -vXm v m = flatten $ (asRow v) `mXm` m - - --------------------------------------------------------------------------- norm2 :: Vector Double -> Double @@ -723,51 +692,3 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) (|*|) = mul -------------------------------------------------- - -{- moved to Numeric.LinearAlgebra.Interface Vector typeclass --- | Euclidean inner product. -dot :: (Field t) => Vector t -> Vector t -> t -dot u v = multiply r c @@> (0,0) - where r = asRow u - c = asColumn v --} - -{- | Outer product of two vectors. - -@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] -(3><3) - [ 5.0, 2.0, 3.0 - , 10.0, 4.0, 6.0 - , 15.0, 6.0, 9.0 ]@ --} -outer :: (Field t) => Vector t -> Vector t -> Matrix t -outer u v = asColumn u `multiply` asRow v - -{- | Kronecker product of two matrices. - -@m1=(2><3) - [ 1.0, 2.0, 0.0 - , 0.0, -1.0, 3.0 ] -m2=(4><3) - [ 1.0, 2.0, 3.0 - , 4.0, 5.0, 6.0 - , 7.0, 8.0, 9.0 - , 10.0, 11.0, 12.0 ]@ - -@\> kronecker m1 m2 -(8><9) - [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 - , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 - , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 - , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 - , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 - , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 - , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 - , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ --} -kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t -kronecker a b = fromBlocks - . splitEvery (cols a) - . map (reshape (cols b)) - . toRows - $ flatten a `outer` flatten b diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs index f8917a0..6df782f 100644 --- a/lib/Numeric/LinearAlgebra/Interface.hs +++ b/lib/Numeric/LinearAlgebra/Interface.hs @@ -35,7 +35,7 @@ import Numeric.LinearAlgebra.Linear class Mul a b c | a b -> c where infixl 7 <> -- | Matrix-matrix, matrix-vector, and vector-matrix products. - (<>) :: Field t => a t -> b t -> c t + (<>) :: Prod t => a t -> b t -> c t instance Mul Matrix Matrix Matrix where (<>) = multiply diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 7f057ba..eec3035 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -14,7 +14,7 @@ module Numeric.LinearAlgebra.LAPACK ( -- * Matrix product - multiplyR, multiplyC, + multiplyR, multiplyC, multiplyF, multiplyQ, -- * Linear systems linearSolveR, linearSolveC, lusR, lusC, @@ -51,8 +51,10 @@ import Control.Monad(when) ----------------------------------------------------------------------------------- -foreign import ccall "LAPACK/lapack-aux.h multiplyR" dgemmc :: CInt -> CInt -> TMMM -foreign import ccall "LAPACK/lapack-aux.h multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM +foreign import ccall "multiplyR" dgemmc :: CInt -> CInt -> TMMM +foreign import ccall "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM +foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM +foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM isT MF{} = 0 isT MC{} = 1 @@ -69,12 +71,20 @@ multiplyAux f st a b = unsafePerformIO $ do -- | Matrix product based on BLAS's /dgemm/. multiplyR :: Matrix Double -> Matrix Double -> Matrix Double -multiplyR a b = multiplyAux dgemmc "dgemmc" a b +multiplyR a b = {-# SCC "multiplyR" #-} multiplyAux dgemmc "dgemmc" a b -- | Matrix product based on BLAS's /zgemm/. multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) multiplyC a b = multiplyAux zgemmc "zgemmc" a b +-- | Matrix product based on BLAS's /sgemm/. +multiplyF :: Matrix Float -> Matrix Float -> Matrix Float +multiplyF a b = multiplyAux sgemmc "sgemmc" a b + +-- | Matrix product based on BLAS's /cgemm/. +multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) +multiplyQ a b = multiplyAux cgemmc "cgemmc" a b + ----------------------------------------------------------------------------- foreign import ccall "svd_l_R" dgesvd :: TMMVM foreign import ccall "svd_l_C" zgesvd :: TCMCMVCM diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 7a40991..9e44431 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -11,15 +11,25 @@ #define MIN(A,B) ((A)<(B)?(A):(B)) #define MAX(A,B) ((A)>(B)?(A):(B)) - + +// #define DBGL + #ifdef DBGL -#define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL); -#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); +#define DEBUGMSG(M) printf("\nLAPACK "M"\n"); #else #define DEBUGMSG(M) -#define OK return 0; #endif +#define OK return 0; + +// #ifdef DBGL +// #define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL); +// #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); +// #else +// #define DEBUGMSG(M) +// #define OK return 0; +// #endif + #define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ for(q=0;q Linear c e where +class (Element e, AutoReal e, Container c) => Linear c e where -- | create a structure with a single element scalar :: e -> c e scale :: e -> c e -> c e @@ -184,3 +188,83 @@ linspace :: (Enum e, Linear Vector e) => Int -> (e, e) -> Vector e linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1] where s = (b-a)/fromIntegral (n-1) +---------------------------------------------------- + +-- reference multiply +mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] + where doth u v = sum $ zipWith (*) (toList u) (toList v) + +class Element t => Prod t where + multiply :: Matrix t -> Matrix t -> Matrix t + multiply = mulH + ctrans :: Matrix t -> Matrix t + +instance Prod Double where + multiply = multiplyR + ctrans = trans + +instance Prod (Complex Double) where + multiply = multiplyC + ctrans = conj . trans + +instance Prod Float where + multiply = multiplyF + ctrans = trans + +instance Prod (Complex Float) where + multiply = multiplyQ + ctrans = conj . trans + +---------------------------------------------------------- + +-- synonym for matrix product +mXm :: Prod t => Matrix t -> Matrix t -> Matrix t +mXm = multiply + +-- matrix - vector product +mXv :: Prod t => Matrix t -> Vector t -> Vector t +mXv m v = flatten $ m `mXm` (asColumn v) + +-- vector - matrix product +vXm :: Prod t => Vector t -> Matrix t -> Vector t +vXm v m = flatten $ (asRow v) `mXm` m + +{- | Outer product of two vectors. + +@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] +(3><3) + [ 5.0, 2.0, 3.0 + , 10.0, 4.0, 6.0 + , 15.0, 6.0, 9.0 ]@ +-} +outer :: (Prod t) => Vector t -> Vector t -> Matrix t +outer u v = asColumn u `multiply` asRow v + +{- | Kronecker product of two matrices. + +@m1=(2><3) + [ 1.0, 2.0, 0.0 + , 0.0, -1.0, 3.0 ] +m2=(4><3) + [ 1.0, 2.0, 3.0 + , 4.0, 5.0, 6.0 + , 7.0, 8.0, 9.0 + , 10.0, 11.0, 12.0 ]@ + +@\> kronecker m1 m2 +(8><9) + [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 + , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 + , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 + , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 + , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 + , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 + , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 + , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ +-} +kronecker :: (Prod t) => Matrix t -> Matrix t -> Matrix t +kronecker a b = fromBlocks + . splitEvery (cols a) + . map (reshape (cols b)) + . toRows + $ flatten a `outer` flatten b diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index e3b6e1f..91f6742 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -34,6 +34,7 @@ import qualified Prelude import System.CPUTime import Text.Printf import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) +import Control.Arrow((***)) #include "Tests/quickCheckCompat.h" @@ -224,11 +225,16 @@ runTests :: Int -- ^ maximum dimension runTests n = do setErrorHandlerOff let test p = qCheck n p - putStrLn "------ mult" - test (multProp1 . rConsist) - test (multProp1 . cConsist) - test (multProp2 . rConsist) - test (multProp2 . cConsist) + putStrLn "------ mult Double" + test (multProp1 10 . rConsist) + test (multProp1 10 . cConsist) + test (multProp2 10 . rConsist) + test (multProp2 10 . cConsist) + putStrLn "------ mult Float" + test (multProp1 6 . (single *** single) . rConsist) + test (multProp1 6 . (single *** single) . cConsist) + test (multProp2 6 . (single *** single) . rConsist) + test (multProp2 6 . (single *** single) . cConsist) putStrLn "------ sub-trans" test (subProp . rM) test (subProp . cM) diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index d29e19a..f7a948e 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs @@ -42,7 +42,7 @@ module Numeric.LinearAlgebra.Tests.Properties ( linearSolveProp, linearSolveProp2 ) where -import Numeric.LinearAlgebra +import Numeric.LinearAlgebra hiding (mulH) import Numeric.LinearAlgebra.LAPACK import Debug.Trace #include "quickCheckCompat.h" @@ -237,9 +237,9 @@ expmDiagProp m = expm (logm m) :~ 7 ~: complex m mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] where doth u v = sum $ zipWith (*) (toList u) (toList v) -multProp1 (a,b) = a <> b |~| mulH a b +multProp1 p (a,b) = (a <> b) :~p~: (mulH a b) -multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a +multProp2 p (a,b) = (ctrans (a <> b)) :~p~: (ctrans b <> ctrans a) linearSolveProp f m = f m m |~| ident (rows m) -- cgit v1.2.3