summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Numeric/LinearAlgebra')
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs48
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c116
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h6
3 files changed, 89 insertions, 81 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index 00a0ab0..fbefa68 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -589,6 +589,7 @@ mulCW a b = toComplex (rr,ri)
589-- Direct CBLAS 589-- Direct CBLAS
590----------------------------------------------------------------------------------- 590-----------------------------------------------------------------------------------
591 591
592-- taken from Patrick Perry's BLAS package
592newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) 593newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show)
593newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) 594newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show)
594 595
@@ -602,8 +603,9 @@ trans' = CBLASTrans 112
602conjTrans = CBLASTrans 113 603conjTrans = CBLASTrans 113
603 604
604foreign import ccall "cblas.h cblas_dgemm" 605foreign import ccall "cblas.h cblas_dgemm"
605 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () 606 dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
606 607 -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double
608 -> Ptr Double -> CInt -> IO ()
607 609
608multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double 610multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double
609multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) 611multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b)
@@ -618,7 +620,9 @@ multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b)
618 620
619 621
620foreign import ccall "cblas.h cblas_zgemm" 622foreign import ccall "cblas.h cblas_zgemm"
621 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () 623 zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt
624 -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double)
625 -> Ptr (Complex Double) -> CInt -> IO ()
622 626
623multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 627multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
624multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) 628multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b)
@@ -640,27 +644,27 @@ multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat
640-- BLAS via auxiliary C 644-- BLAS via auxiliary C
641----------------------------------------------------------------------------------- 645-----------------------------------------------------------------------------------
642 646
643foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM 647-- foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM
644foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM 648-- foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM
645 649--
646multiply2 f st a b 650-- multiply2 f st a b
647 | cols a == rows b = unsafePerformIO $ do 651-- | cols a == rows b = unsafePerformIO $ do
648 s <- createMatrix ColumnMajor (rows a) (cols b) 652-- s <- createMatrix ColumnMajor (rows a) (cols b)
649 app3 f mat a mat b mat s st 653-- app3 f mat a mat b mat s st
650 if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) 654-- if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s))
651 | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" 655-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
652 656--
653multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double 657-- multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double
654multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) 658-- multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b)
655 659--
656multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 660-- multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
657multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) 661-- multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b)
658 662
659----------------------------------------------------------------------------------- 663-----------------------------------------------------------------------------------
660-- direct C multiplication 664-- direct C multiplication, to expose the NaN bug
661----------------------------------------------------------------------------------- 665-----------------------------------------------------------------------------------
662 666
663foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM 667-- foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM
664foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM 668foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM
665 669
666cmultiply f st a b 670cmultiply f st a b
@@ -674,8 +678,8 @@ cmultiply f st a b
674 -- return s 678 -- return s
675-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" 679-- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices"
676 680
677multiplyR :: Matrix Double -> Matrix Double -> Matrix Double 681-- multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
678multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) 682-- multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b)
679 683
680multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 684multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
681multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) 685multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b)
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index 0dccea2..61cb002 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -817,20 +817,21 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) {
817 817
818//////////////////////////////////////////////////////////// 818////////////////////////////////////////////////////////////
819 819
820int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) { 820// int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) {
821 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 821// REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
822 int i,j,k; 822// int i,j,k;
823 for (i=0;i<ar;i++) { 823// for (i=0;i<ar;i++) {
824 for(j=0;j<bc;j++) { 824// for(j=0;j<bc;j++) {
825 double temp = 0; 825// double temp = 0;
826 for(k=0;k<ac;k++) { 826// for(k=0;k<ac;k++) {
827 temp += ap[i*ac+k]*bp[k*bc+j]; 827// temp += ap[i*ac+k]*bp[k*bc+j];
828 } 828// }
829 rp[i*rc+j] = temp; 829// rp[i*rc+j] = temp;
830 } 830// }
831 } 831// }
832 OK 832// OK
833} 833// }
834
834 835
835int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) { 836int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) {
836 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 837 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
@@ -839,52 +840,55 @@ int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) {
839 for(j=0;j<bc;j++) { 840 for(j=0;j<bc;j++) {
840 doublecomplex temp = {0,0}; 841 doublecomplex temp = {0,0};
841 for(k=0;k<ac;k++) { 842 for(k=0;k<ac;k++) {
842 doublecomplex aik = ((doublecomplex*)ap)[i*ac+k]; 843 doublecomplex A = ((doublecomplex*)ap)[i*ac+k];
843 doublecomplex bkj = ((doublecomplex*)bp)[k*bc+j]; 844 doublecomplex B = ((doublecomplex*)bp)[k*bc+j];
844 //double w = aik.r+aik.i+bkj.r+bkj.i; 845 double w = A.r * B.r - A.i * B.i;
845 //if (w>w) exit(1); 846 double w1 = A.r * B.r;
846 //printf("%d",w>w); 847 double w2 = A.i * B.i;
847 temp.r += aik.r * bkj.r - aik.i * bkj.i; 848 if(w != w) {
848 temp.i += aik.r * bkj.i + aik.i * bkj.r; 849 printf("at : %d %d %d\n",i,j,k);
849 //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i); 850 printf("%f %f %f\n",A.i,B.i, A.i * B.i);
850 //printf("%f %f %f \n",w,temp.r,temp.i); 851 printf("%f %f %f\n",A.i,B.i, w2);
851 852 exit(1);
853 }
854 temp.r += (w + w1-w2)/2;
855 //temp.r += w;
856 temp.i += A.r * B.i + A.i * B.r;
852 } 857 }
853 ((doublecomplex*)rp)[i*rc+j] = temp; 858 ((doublecomplex*)rp)[i*rc+j] = temp;
854 //printf("%f %f\n",temp.r,temp.i);
855 } 859 }
856 } 860 }
857 OK 861 OK
858} 862}
859 863
860void dgemm_(char *, char *, integer *, integer *, integer *, 864// void dgemm_(char *, char *, integer *, integer *, integer *,
861 double *, const double *, integer *, const double *, 865// double *, const double *, integer *, const double *,
862 integer *, double *, double *, integer *); 866// integer *, double *, double *, integer *);
863 867//
864void zgemm_(char *, char *, integer *, integer *, integer *, 868// void zgemm_(char *, char *, integer *, integer *, integer *,
865 doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, 869// doublecomplex *, const doublecomplex *, integer *, const doublecomplex *,
866 integer *, doublecomplex *, doublecomplex *, integer *); 870// integer *, doublecomplex *, doublecomplex *, integer *);
867 871//
868 872//
869int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { 873// int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) {
870 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 874// REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
871 double alpha = 1; 875// double alpha = 1;
872 double beta = 0; 876// double beta = 0;
873 integer m = ar; 877// integer m = ar;
874 integer n = bc; 878// integer n = bc;
875 integer k = ac; 879// integer k = ac;
876 int i,j; 880// int i,j;
877 dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); 881// dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m);
878 OK 882// OK
879} 883// }
880 884//
881int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { 885// int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) {
882 REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 886// REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
883 integer m = ar; 887// integer m = ar;
884 integer n = bc; 888// integer n = bc;
885 integer k = ac; 889// integer k = ac;
886 doublecomplex alpha = {1,0}; 890// doublecomplex alpha = {1,0};
887 doublecomplex beta = {0,0}; 891// doublecomplex beta = {0,0};
888 zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); 892// zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m);
889 OK 893// OK
890} 894// }
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index c0361a6..e8cba30 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -85,8 +85,8 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s));
85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); 85int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r));
86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); 86int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r));
87 87
88int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); 88// int multiplyR(KDMAT(a),KDMAT(b),DMAT(r));
89int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); 89// int multiplyC(KCMAT(a),KCMAT(b),CMAT(r));
90 90
91int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); 91// int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r));
92int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)); 92int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r));