From b9329d636d19f6a26da1cf1fd7e8d7cbd0b04cce Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 30 Jun 2015 12:04:21 +0200 Subject: support slice in multiply --- packages/base/src/Internal/C/lapack-aux.c | 104 ++++++--------- packages/base/src/Internal/Matrix.hs | 8 +- packages/base/src/Internal/Modular.hs | 4 +- packages/base/src/Internal/ST.hs | 9 +- packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 151 +++++++++++++++++++++- 5 files changed, 202 insertions(+), 74 deletions(-) diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ca60846..30689bf 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -1093,16 +1093,15 @@ void dgemm_(char *, char *, integer *, integer *, integer *, integer *, double *, double *, integer *); int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("dgemm_"); CHECKNANR(a,"NaN multR Input\n") CHECKNANR(b,"NaN multR Input\n") integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; double alpha = 1; double beta = 0; dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); @@ -1115,16 +1114,15 @@ void zgemm_(char *, char *, integer *, integer *, integer *, integer *, doublecomplex *, doublecomplex *, integer *); int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("zgemm_"); CHECKNANC(a,"NaN multC Input\n") CHECKNANC(b,"NaN multC Input\n") integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; doublecomplex alpha = {1,0}; doublecomplex beta = {0,0}; zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, @@ -1140,14 +1138,13 @@ void sgemm_(char *, char *, integer *, integer *, integer *, integer *, float *, float *, integer *); int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("sgemm_"); integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; float alpha = 1; float beta = 0; sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc); @@ -1159,14 +1156,13 @@ void cgemm_(char *, char *, integer *, integer *, integer *, integer *, complex *, complex *, integer *); int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { - //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); DEBUGMSG("cgemm_"); integer m = ta?ac:ar; integer n = tb?br:bc; integer k = ta?ar:ac; - integer lda = ar; - integer ldb = br; - integer ldc = rr; + integer lda = aXc; + integer ldb = bXc; + integer ldc = rXc; complex alpha = {1,0}; complex beta = {0,0}; cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha, @@ -1187,15 +1183,15 @@ int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { } \ } -#define MULT_IMP { \ +#define MULT_IMP(M) { \ if (m==1) { \ MULT_IMP_VER( AT(r,i,j) += AT(a,i,k) * AT(b,k,j); ) \ } else { \ - MULT_IMP_VER( AT(r,i,j) = (AT(r,i,j) + (AT(a,i,k) * AT(b,k,j)) % m) % m ; ) \ + MULT_IMP_VER( AT(r,i,j) = M(AT(r,i,j) + M(AT(a,i,k) * AT(b,k,j), m) , m) ; ) \ } OK } -int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP -int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP +int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP(mod) +int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP(mod_l) /////////////////////////////// inplace row ops //////////////////////////////// @@ -1277,27 +1273,19 @@ ROWOP_MOD(int64_t,mod_l) /////////////////////////////// inplace GEMM //////////////////////////////// -#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \ - T a = cp[0], b = cp[1]; \ - int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \ - int r1b = pp[4], c1b = pp[6] ; \ - int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \ - int dra = r1a - r1r; \ - int dcb = c1b-c1r; \ - int nk = c2a-c1a+1; \ - int i,j,k; \ - T t; \ - for (i=r1r; i<=r2r; i++) { \ - for (j=c1r; j<=c2r; j++) { \ - t = 0; \ - for(k=0; kd) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } +matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } matrixFromVector o r c v | r * c == dim v = m | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m @@ -280,7 +282,7 @@ class (Storable a) => Element a where selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () - gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () + gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () instance Element Float where @@ -569,9 +571,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- -gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" +gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" -type Tgemm x = x :> I :> x ::> x ::> x ::> Ok +type Tgemm x = x :> x ::> x ::> x ::> Ok foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 54d9cb8..8fa2747 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -137,7 +137,7 @@ instance KnownNat m => Element (Mod m I) rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) where m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) + gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) where m' = fromIntegral . natVal $ (undefined :: Proxy m) @@ -154,7 +154,7 @@ instance KnownNat m => Element (Mod m Z) rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) where m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c) + gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) where m' = fromIntegral . natVal $ (undefined :: Proxy m) diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 91c2a11..62dfddf 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -231,14 +231,13 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[ -- | r0 c0 height width data Slice s t = Slice (STMatrix s t) Int Int Int Int -slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) +slice (Slice (STMatrix m) r0 c0 nr nc) = sliceMatrix (r0,c0) (nr,nc) m gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () -gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res +gemmm beta (slice->r) alpha (slice->a) (slice->b) = res where - res = unsafeIOToST (gemm u v a b r) - u = fromList [alpha,beta] - v = vjoin[pa,pb,pr] + res = unsafeIOToST (gemm v a b r) + v = fromList [alpha,beta] mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index b226c9f..79cb769 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs @@ -4,6 +4,8 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} ----------------------------------------------------------------------------- {- | @@ -76,7 +78,7 @@ detTest1 = det m == 26 && det mc == 38 :+ (-3) && det (feye 2) == -1 where - m = (3><3) + m = (3><3) [ 1, 2, 3 , 4, 5, 7 , 2, 8, 4 :: Double @@ -357,7 +359,7 @@ accumTest = utest "accum" ok ,0,1,7 ,0,0,4] && - toList (flatten x) == [1,0,0,0,1,0,0,0,1] + toList (flatten x) == [1,0,0,0,1,0,0,0,1] -------------------------------------------------------------------------------- @@ -398,6 +400,150 @@ indexProp g f x = a1 == g a2 && a2 == a3 && b1 == g b2 && b2 == b3 a3 = maxElement x b3 = minElement x +-------------------------------------------------------------------------------- + +sliceTest = utest "slice test" $ and + [ testSlice chol (gen 5 :: Matrix R) + , testSlice chol (gen 5 :: Matrix C) + , testSlice qr (rec :: Matrix R) + , testSlice qr (rec :: Matrix C) + , testSlice hess (agen 5 :: Matrix R) + , testSlice hess (agen 5 :: Matrix C) + , testSlice schur (agen 5 :: Matrix R) + , testSlice schur (agen 5 :: Matrix C) + , testSlice lu (agen 5 :: Matrix R) + , testSlice lu (agen 5 :: Matrix C) + , testSlice (luSolve (luPacked (agen 5 :: Matrix R))) (agen 5) + , testSlice (luSolve (luPacked (agen 5 :: Matrix C))) (agen 5) + , test_lus (agen 5 :: Matrix R) + , test_lus (agen 5 :: Matrix C) + + , testSlice eig (agen 5 :: Matrix R) + , testSlice eig (agen 5 :: Matrix C) + , testSlice eigSH (gen 5 :: Matrix R) + , testSlice eigSH (gen 5 :: Matrix C) + , testSlice eigenvalues (agen 5 :: Matrix R) + , testSlice eigenvalues (agen 5 :: Matrix C) + , testSlice eigenvaluesSH (gen 5 :: Matrix R) + , testSlice eigenvaluesSH (gen 5 :: Matrix C) + + , testSlice svd (rec :: Matrix R) + , testSlice thinSVD (rec :: Matrix R) + , testSlice compactSVD (rec :: Matrix R) + , testSlice leftSV (rec :: Matrix R) + , testSlice rightSV (rec :: Matrix R) + , testSlice singularValues (rec :: Matrix R) + + , testSlice svd (rec :: Matrix C) + , testSlice thinSVD (rec :: Matrix C) + , testSlice compactSVD (rec :: Matrix C) + , testSlice leftSV (rec :: Matrix C) + , testSlice rightSV (rec :: Matrix C) + , testSlice singularValues (rec :: Matrix C) + + , testSlice (linearSolve (agen 5:: Matrix R)) (agen 5) + , testSlice (flip linearSolve (agen 5:: Matrix R)) (agen 5) + + , testSlice (linearSolve (agen 5:: Matrix C)) (agen 5) + , testSlice (flip linearSolve (agen 5:: Matrix C)) (agen 5) + + , testSlice (linearSolveLS (ogen 5:: Matrix R)) (ogen 5) + , testSlice (flip linearSolveLS (ogen 5:: Matrix R)) (ogen 5) + + , testSlice (linearSolveLS (ogen 5:: Matrix C)) (ogen 5) + , testSlice (flip linearSolveLS (ogen 5:: Matrix C)) (ogen 5) + + , testSlice (linearSolveSVD (ogen 5:: Matrix R)) (ogen 5) + , testSlice (flip linearSolveSVD (ogen 5:: Matrix R)) (ogen 5) + + , testSlice (linearSolveSVD (ogen 5:: Matrix C)) (ogen 5) + , testSlice (flip linearSolveSVD (ogen 5:: Matrix C)) (ogen 5) + + , testSlice (linearSolveLS (ugen 5:: Matrix R)) (ugen 5) + , testSlice (flip linearSolveLS (ugen 5:: Matrix R)) (ugen 5) + + , testSlice (linearSolveLS (ugen 5:: Matrix C)) (ugen 5) + , testSlice (flip linearSolveLS (ugen 5:: Matrix C)) (ugen 5) + + , testSlice (linearSolveSVD (ugen 5:: Matrix R)) (ugen 5) + , testSlice (flip linearSolveSVD (ugen 5:: Matrix R)) (ugen 5) + + , testSlice (linearSolveSVD (ugen 5:: Matrix C)) (ugen 5) + , testSlice (flip linearSolveSVD (ugen 5:: Matrix C)) (ugen 5) + + , testSlice ((<>) (ogen 5:: Matrix R)) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix R)) (ogen 5) + , testSlice ((<>) (ogen 5:: Matrix C)) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix C)) (ogen 5) + , testSlice ((<>) (ogen 5:: Matrix Float)) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix Float)) (ogen 5) + , testSlice ((<>) (ogen 5:: Matrix (Complex Float))) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix (Complex Float))) (ogen 5) + , testSlice ((<>) (ogen 5:: Matrix I)) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix I)) (ogen 5) + , testSlice ((<>) (ogen 5:: Matrix Z)) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix Z)) (ogen 5) + + , testSlice ((<>) (ogen 5:: Matrix (I ./. 7))) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix (I ./. 7))) (ogen 5) + , testSlice ((<>) (ogen 5:: Matrix (Z ./. 7))) (gen 5) + , testSlice (flip (<>) (gen 5:: Matrix (Z ./. 7))) (ogen 5) + + , testSlice (flip cholSolve (agen 5:: Matrix R)) (chol $ gen 5) + , testSlice (flip cholSolve (agen 5:: Matrix C)) (chol $ gen 5) + , testSlice (cholSolve (chol $ gen 5:: Matrix R)) (agen 5) + , testSlice (cholSolve (chol $ gen 5:: Matrix C)) (agen 5) + + , ok_qrgr (rec :: Matrix R) + , ok_qrgr (rec :: Matrix C) + , testSlice (test_qrgr 4 tau1) qrr1 + , testSlice (test_qrgr 4 tau2) qrr2 + ] + where + (qrr1,tau1) = qrRaw (rec :: Matrix R) + (qrr2,tau2) = qrRaw (rec :: Matrix C) + + test_qrgr n t x = qrgr n (x,t) + + ok_qrgr x = simeq 1E-15 q q' + where + (q,_) = qr x + atau = qrRaw x + q' = qrgr (rows q) atau + + simeq eps a b = not $ magnit eps (norm_1 $ flatten (a-b)) + + test_lus m = testSlice f lup + where + f x = luSolve (x,p) m + (lup,p) = luPacked m + + gen :: Numeric t => Int -> Matrix t + gen n = diagRect 1 (konst 5 n) n n + + agen :: (Numeric t, Num (Vector t))=> Int -> Matrix t + agen n = gen n + fromInt ((n> Int -> Matrix t + ogen n = gen n === gen n + + ugen :: (Numeric t, Num (Vector t))=> Int -> Matrix t + ugen n = takeRows 3 (gen n) + + + rec :: Numeric t => Matrix t + rec = subMatrix (0,0) (4,5) (gen 5) + + testSlice f x@(size->sz@(r,c)) = all (==f x) (map f (g y1 ++ g y2)) + where + subm = sliceMatrix + g y = [ subm (a*r,b*c) sz y | a <-[0..2], b <- [0..2]] + h z = fromBlocks (replicate 3 (replicate 3 z)) + y1 = h x + y2 = (tr . h . tr) x + + + -------------------------------------------------------------------------------- -- | All tests must pass with a maximum dimension of about 20 @@ -578,6 +724,7 @@ runTests n = do , staticTest , intTest , modularTest + , sliceTest ] when (errors c + failures c > 0) exitFailure return () -- cgit v1.2.3