From 6058e1b17c005be1ea95ebb7d98d9fd15bb538d2 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Thu, 26 Aug 2010 17:49:45 +0000 Subject: Float matrix product --- CHANGES | 2 +- lib/Data/Packed/Internal/Signatures.hs | 4 ++ lib/Data/Packed/Internal/Vector.hs | 20 +++++- lib/Data/Packed/Matrix.hs | 66 ++++++++++++++------ lib/Graphics/Plot.hs | 1 - lib/Numeric/LinearAlgebra/Algorithms.hs | 81 +----------------------- lib/Numeric/LinearAlgebra/Interface.hs | 2 +- lib/Numeric/LinearAlgebra/LAPACK.hs | 18 ++++-- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 81 ++++++++++++++++++++++-- lib/Numeric/LinearAlgebra/Linear.hs | 90 ++++++++++++++++++++++++++- lib/Numeric/LinearAlgebra/Tests.hs | 16 +++-- lib/Numeric/LinearAlgebra/Tests/Properties.hs | 6 +- 12 files changed, 264 insertions(+), 123 deletions(-) diff --git a/CHANGES b/CHANGES index 20c5f0e..8d09c9f 100644 --- a/CHANGES +++ b/CHANGES @@ -5,7 +5,7 @@ - Vectors typeclass -- Initial support for Vector Float and Vector (Complex Float) +- Support for Float and Complex Float elements (excluding LAPACK computations) - Binary instances for Vector and Matrix diff --git a/lib/Data/Packed/Internal/Signatures.hs b/lib/Data/Packed/Internal/Signatures.hs index 8c1c5f6..b81efa4 100644 --- a/lib/Data/Packed/Internal/Signatures.hs +++ b/lib/Data/Packed/Internal/Signatures.hs @@ -24,12 +24,15 @@ type PQ = Ptr (Complex Float) -- type PC = Ptr (Complex Double) -- type TF = CInt -> PF -> IO CInt -- type TFF = CInt -> PF -> TF -- +type TFV = CInt -> PF -> TV -- +type TVF = CInt -> PD -> TF -- type TFFF = CInt -> PF -> TFF -- type TV = CInt -> PD -> IO CInt -- type TVV = CInt -> PD -> TV -- type TVVV = CInt -> PD -> TVV -- type TFM = CInt -> CInt -> PF -> IO CInt -- type TFMFM = CInt -> CInt -> PF -> TFM -- +type TFMFMFM = CInt -> CInt -> PF -> TFMFM -- type TM = CInt -> CInt -> PD -> IO CInt -- type TMM = CInt -> CInt -> PD -> TM -- type TVMM = CInt -> PD -> TMM -- @@ -61,6 +64,7 @@ type TQVQVQV = CInt -> PQ -> TQVQV -- type TQVF = CInt -> PQ -> TF -- type TQM = CInt -> CInt -> PQ -> IO CInt -- type TQMQM = CInt -> CInt -> PQ -> TQM -- +type TQMQMQM = CInt -> CInt -> PQ -> TQMQM -- type TCMCV = CInt -> CInt -> PC -> TCV -- type TVCV = CInt -> PD -> TCV -- type TCVM = CInt -> PC -> TM -- diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index ac2d0d7..c8cc2c2 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -21,7 +21,7 @@ module Data.Packed.Internal.Vector ( mapVectorM, mapVectorM_, foldVector, foldVectorG, foldLoop, createVector, vec, - asComplex, asReal, + asComplex, asReal, float2DoubleV, double2FloatV, fwriteVector, freadVector, fprintfVector, fscanfVector, cloneVector, unsafeToForeignPtr, @@ -274,6 +274,24 @@ asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a) asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2) where (fp,i,n) = unsafeToForeignPtr v +--------------------------------------------------------------- + +float2DoubleV :: Vector Float -> Vector Double +float2DoubleV v = unsafePerformIO $ do + r <- createVector (dim v) + app2 c_float2double vec v vec r "float2double" + return r + +double2FloatV :: Vector Double -> Vector Float +double2FloatV v = unsafePerformIO $ do + r <- createVector (dim v) + app2 c_double2float vec v vec r "double2float2" + return r + + +foreign import ccall "float2double" c_float2double:: TFV +foreign import ccall "double2float" c_double2float:: TVF + ---------------------------------------------------------------- cloneVector :: Storable t => Vector t -> IO (Vector t) diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 8aa1693..8694249 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -1,6 +1,10 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} + + ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Matrix @@ -16,8 +20,9 @@ ----------------------------------------------------------------------------- module Data.Packed.Matrix ( - Element, Scalar, Container(..), Convert(..), - RealOf, ComplexOf, SingleOf, DoubleOf, ElementOf, AutoReal(..), + Element, RealElement, Container(..), + Convert(..), RealOf, ComplexOf, SingleOf, DoubleOf, ElementOf, + AutoReal(..), Matrix,rows,cols, (><), trans, @@ -51,7 +56,7 @@ import Data.Binary import Foreign.Storable import Control.Monad(replicateM) import Control.Arrow((***)) -import GHC.Float(double2Float,float2Double) +--import GHC.Float(double2Float,float2Double) ------------------------------------------------------------------- @@ -468,17 +473,32 @@ toBlocksEvery r c m = toBlocks rs cs m where -- | conversion utilities -class (Element t, Element (Complex t), Fractional t, RealFloat t) => Scalar t +class (Element t, Element (Complex t), RealFloat t) => RealElement t + +instance RealElement Double +instance RealElement Float + +class (Element s, Element d) => Prec s d | s -> d, d -> s where + double2FloatG :: Vector d -> Vector s + float2DoubleG :: Vector s -> Vector d + +instance Prec Float Double where + double2FloatG = double2FloatV + float2DoubleG = float2DoubleV + +instance Prec (Complex Float) (Complex Double) where + double2FloatG = asComplex . double2FloatV . asReal + float2DoubleG = asComplex . float2DoubleV . asReal -instance Scalar Double -instance Scalar Float class Container c where - toComplex :: (Scalar e) => (c e, c e) -> c (Complex e) - fromComplex :: (Scalar e) => c (Complex e) -> (c e, c e) - comp :: (Scalar e) => c e -> c (Complex e) - conj :: (Scalar e) => c (Complex e) -> c (Complex e) + toComplex :: (RealElement e) => (c e, c e) -> c (Complex e) + fromComplex :: (RealElement e) => c (Complex e) -> (c e, c e) + comp :: (RealElement e) => c e -> c (Complex e) + conj :: (RealElement e) => c (Complex e) -> c (Complex e) cmap :: (Element a, Element b) => (a -> b) -> c a -> c b + single :: Prec a b => c b -> c a + double :: Prec a b => c a -> c b instance Container Vector where toComplex = toComplexV @@ -486,6 +506,8 @@ instance Container Vector where comp v = toComplex (v,constantD 0 (dim v)) conj = conjV cmap = mapVector + single = double2FloatG + double = float2DoubleG instance Container Matrix where toComplex = uncurry $ liftMatrix2 $ curry toComplex @@ -494,6 +516,8 @@ instance Container Matrix where comp = liftMatrix comp conj = liftMatrix conj cmap f = liftMatrix (cmap f) + single = liftMatrix single + double = liftMatrix double ------------------------------------------------------------------- @@ -534,38 +558,40 @@ type instance ElementOf (Matrix a) = a ------------------------------------------------------------------- +-- | generic conversion functions class Convert t where real' :: Container c => c (RealOf t) -> c t complex' :: Container c => c t -> c (ComplexOf t) - single :: Container c => c t -> c (SingleOf t) - double :: Container c => c t -> c (DoubleOf t) + single' :: Container c => c t -> c (SingleOf t) + double' :: Container c => c t -> c (DoubleOf t) instance Convert Double where real' = id complex' = comp - single = cmap double2Float - double = id + single' = single + double' = id instance Convert Float where real' = id complex' = comp - single = id - double = cmap float2Double + single' = id + double' = double instance Convert (Complex Double) where real' = comp complex' = id - single = toComplex . (single *** single) . fromComplex - double = id + single' = single + double' = id instance Convert (Complex Float) where real' = comp complex' = id - single = id - double = toComplex . (double *** double) . fromComplex + single' = id + double' = double ------------------------------------------------------------------- +-- | to be replaced by Convert class AutoReal t where real :: Container c => c Double -> c t complex :: Container c => c t -> c (Complex Double) diff --git a/lib/Graphics/Plot.hs b/lib/Graphics/Plot.hs index b2acc15..2dc0553 100644 --- a/lib/Graphics/Plot.hs +++ b/lib/Graphics/Plot.hs @@ -29,7 +29,6 @@ module Graphics.Plot( ) where import Data.Packed -import Numeric.LinearAlgebra(outer) import Numeric.LinearAlgebra.Linear import Data.List(intersperse) import System.Process (system) diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index f4b7ee9..8962c60 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -21,9 +21,6 @@ imported from "Numeric.LinearAlgebra.LAPACK". module Numeric.LinearAlgebra.Algorithms ( -- * Supported types Field(), --- * Products - multiply, -- dot, moved dot to typeclass - outer, kronecker, -- * Linear Systems linearSolve, luSolve, @@ -64,7 +61,6 @@ module Numeric.LinearAlgebra.Algorithms ( -- * Norms Normed(..), NormType(..), -- * Misc - ctrans, eps, i, -- * Util haussholder, @@ -86,7 +82,7 @@ import Data.List(foldl1') import Data.Array -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. -class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where +class (Prod t, Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t) thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t) sv' :: Matrix t -> Vector Double @@ -105,8 +101,6 @@ class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where qr' :: Matrix t -> (Matrix t, Matrix t) hess' :: Matrix t -> (Matrix t, Matrix t) schur' :: Matrix t -> (Matrix t, Matrix t) - ctrans' :: Matrix t -> Matrix t - multiply' :: Matrix t -> Matrix t -> Matrix t instance Field Double where @@ -119,7 +113,6 @@ instance Field Double where cholSolve' = cholSolveR linearSolveLS' = linearSolveLSR linearSolveSVD' = linearSolveSVDR Nothing - ctrans' = trans eig' = eigR eigSH'' = eigS eigOnly = eigOnlyR @@ -129,7 +122,6 @@ instance Field Double where qr' = unpackQR . qrR hess' = unpackHess hessR schur' = schurR - multiply' = multiplyR instance Field (Complex Double) where #ifdef NOZGESDD @@ -146,7 +138,6 @@ instance Field (Complex Double) where cholSolve' = cholSolveC linearSolveLS' = linearSolveLSC linearSolveSVD' = linearSolveSVDC Nothing - ctrans' = conj . trans eig' = eigC eigOnly = eigOnlyC eigSH'' = eigH @@ -156,7 +147,6 @@ instance Field (Complex Double) where qr' = unpackQR . qrC hess' = unpackHess hessC schur' = schurC - multiply' = multiplyC -------------------------------------------------------------- @@ -324,13 +314,6 @@ hess = hess' schur :: Field t => Matrix t -> (Matrix t, Matrix t) schur = schur' --- | Generic conjugate transpose. -ctrans :: Field t => Matrix t -> Matrix t -ctrans = ctrans' - --- | Matrix product. -multiply :: Field t => Matrix t -> Matrix t -> Matrix t -multiply = {-# SCC "multiply" #-} multiply' -- | Similar to 'cholSH', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'. mbCholSH :: Field t => Matrix t -> Maybe (Matrix t) @@ -404,20 +387,6 @@ peps x = 2.0**(fromIntegral $ 1-floatDigits x) i :: Complex Double i = 0:+1 - --- matrix product -mXm :: (Num t, Field t) => Matrix t -> Matrix t -> Matrix t -mXm = multiply - --- matrix - vector product -mXv :: (Num t, Field t) => Matrix t -> Vector t -> Vector t -mXv m v = flatten $ m `mXm` (asColumn v) - --- vector - matrix product -vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t -vXm v m = flatten $ (asRow v) `mXm` m - - --------------------------------------------------------------------------- norm2 :: Vector Double -> Double @@ -723,51 +692,3 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) (|*|) = mul -------------------------------------------------- - -{- moved to Numeric.LinearAlgebra.Interface Vector typeclass --- | Euclidean inner product. -dot :: (Field t) => Vector t -> Vector t -> t -dot u v = multiply r c @@> (0,0) - where r = asRow u - c = asColumn v --} - -{- | Outer product of two vectors. - -@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] -(3><3) - [ 5.0, 2.0, 3.0 - , 10.0, 4.0, 6.0 - , 15.0, 6.0, 9.0 ]@ --} -outer :: (Field t) => Vector t -> Vector t -> Matrix t -outer u v = asColumn u `multiply` asRow v - -{- | Kronecker product of two matrices. - -@m1=(2><3) - [ 1.0, 2.0, 0.0 - , 0.0, -1.0, 3.0 ] -m2=(4><3) - [ 1.0, 2.0, 3.0 - , 4.0, 5.0, 6.0 - , 7.0, 8.0, 9.0 - , 10.0, 11.0, 12.0 ]@ - -@\> kronecker m1 m2 -(8><9) - [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 - , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 - , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 - , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 - , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 - , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 - , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 - , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ --} -kronecker :: (Field t) => Matrix t -> Matrix t -> Matrix t -kronecker a b = fromBlocks - . splitEvery (cols a) - . map (reshape (cols b)) - . toRows - $ flatten a `outer` flatten b diff --git a/lib/Numeric/LinearAlgebra/Interface.hs b/lib/Numeric/LinearAlgebra/Interface.hs index f8917a0..6df782f 100644 --- a/lib/Numeric/LinearAlgebra/Interface.hs +++ b/lib/Numeric/LinearAlgebra/Interface.hs @@ -35,7 +35,7 @@ import Numeric.LinearAlgebra.Linear class Mul a b c | a b -> c where infixl 7 <> -- | Matrix-matrix, matrix-vector, and vector-matrix products. - (<>) :: Field t => a t -> b t -> c t + (<>) :: Prod t => a t -> b t -> c t instance Mul Matrix Matrix Matrix where (<>) = multiply diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 7f057ba..eec3035 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -14,7 +14,7 @@ module Numeric.LinearAlgebra.LAPACK ( -- * Matrix product - multiplyR, multiplyC, + multiplyR, multiplyC, multiplyF, multiplyQ, -- * Linear systems linearSolveR, linearSolveC, lusR, lusC, @@ -51,8 +51,10 @@ import Control.Monad(when) ----------------------------------------------------------------------------------- -foreign import ccall "LAPACK/lapack-aux.h multiplyR" dgemmc :: CInt -> CInt -> TMMM -foreign import ccall "LAPACK/lapack-aux.h multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM +foreign import ccall "multiplyR" dgemmc :: CInt -> CInt -> TMMM +foreign import ccall "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM +foreign import ccall "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM +foreign import ccall "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM isT MF{} = 0 isT MC{} = 1 @@ -69,12 +71,20 @@ multiplyAux f st a b = unsafePerformIO $ do -- | Matrix product based on BLAS's /dgemm/. multiplyR :: Matrix Double -> Matrix Double -> Matrix Double -multiplyR a b = multiplyAux dgemmc "dgemmc" a b +multiplyR a b = {-# SCC "multiplyR" #-} multiplyAux dgemmc "dgemmc" a b -- | Matrix product based on BLAS's /zgemm/. multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) multiplyC a b = multiplyAux zgemmc "zgemmc" a b +-- | Matrix product based on BLAS's /sgemm/. +multiplyF :: Matrix Float -> Matrix Float -> Matrix Float +multiplyF a b = multiplyAux sgemmc "sgemmc" a b + +-- | Matrix product based on BLAS's /cgemm/. +multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float) +multiplyQ a b = multiplyAux cgemmc "cgemmc" a b + ----------------------------------------------------------------------------- foreign import ccall "svd_l_R" dgesvd :: TMMVM foreign import ccall "svd_l_C" zgesvd :: TCMCMVCM diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 7a40991..9e44431 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -11,15 +11,25 @@ #define MIN(A,B) ((A)<(B)?(A):(B)) #define MAX(A,B) ((A)>(B)?(A):(B)) - + +// #define DBGL + #ifdef DBGL -#define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL); -#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); +#define DEBUGMSG(M) printf("\nLAPACK "M"\n"); #else #define DEBUGMSG(M) -#define OK return 0; #endif +#define OK return 0; + +// #ifdef DBGL +// #define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL); +// #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); +// #else +// #define DEBUGMSG(M) +// #define OK return 0; +// #endif + #define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ for(q=0;q Linear c e where +class (Element e, AutoReal e, Container c) => Linear c e where -- | create a structure with a single element scalar :: e -> c e scale :: e -> c e -> c e @@ -184,3 +188,83 @@ linspace :: (Enum e, Linear Vector e) => Int -> (e, e) -> Vector e linspace n (a,b) = addConstant a $ scale s $ fromList [0 .. fromIntegral n-1] where s = (b-a)/fromIntegral (n-1) +---------------------------------------------------- + +-- reference multiply +mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] + where doth u v = sum $ zipWith (*) (toList u) (toList v) + +class Element t => Prod t where + multiply :: Matrix t -> Matrix t -> Matrix t + multiply = mulH + ctrans :: Matrix t -> Matrix t + +instance Prod Double where + multiply = multiplyR + ctrans = trans + +instance Prod (Complex Double) where + multiply = multiplyC + ctrans = conj . trans + +instance Prod Float where + multiply = multiplyF + ctrans = trans + +instance Prod (Complex Float) where + multiply = multiplyQ + ctrans = conj . trans + +---------------------------------------------------------- + +-- synonym for matrix product +mXm :: Prod t => Matrix t -> Matrix t -> Matrix t +mXm = multiply + +-- matrix - vector product +mXv :: Prod t => Matrix t -> Vector t -> Vector t +mXv m v = flatten $ m `mXm` (asColumn v) + +-- vector - matrix product +vXm :: Prod t => Vector t -> Matrix t -> Vector t +vXm v m = flatten $ (asRow v) `mXm` m + +{- | Outer product of two vectors. + +@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] +(3><3) + [ 5.0, 2.0, 3.0 + , 10.0, 4.0, 6.0 + , 15.0, 6.0, 9.0 ]@ +-} +outer :: (Prod t) => Vector t -> Vector t -> Matrix t +outer u v = asColumn u `multiply` asRow v + +{- | Kronecker product of two matrices. + +@m1=(2><3) + [ 1.0, 2.0, 0.0 + , 0.0, -1.0, 3.0 ] +m2=(4><3) + [ 1.0, 2.0, 3.0 + , 4.0, 5.0, 6.0 + , 7.0, 8.0, 9.0 + , 10.0, 11.0, 12.0 ]@ + +@\> kronecker m1 m2 +(8><9) + [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0 + , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0 + , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0 + , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0 + , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0 + , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0 + , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0 + , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]@ +-} +kronecker :: (Prod t) => Matrix t -> Matrix t -> Matrix t +kronecker a b = fromBlocks + . splitEvery (cols a) + . map (reshape (cols b)) + . toRows + $ flatten a `outer` flatten b diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index e3b6e1f..91f6742 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -34,6 +34,7 @@ import qualified Prelude import System.CPUTime import Text.Printf import Data.Packed.Development(unsafeFromForeignPtr,unsafeToForeignPtr) +import Control.Arrow((***)) #include "Tests/quickCheckCompat.h" @@ -224,11 +225,16 @@ runTests :: Int -- ^ maximum dimension runTests n = do setErrorHandlerOff let test p = qCheck n p - putStrLn "------ mult" - test (multProp1 . rConsist) - test (multProp1 . cConsist) - test (multProp2 . rConsist) - test (multProp2 . cConsist) + putStrLn "------ mult Double" + test (multProp1 10 . rConsist) + test (multProp1 10 . cConsist) + test (multProp2 10 . rConsist) + test (multProp2 10 . cConsist) + putStrLn "------ mult Float" + test (multProp1 6 . (single *** single) . rConsist) + test (multProp1 6 . (single *** single) . cConsist) + test (multProp2 6 . (single *** single) . rConsist) + test (multProp2 6 . (single *** single) . cConsist) putStrLn "------ sub-trans" test (subProp . rM) test (subProp . cM) diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index d29e19a..f7a948e 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs @@ -42,7 +42,7 @@ module Numeric.LinearAlgebra.Tests.Properties ( linearSolveProp, linearSolveProp2 ) where -import Numeric.LinearAlgebra +import Numeric.LinearAlgebra hiding (mulH) import Numeric.LinearAlgebra.LAPACK import Debug.Trace #include "quickCheckCompat.h" @@ -237,9 +237,9 @@ expmDiagProp m = expm (logm m) :~ 7 ~: complex m mulH a b = fromLists [[ doth ai bj | bj <- toColumns b] | ai <- toRows a ] where doth u v = sum $ zipWith (*) (toList u) (toList v) -multProp1 (a,b) = a <> b |~| mulH a b +multProp1 p (a,b) = (a <> b) :~p~: (mulH a b) -multProp2 (a,b) = ctrans (a <> b) |~| ctrans b <> ctrans a +multProp2 p (a,b) = (ctrans (a <> b)) :~p~: (ctrans b <> ctrans a) linearSolveProp f m = f m m |~| ident (rows m) -- cgit v1.2.3