From 210f0027a7a4614469f8f61eef255852c53e5fb8 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 15 Oct 2008 15:42:12 +0000 Subject: debug info for the NaN bug --- lib/Numeric/LinearAlgebra/Algorithms.hs | 48 ++++++----- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 116 +++++++++++++------------- lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 6 +- 3 files changed, 89 insertions(+), 81 deletions(-) (limited to 'lib/Numeric') diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 00a0ab0..fbefa68 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -589,6 +589,7 @@ mulCW a b = toComplex (rr,ri) -- Direct CBLAS ----------------------------------------------------------------------------------- +-- taken from Patrick Perry's BLAS package newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) @@ -602,8 +603,9 @@ trans' = CBLASTrans 112 conjTrans = CBLASTrans 113 foreign import ccall "cblas.h cblas_dgemm" - dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () - + dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt + -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double + -> Ptr Double -> CInt -> IO () multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) @@ -618,7 +620,9 @@ multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) foreign import ccall "cblas.h cblas_zgemm" - zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () + zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt + -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) + -> Ptr (Complex Double) -> CInt -> IO () multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) @@ -640,27 +644,27 @@ multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat -- BLAS via auxiliary C ----------------------------------------------------------------------------------- -foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM -foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM - -multiply2 f st a b - | cols a == rows b = unsafePerformIO $ do - s <- createMatrix ColumnMajor (rows a) (cols b) - app3 f mat a mat b mat s st - if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) - | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" - -multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double -multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) - -multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) -multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) +-- foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM +-- foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM +-- +-- multiply2 f st a b +-- | cols a == rows b = unsafePerformIO $ do +-- s <- createMatrix ColumnMajor (rows a) (cols b) +-- app3 f mat a mat b mat s st +-- if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) +-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" +-- +-- multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double +-- multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) +-- +-- multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +-- multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) ----------------------------------------------------------------------------------- --- direct C multiplication +-- direct C multiplication, to expose the NaN bug ----------------------------------------------------------------------------------- -foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM +-- foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM cmultiply f st a b @@ -674,8 +678,8 @@ cmultiply f st a b -- return s -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" -multiplyR :: Matrix Double -> Matrix Double -> Matrix Double -multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) +-- multiplyR :: Matrix Double -> Matrix Double -> Matrix Double +-- multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 0dccea2..61cb002 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -817,20 +817,21 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { //////////////////////////////////////////////////////////// -int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) { - REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); - int i,j,k; - for (i=0;iw) exit(1); - //printf("%d",w>w); - temp.r += aik.r * bkj.r - aik.i * bkj.i; - temp.i += aik.r * bkj.i + aik.i * bkj.r; - //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i); - //printf("%f %f %f \n",w,temp.r,temp.i); - + doublecomplex A = ((doublecomplex*)ap)[i*ac+k]; + doublecomplex B = ((doublecomplex*)bp)[k*bc+j]; + double w = A.r * B.r - A.i * B.i; + double w1 = A.r * B.r; + double w2 = A.i * B.i; + if(w != w) { + printf("at : %d %d %d\n",i,j,k); + printf("%f %f %f\n",A.i,B.i, A.i * B.i); + printf("%f %f %f\n",A.i,B.i, w2); + exit(1); + } + temp.r += (w + w1-w2)/2; + //temp.r += w; + temp.i += A.r * B.i + A.i * B.r; } ((doublecomplex*)rp)[i*rc+j] = temp; - //printf("%f %f\n",temp.r,temp.i); } } OK } -void dgemm_(char *, char *, integer *, integer *, integer *, - double *, const double *, integer *, const double *, - integer *, double *, double *, integer *); - -void zgemm_(char *, char *, integer *, integer *, integer *, - doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, - integer *, doublecomplex *, doublecomplex *, integer *); - - -int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { - REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); - double alpha = 1; - double beta = 0; - integer m = ar; - integer n = bc; - integer k = ac; - int i,j; - dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); - OK -} - -int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { - REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); - integer m = ar; - integer n = bc; - integer k = ac; - doublecomplex alpha = {1,0}; - doublecomplex beta = {0,0}; - zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); - OK -} +// void dgemm_(char *, char *, integer *, integer *, integer *, +// double *, const double *, integer *, const double *, +// integer *, double *, double *, integer *); +// +// void zgemm_(char *, char *, integer *, integer *, integer *, +// doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, +// integer *, doublecomplex *, doublecomplex *, integer *); +// +// +// int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { +// REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); +// double alpha = 1; +// double beta = 0; +// integer m = ar; +// integer n = bc; +// integer k = ac; +// int i,j; +// dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); +// OK +// } +// +// int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { +// REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); +// integer m = ar; +// integer n = bc; +// integer k = ac; +// doublecomplex alpha = {1,0}; +// doublecomplex beta = {0,0}; +// zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); +// OK +// } diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index c0361a6..e8cba30 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h @@ -85,8 +85,8 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); -int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); -int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); +// int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); +// int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); -int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); +// int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)); -- cgit v1.2.3