diff options
Diffstat (limited to 'lib/Numeric/LinearAlgebra')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 48 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 116 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 6 |
3 files changed, 89 insertions, 81 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 00a0ab0..fbefa68 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -589,6 +589,7 @@ mulCW a b = toComplex (rr,ri) | |||
589 | -- Direct CBLAS | 589 | -- Direct CBLAS |
590 | ----------------------------------------------------------------------------------- | 590 | ----------------------------------------------------------------------------------- |
591 | 591 | ||
592 | -- taken from Patrick Perry's BLAS package | ||
592 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) | 593 | newtype CBLASOrder = CBLASOrder CInt deriving (Eq, Show) |
593 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) | 594 | newtype CBLASTrans = CBLASTrans CInt deriving (Eq, Show) |
594 | 595 | ||
@@ -602,8 +603,9 @@ trans' = CBLASTrans 112 | |||
602 | conjTrans = CBLASTrans 113 | 603 | conjTrans = CBLASTrans 113 |
603 | 604 | ||
604 | foreign import ccall "cblas.h cblas_dgemm" | 605 | foreign import ccall "cblas.h cblas_dgemm" |
605 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double -> Ptr Double -> CInt -> IO () | 606 | dgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt |
606 | 607 | -> Double -> Ptr Double -> CInt -> Ptr Double -> CInt -> Double | |
608 | -> Ptr Double -> CInt -> IO () | ||
607 | 609 | ||
608 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double | 610 | multiplyR3 :: Matrix Double -> Matrix Double -> Matrix Double |
609 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | 611 | multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) |
@@ -618,7 +620,9 @@ multiplyR3 a b = multiply3 dgemm "cblas_dgemm" (fmat a) (fmat b) | |||
618 | 620 | ||
619 | 621 | ||
620 | foreign import ccall "cblas.h cblas_zgemm" | 622 | foreign import ccall "cblas.h cblas_zgemm" |
621 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> IO () | 623 | zgemm :: CBLASOrder -> CBLASTrans -> CBLASTrans -> CInt -> CInt -> CInt |
624 | -> Ptr (Complex Double) -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) -> CInt -> Ptr (Complex Double) | ||
625 | -> Ptr (Complex Double) -> CInt -> IO () | ||
622 | 626 | ||
623 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 627 | multiplyC3 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
624 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) | 628 | multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat b) |
@@ -640,27 +644,27 @@ multiplyC3 a b = unsafePerformIO $ multiply3 zgemm "cblas_zgemm" (fmat a) (fmat | |||
640 | -- BLAS via auxiliary C | 644 | -- BLAS via auxiliary C |
641 | ----------------------------------------------------------------------------------- | 645 | ----------------------------------------------------------------------------------- |
642 | 646 | ||
643 | foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM | 647 | -- foreign import ccall "multiply.h multiplyR2" dgemmc :: TMMM |
644 | foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM | 648 | -- foreign import ccall "multiply.h multiplyC2" zgemmc :: TCMCMCM |
645 | 649 | -- | |
646 | multiply2 f st a b | 650 | -- multiply2 f st a b |
647 | | cols a == rows b = unsafePerformIO $ do | 651 | -- | cols a == rows b = unsafePerformIO $ do |
648 | s <- createMatrix ColumnMajor (rows a) (cols b) | 652 | -- s <- createMatrix ColumnMajor (rows a) (cols b) |
649 | app3 f mat a mat b mat s st | 653 | -- app3 f mat a mat b mat s st |
650 | if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) | 654 | -- if toLists s== toLists s then return s else error $ "AYYY " ++ (show (toLists s)) |
651 | | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | 655 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" |
652 | 656 | -- | |
653 | multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double | 657 | -- multiplyR2 :: Matrix Double -> Matrix Double -> Matrix Double |
654 | multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) | 658 | -- multiplyR2 a b = multiply2 dgemmc "dgemmc" (fmat a) (fmat b) |
655 | 659 | -- | |
656 | multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 660 | -- multiplyC2 :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
657 | multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) | 661 | -- multiplyC2 a b = multiply2 zgemmc "zgemmc" (fmat a) (fmat b) |
658 | 662 | ||
659 | ----------------------------------------------------------------------------------- | 663 | ----------------------------------------------------------------------------------- |
660 | -- direct C multiplication | 664 | -- direct C multiplication, to expose the NaN bug |
661 | ----------------------------------------------------------------------------------- | 665 | ----------------------------------------------------------------------------------- |
662 | 666 | ||
663 | foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM | 667 | -- foreign import ccall "multiply.h multiplyR" cmultiplyR :: TMMM |
664 | foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM | 668 | foreign import ccall "multiply.h multiplyC" cmultiplyC :: TCMCMCM |
665 | 669 | ||
666 | cmultiply f st a b | 670 | cmultiply f st a b |
@@ -674,8 +678,8 @@ cmultiply f st a b | |||
674 | -- return s | 678 | -- return s |
675 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" | 679 | -- | otherwise = error $ st ++ " (matrix product) of nonconformant matrices" |
676 | 680 | ||
677 | multiplyR :: Matrix Double -> Matrix Double -> Matrix Double | 681 | -- multiplyR :: Matrix Double -> Matrix Double -> Matrix Double |
678 | multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) | 682 | -- multiplyR a b = cmultiply cmultiplyR "cmultiplyR" (cmat a) (cmat b) |
679 | 683 | ||
680 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 684 | multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
681 | multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) | 685 | multiplyC a b = cmultiply cmultiplyC "cmultiplyR" (cmat a) (cmat b) |
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 0dccea2..61cb002 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -817,20 +817,21 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | |||
817 | 817 | ||
818 | //////////////////////////////////////////////////////////// | 818 | //////////////////////////////////////////////////////////// |
819 | 819 | ||
820 | int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) { | 820 | // int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)) { |
821 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 821 | // REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
822 | int i,j,k; | 822 | // int i,j,k; |
823 | for (i=0;i<ar;i++) { | 823 | // for (i=0;i<ar;i++) { |
824 | for(j=0;j<bc;j++) { | 824 | // for(j=0;j<bc;j++) { |
825 | double temp = 0; | 825 | // double temp = 0; |
826 | for(k=0;k<ac;k++) { | 826 | // for(k=0;k<ac;k++) { |
827 | temp += ap[i*ac+k]*bp[k*bc+j]; | 827 | // temp += ap[i*ac+k]*bp[k*bc+j]; |
828 | } | 828 | // } |
829 | rp[i*rc+j] = temp; | 829 | // rp[i*rc+j] = temp; |
830 | } | 830 | // } |
831 | } | 831 | // } |
832 | OK | 832 | // OK |
833 | } | 833 | // } |
834 | |||
834 | 835 | ||
835 | int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) { | 836 | int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) { |
836 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 837 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
@@ -839,52 +840,55 @@ int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)) { | |||
839 | for(j=0;j<bc;j++) { | 840 | for(j=0;j<bc;j++) { |
840 | doublecomplex temp = {0,0}; | 841 | doublecomplex temp = {0,0}; |
841 | for(k=0;k<ac;k++) { | 842 | for(k=0;k<ac;k++) { |
842 | doublecomplex aik = ((doublecomplex*)ap)[i*ac+k]; | 843 | doublecomplex A = ((doublecomplex*)ap)[i*ac+k]; |
843 | doublecomplex bkj = ((doublecomplex*)bp)[k*bc+j]; | 844 | doublecomplex B = ((doublecomplex*)bp)[k*bc+j]; |
844 | //double w = aik.r+aik.i+bkj.r+bkj.i; | 845 | double w = A.r * B.r - A.i * B.i; |
845 | //if (w>w) exit(1); | 846 | double w1 = A.r * B.r; |
846 | //printf("%d",w>w); | 847 | double w2 = A.i * B.i; |
847 | temp.r += aik.r * bkj.r - aik.i * bkj.i; | 848 | if(w != w) { |
848 | temp.i += aik.r * bkj.i + aik.i * bkj.r; | 849 | printf("at : %d %d %d\n",i,j,k); |
849 | //printf("%f %f %f %f \n",aik.r,aik.i,bkj.r,bkj.i); | 850 | printf("%f %f %f\n",A.i,B.i, A.i * B.i); |
850 | //printf("%f %f %f \n",w,temp.r,temp.i); | 851 | printf("%f %f %f\n",A.i,B.i, w2); |
851 | 852 | exit(1); | |
853 | } | ||
854 | temp.r += (w + w1-w2)/2; | ||
855 | //temp.r += w; | ||
856 | temp.i += A.r * B.i + A.i * B.r; | ||
852 | } | 857 | } |
853 | ((doublecomplex*)rp)[i*rc+j] = temp; | 858 | ((doublecomplex*)rp)[i*rc+j] = temp; |
854 | //printf("%f %f\n",temp.r,temp.i); | ||
855 | } | 859 | } |
856 | } | 860 | } |
857 | OK | 861 | OK |
858 | } | 862 | } |
859 | 863 | ||
860 | void dgemm_(char *, char *, integer *, integer *, integer *, | 864 | // void dgemm_(char *, char *, integer *, integer *, integer *, |
861 | double *, const double *, integer *, const double *, | 865 | // double *, const double *, integer *, const double *, |
862 | integer *, double *, double *, integer *); | 866 | // integer *, double *, double *, integer *); |
863 | 867 | // | |
864 | void zgemm_(char *, char *, integer *, integer *, integer *, | 868 | // void zgemm_(char *, char *, integer *, integer *, integer *, |
865 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, | 869 | // doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, |
866 | integer *, doublecomplex *, doublecomplex *, integer *); | 870 | // integer *, doublecomplex *, doublecomplex *, integer *); |
867 | 871 | // | |
868 | 872 | // | |
869 | int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { | 873 | // int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)) { |
870 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 874 | // REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
871 | double alpha = 1; | 875 | // double alpha = 1; |
872 | double beta = 0; | 876 | // double beta = 0; |
873 | integer m = ar; | 877 | // integer m = ar; |
874 | integer n = bc; | 878 | // integer n = bc; |
875 | integer k = ac; | 879 | // integer k = ac; |
876 | int i,j; | 880 | // int i,j; |
877 | dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); | 881 | // dgemm_("N","N",&m,&n,&k,&alpha,ap,&m,bp,&k,&beta,rp,&m); |
878 | OK | 882 | // OK |
879 | } | 883 | // } |
880 | 884 | // | |
881 | int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { | 885 | // int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)) { |
882 | REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 886 | // REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
883 | integer m = ar; | 887 | // integer m = ar; |
884 | integer n = bc; | 888 | // integer n = bc; |
885 | integer k = ac; | 889 | // integer k = ac; |
886 | doublecomplex alpha = {1,0}; | 890 | // doublecomplex alpha = {1,0}; |
887 | doublecomplex beta = {0,0}; | 891 | // doublecomplex beta = {0,0}; |
888 | zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); | 892 | // zgemm_("N","N",&m,&n,&k,&alpha,(doublecomplex*)ap,&m,(doublecomplex*)bp,&k,&beta,(doublecomplex*)rp,&m); |
889 | OK | 893 | // OK |
890 | } | 894 | // } |
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index c0361a6..e8cba30 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -85,8 +85,8 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); | |||
85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); | 85 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); |
86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); | 86 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); |
87 | 87 | ||
88 | int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); | 88 | // int multiplyR(KDMAT(a),KDMAT(b),DMAT(r)); |
89 | int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); | 89 | // int multiplyC(KCMAT(a),KCMAT(b),CMAT(r)); |
90 | 90 | ||
91 | int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); | 91 | // int multiplyR2(KDMAT(a),KDMAT(b),DMAT(r)); |
92 | int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)); | 92 | int multiplyC2(KCMAT(a),KCMAT(b),CMAT(r)); |