diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-27 09:15:27 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-27 09:15:27 +0200 |
commit | 4d96b90c4cfd38cdb51f3dc66a8a644bd87cdbff (patch) | |
tree | d7b82283f08e5947b06fdec4f403a5bc87c09f35 | |
parent | 624046d6b55d37104f950e8888ab68c53a2e6bf0 (diff) |
use slice interface for lapack funcs (wip)
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 95 | ||||
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.h | 32 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 165 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 4 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 1 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 9 |
6 files changed, 163 insertions, 143 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index cdbaab9..baa0570 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -38,6 +38,9 @@ typedef float complex TCF; | |||
38 | // #define OK return 0; | 38 | // #define OK return 0; |
39 | // #endif | 39 | // #endif |
40 | 40 | ||
41 | |||
42 | // printf("%dx%d %d:%d\n",ar,ac,aXr,aXc); | ||
43 | |||
41 | #define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ | 44 | #define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \ |
42 | for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");} | 45 | for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");} |
43 | 46 | ||
@@ -56,7 +59,7 @@ inline int mod (int a, int b); | |||
56 | 59 | ||
57 | inline int64_t mod_l (int64_t a, int64_t b); | 60 | inline int64_t mod_l (int64_t a, int64_t b); |
58 | 61 | ||
59 | //--------------------------------------- | 62 | //////////////////////////////////////////////////////////////////////////////// |
60 | void asm_finit() { | 63 | void asm_finit() { |
61 | #ifdef i386 | 64 | #ifdef i386 |
62 | 65 | ||
@@ -78,8 +81,6 @@ void asm_finit() { | |||
78 | #endif | 81 | #endif |
79 | } | 82 | } |
80 | 83 | ||
81 | //--------------------------------------- | ||
82 | |||
83 | #if NANDEBUG | 84 | #if NANDEBUG |
84 | 85 | ||
85 | #define CHECKNANR(M,msg) \ | 86 | #define CHECKNANR(M,msg) \ |
@@ -109,16 +110,16 @@ for(k=0; k<(M##r * M##c); k++) { \ | |||
109 | #define CHECKNANR(M,msg) | 110 | #define CHECKNANR(M,msg) |
110 | #endif | 111 | #endif |
111 | 112 | ||
112 | //--------------------------------------- | ||
113 | 113 | ||
114 | //////////////////// real svd //////////////////////////////////// | 114 | //////////////////////////////////////////////////////////////////////////////// |
115 | //////////////////// real svd /////////////////////////////////////////////////// | ||
115 | 116 | ||
116 | /* Subroutine */ int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n, | 117 | /* Subroutine */ int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n, |
117 | doublereal *a, integer *lda, doublereal *s, doublereal *u, integer * | 118 | doublereal *a, integer *lda, doublereal *s, doublereal *u, integer * |
118 | ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, | 119 | ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, |
119 | integer *info); | 120 | integer *info); |
120 | 121 | ||
121 | int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | 122 | int svd_l_R(KODMAT(a),ODMAT(u), DVEC(s),ODMAT(v)) { |
122 | integer m = ar; | 123 | integer m = ar; |
123 | integer n = ac; | 124 | integer n = ac; |
124 | integer q = MIN(m,n); | 125 | integer q = MIN(m,n); |
@@ -181,7 +182,7 @@ int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | |||
181 | doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, | 182 | doublereal *vt, integer *ldvt, doublereal *work, integer *lwork, |
182 | integer *iwork, integer *info); | 183 | integer *iwork, integer *info); |
183 | 184 | ||
184 | int svd_l_Rdd(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) { | 185 | int svd_l_Rdd(KODMAT(a),ODMAT(u), DVEC(s),ODMAT(v)) { |
185 | integer m = ar; | 186 | integer m = ar; |
186 | integer n = ac; | 187 | integer n = ac; |
187 | integer q = MIN(m,n); | 188 | integer q = MIN(m,n); |
@@ -231,7 +232,7 @@ int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n, | |||
231 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, | 232 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, |
232 | integer *lwork, doublereal *rwork, integer *info); | 233 | integer *lwork, doublereal *rwork, integer *info); |
233 | 234 | ||
234 | int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | 235 | int svd_l_C(KOCMAT(a),OCMAT(u), DVEC(s),OCMAT(v)) { |
235 | integer m = ar; | 236 | integer m = ar; |
236 | integer n = ac; | 237 | integer n = ac; |
237 | integer q = MIN(m,n); | 238 | integer q = MIN(m,n); |
@@ -297,7 +298,7 @@ int zgesdd_ (char *jobz, integer *m, integer *n, | |||
297 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, | 298 | integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work, |
298 | integer *lwork, doublereal *rwork, integer* iwork, integer *info); | 299 | integer *lwork, doublereal *rwork, integer* iwork, integer *info); |
299 | 300 | ||
300 | int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | 301 | int svd_l_Cdd(KOCMAT(a),OCMAT(u), DVEC(s),OCMAT(v)) { |
301 | //printf("entro\n"); | 302 | //printf("entro\n"); |
302 | integer m = ar; | 303 | integer m = ar; |
303 | integer n = ac; | 304 | integer n = ac; |
@@ -358,7 +359,7 @@ int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) { | |||
358 | integer *ldvl, doublecomplex *vr, integer *ldvr, doublecomplex *work, | 359 | integer *ldvl, doublecomplex *vr, integer *ldvr, doublecomplex *work, |
359 | integer *lwork, doublereal *rwork, integer *info); | 360 | integer *lwork, doublereal *rwork, integer *info); |
360 | 361 | ||
361 | int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | 362 | int eig_l_C(KOCMAT(a), OCMAT(u), CVEC(s),OCMAT(v)) { |
362 | integer n = ar; | 363 | integer n = ar; |
363 | REQUIRES(ac==n && sn==n, BAD_SIZE); | 364 | REQUIRES(ac==n && sn==n, BAD_SIZE); |
364 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); | 365 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); |
@@ -413,7 +414,7 @@ int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) { | |||
413 | integer *ldvl, doublereal *vr, integer *ldvr, doublereal *work, | 414 | integer *ldvl, doublereal *vr, integer *ldvr, doublereal *work, |
414 | integer *lwork, integer *info); | 415 | integer *lwork, integer *info); |
415 | 416 | ||
416 | int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) { | 417 | int eig_l_R(KODMAT(a),ODMAT(u), CVEC(s),ODMAT(v)) { |
417 | integer n = ar; | 418 | integer n = ar; |
418 | REQUIRES(ac==n && sn==n, BAD_SIZE); | 419 | REQUIRES(ac==n && sn==n, BAD_SIZE); |
419 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); | 420 | REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE); |
@@ -461,7 +462,7 @@ int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) { | |||
461 | integer *lda, doublereal *w, doublereal *work, integer *lwork, | 462 | integer *lda, doublereal *w, doublereal *work, integer *lwork, |
462 | integer *info); | 463 | integer *info); |
463 | 464 | ||
464 | int eig_l_S(int wantV,KDMAT(a),DVEC(s),DMAT(v)) { | 465 | int eig_l_S(int wantV,KODMAT(a),DVEC(s),ODMAT(v)) { |
465 | integer n = ar; | 466 | integer n = ar; |
466 | REQUIRES(ac==n && sn==n, BAD_SIZE); | 467 | REQUIRES(ac==n && sn==n, BAD_SIZE); |
467 | REQUIRES(vr==n && vc==n, BAD_SIZE); | 468 | REQUIRES(vr==n && vc==n, BAD_SIZE); |
@@ -499,7 +500,7 @@ int eig_l_S(int wantV,KDMAT(a),DVEC(s),DMAT(v)) { | |||
499 | *a, integer *lda, doublereal *w, doublecomplex *work, integer *lwork, | 500 | *a, integer *lda, doublereal *w, doublecomplex *work, integer *lwork, |
500 | doublereal *rwork, integer *info); | 501 | doublereal *rwork, integer *info); |
501 | 502 | ||
502 | int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) { | 503 | int eig_l_H(int wantV,KOCMAT(a),DVEC(s),OCMAT(v)) { |
503 | integer n = ar; | 504 | integer n = ar; |
504 | REQUIRES(ac==n && sn==n, BAD_SIZE); | 505 | REQUIRES(ac==n && sn==n, BAD_SIZE); |
505 | REQUIRES(vr==n && vc==n, BAD_SIZE); | 506 | REQUIRES(vr==n && vc==n, BAD_SIZE); |
@@ -541,7 +542,7 @@ int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) { | |||
541 | /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer | 542 | /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer |
542 | *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info); | 543 | *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info); |
543 | 544 | ||
544 | int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | 545 | int linearSolveR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { |
545 | integer n = ar; | 546 | integer n = ar; |
546 | integer nhrs = bc; | 547 | integer nhrs = bc; |
547 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 548 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
@@ -571,7 +572,7 @@ int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | |||
571 | integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer * | 572 | integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer * |
572 | info); | 573 | info); |
573 | 574 | ||
574 | int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | 575 | int linearSolveC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { |
575 | integer n = ar; | 576 | integer n = ar; |
576 | integer nhrs = bc; | 577 | integer nhrs = bc; |
577 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 578 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
@@ -601,7 +602,7 @@ int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | |||
601 | doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * | 602 | doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * |
602 | info); | 603 | info); |
603 | 604 | ||
604 | int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | 605 | int cholSolveR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { |
605 | integer n = ar; | 606 | integer n = ar; |
606 | integer nhrs = bc; | 607 | integer nhrs = bc; |
607 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 608 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
@@ -623,7 +624,7 @@ int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | |||
623 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, | 624 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, |
624 | integer *info); | 625 | integer *info); |
625 | 626 | ||
626 | int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | 627 | int cholSolveC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { |
627 | integer n = ar; | 628 | integer n = ar; |
628 | integer nhrs = bc; | 629 | integer nhrs = bc; |
629 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | 630 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); |
@@ -645,7 +646,7 @@ int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | |||
645 | nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, | 646 | nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, |
646 | doublereal *work, integer *lwork, integer *info); | 647 | doublereal *work, integer *lwork, integer *info); |
647 | 648 | ||
648 | int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | 649 | int linearSolveLSR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { |
649 | integer m = ar; | 650 | integer m = ar; |
650 | integer n = ac; | 651 | integer n = ac; |
651 | integer nrhs = bc; | 652 | integer nrhs = bc; |
@@ -693,7 +694,7 @@ int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) { | |||
693 | nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, | 694 | nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, |
694 | doublecomplex *work, integer *lwork, integer *info); | 695 | doublecomplex *work, integer *lwork, integer *info); |
695 | 696 | ||
696 | int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | 697 | int linearSolveLSC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { |
697 | integer m = ar; | 698 | integer m = ar; |
698 | integer n = ac; | 699 | integer n = ac; |
699 | integer nrhs = bc; | 700 | integer nrhs = bc; |
@@ -742,7 +743,7 @@ int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)) { | |||
742 | s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork, | 743 | s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork, |
743 | integer *info); | 744 | integer *info); |
744 | 745 | ||
745 | int linearSolveSVDR_l(double rcond,KDMAT(a),KDMAT(b),DMAT(x)) { | 746 | int linearSolveSVDR_l(double rcond,KODMAT(a),KODMAT(b),ODMAT(x)) { |
746 | integer m = ar; | 747 | integer m = ar; |
747 | integer n = ac; | 748 | integer n = ac; |
748 | integer nrhs = bc; | 749 | integer nrhs = bc; |
@@ -801,7 +802,7 @@ int zgelss_(integer *m, integer *n, integer *nhrs, | |||
801 | doublecomplex *work, integer* lwork, doublereal* rwork, | 802 | doublecomplex *work, integer* lwork, doublereal* rwork, |
802 | integer *info); | 803 | integer *info); |
803 | 804 | ||
804 | int linearSolveSVDC_l(double rcond, KCMAT(a),KCMAT(b),CMAT(x)) { | 805 | int linearSolveSVDC_l(double rcond, KOCMAT(a),KOCMAT(b),OCMAT(x)) { |
805 | integer m = ar; | 806 | integer m = ar; |
806 | integer n = ac; | 807 | integer n = ac; |
807 | integer nrhs = bc; | 808 | integer nrhs = bc; |
@@ -859,7 +860,7 @@ int linearSolveSVDC_l(double rcond, KCMAT(a),KCMAT(b),CMAT(x)) { | |||
859 | /* Subroutine */ int zpotrf_(char *uplo, integer *n, doublecomplex *a, | 860 | /* Subroutine */ int zpotrf_(char *uplo, integer *n, doublecomplex *a, |
860 | integer *lda, integer *info); | 861 | integer *lda, integer *info); |
861 | 862 | ||
862 | int chol_l_H(KCMAT(a),CMAT(l)) { | 863 | int chol_l_H(KOCMAT(a),OCMAT(l)) { |
863 | integer n = ar; | 864 | integer n = ar; |
864 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); | 865 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); |
865 | DEBUGMSG("chol_l_H"); | 866 | DEBUGMSG("chol_l_H"); |
@@ -871,9 +872,9 @@ int chol_l_H(KCMAT(a),CMAT(l)) { | |||
871 | CHECK(res,res); | 872 | CHECK(res,res); |
872 | doublecomplex zero = {0.,0.}; | 873 | doublecomplex zero = {0.,0.}; |
873 | int r,c; | 874 | int r,c; |
874 | for (r=0; r<lr-1; r++) { | 875 | for (r=0; r<lr; r++) { |
875 | for(c=r+1; c<lc; c++) { | 876 | for(c=0; c<r; c++) { |
876 | lp[r*lc+c] = zero; | 877 | AT(l,r,c) = zero; |
877 | } | 878 | } |
878 | } | 879 | } |
879 | OK | 880 | OK |
@@ -883,7 +884,7 @@ int chol_l_H(KCMAT(a),CMAT(l)) { | |||
883 | /* Subroutine */ int dpotrf_(char *uplo, integer *n, doublereal *a, integer * | 884 | /* Subroutine */ int dpotrf_(char *uplo, integer *n, doublereal *a, integer * |
884 | lda, integer *info); | 885 | lda, integer *info); |
885 | 886 | ||
886 | int chol_l_S(KDMAT(a),DMAT(l)) { | 887 | int chol_l_S(KODMAT(a),ODMAT(l)) { |
887 | integer n = ar; | 888 | integer n = ar; |
888 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); | 889 | REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE); |
889 | DEBUGMSG("chol_l_S"); | 890 | DEBUGMSG("chol_l_S"); |
@@ -894,9 +895,9 @@ int chol_l_S(KDMAT(a),DMAT(l)) { | |||
894 | CHECK(res>0,NODEFPOS); | 895 | CHECK(res>0,NODEFPOS); |
895 | CHECK(res,res); | 896 | CHECK(res,res); |
896 | int r,c; | 897 | int r,c; |
897 | for (r=0; r<lr-1; r++) { | 898 | for (r=0; r<lr; r++) { |
898 | for(c=r+1; c<lc; c++) { | 899 | for(c=0; c<r; c++) { |
899 | lp[r*lc+c] = 0.; | 900 | AT(l,r,c) = 0.; |
900 | } | 901 | } |
901 | } | 902 | } |
902 | OK | 903 | OK |
@@ -907,7 +908,7 @@ int chol_l_S(KDMAT(a),DMAT(l)) { | |||
907 | /* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer * | 908 | /* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer * |
908 | lda, doublereal *tau, doublereal *work, integer *info); | 909 | lda, doublereal *tau, doublereal *work, integer *info); |
909 | 910 | ||
910 | int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | 911 | int qr_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { |
911 | integer m = ar; | 912 | integer m = ar; |
912 | integer n = ac; | 913 | integer n = ac; |
913 | integer mn = MIN(m,n); | 914 | integer mn = MIN(m,n); |
@@ -926,7 +927,7 @@ int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | |||
926 | /* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a, | 927 | /* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a, |
927 | integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); | 928 | integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); |
928 | 929 | ||
929 | int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | 930 | int qr_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { |
930 | integer m = ar; | 931 | integer m = ar; |
931 | integer n = ac; | 932 | integer n = ac; |
932 | integer mn = MIN(m,n); | 933 | integer mn = MIN(m,n); |
@@ -946,7 +947,7 @@ int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | |||
946 | a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, | 947 | a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, |
947 | integer *info); | 948 | integer *info); |
948 | 949 | ||
949 | int c_dorgqr(KDMAT(a), KDVEC(tau), DMAT(r)) { | 950 | int c_dorgqr(KODMAT(a), KDVEC(tau), ODMAT(r)) { |
950 | integer m = ar; | 951 | integer m = ar; |
951 | integer n = MIN(ac,ar); | 952 | integer n = MIN(ac,ar); |
952 | integer k = taun; | 953 | integer k = taun; |
@@ -966,7 +967,7 @@ int c_dorgqr(KDMAT(a), KDVEC(tau), DMAT(r)) { | |||
966 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | 967 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * |
967 | work, integer *lwork, integer *info); | 968 | work, integer *lwork, integer *info); |
968 | 969 | ||
969 | int c_zungqr(KCMAT(a), KCVEC(tau), CMAT(r)) { | 970 | int c_zungqr(KOCMAT(a), KCVEC(tau), OCMAT(r)) { |
970 | integer m = ar; | 971 | integer m = ar; |
971 | integer n = MIN(ac,ar); | 972 | integer n = MIN(ac,ar); |
972 | integer k = taun; | 973 | integer k = taun; |
@@ -989,7 +990,7 @@ int c_zungqr(KCMAT(a), KCVEC(tau), CMAT(r)) { | |||
989 | doublereal *a, integer *lda, doublereal *tau, doublereal *work, | 990 | doublereal *a, integer *lda, doublereal *tau, doublereal *work, |
990 | integer *lwork, integer *info); | 991 | integer *lwork, integer *info); |
991 | 992 | ||
992 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | 993 | int hess_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { |
993 | integer m = ar; | 994 | integer m = ar; |
994 | integer n = ac; | 995 | integer n = ac; |
995 | integer mn = MIN(m,n); | 996 | integer mn = MIN(m,n); |
@@ -1012,7 +1013,7 @@ int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) { | |||
1012 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | 1013 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * |
1013 | work, integer *lwork, integer *info); | 1014 | work, integer *lwork, integer *info); |
1014 | 1015 | ||
1015 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | 1016 | int hess_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { |
1016 | integer m = ar; | 1017 | integer m = ar; |
1017 | integer n = ac; | 1018 | integer n = ac; |
1018 | integer mn = MIN(m,n); | 1019 | integer mn = MIN(m,n); |
@@ -1037,7 +1038,7 @@ int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) { | |||
1037 | doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, | 1038 | doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, |
1038 | integer *lwork, logical *bwork, integer *info); | 1039 | integer *lwork, logical *bwork, integer *info); |
1039 | 1040 | ||
1040 | int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) { | 1041 | int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) { |
1041 | integer m = ar; | 1042 | integer m = ar; |
1042 | integer n = ac; | 1043 | integer n = ac; |
1043 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | 1044 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); |
@@ -1077,7 +1078,7 @@ int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) { | |||
1077 | doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, | 1078 | doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, |
1078 | doublereal *rwork, logical *bwork, integer *info); | 1079 | doublereal *rwork, logical *bwork, integer *info); |
1079 | 1080 | ||
1080 | int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { | 1081 | int schur_l_C(KOCMAT(a), OCMAT(u), OCMAT(s)) { |
1081 | integer m = ar; | 1082 | integer m = ar; |
1082 | integer n = ac; | 1083 | integer n = ac; |
1083 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | 1084 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); |
@@ -1109,7 +1110,7 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { | |||
1109 | /* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * | 1110 | /* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * |
1110 | lda, integer *ipiv, integer *info); | 1111 | lda, integer *ipiv, integer *info); |
1111 | 1112 | ||
1112 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) { | 1113 | int lu_l_R(KODMAT(a), DVEC(ipiv), ODMAT(r)) { |
1113 | integer m = ar; | 1114 | integer m = ar; |
1114 | integer n = ac; | 1115 | integer n = ac; |
1115 | integer mn = MIN(m,n); | 1116 | integer mn = MIN(m,n); |
@@ -1135,7 +1136,7 @@ int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) { | |||
1135 | /* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a, | 1136 | /* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a, |
1136 | integer *lda, integer *ipiv, integer *info); | 1137 | integer *lda, integer *ipiv, integer *info); |
1137 | 1138 | ||
1138 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | 1139 | int lu_l_C(KOCMAT(a), DVEC(ipiv), OCMAT(r)) { |
1139 | integer m = ar; | 1140 | integer m = ar; |
1140 | integer n = ac; | 1141 | integer n = ac; |
1141 | integer mn = MIN(m,n); | 1142 | integer mn = MIN(m,n); |
@@ -1164,7 +1165,7 @@ int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | |||
1164 | doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * | 1165 | doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * |
1165 | ldb, integer *info); | 1166 | ldb, integer *info); |
1166 | 1167 | ||
1167 | int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { | 1168 | int luS_l_R(KODMAT(a), KDVEC(ipiv), KODMAT(b), ODMAT(x)) { |
1168 | integer m = ar; | 1169 | integer m = ar; |
1169 | integer n = ac; | 1170 | integer n = ac; |
1170 | integer mrhs = br; | 1171 | integer mrhs = br; |
@@ -1189,7 +1190,7 @@ int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) { | |||
1189 | doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, | 1190 | doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, |
1190 | integer *ldb, integer *info); | 1191 | integer *ldb, integer *info); |
1191 | 1192 | ||
1192 | int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) { | 1193 | int luS_l_C(KOCMAT(a), KDVEC(ipiv), KOCMAT(b), OCMAT(x)) { |
1193 | integer m = ar; | 1194 | integer m = ar; |
1194 | integer n = ac; | 1195 | integer n = ac; |
1195 | integer mrhs = br; | 1196 | integer mrhs = br; |
@@ -1215,7 +1216,7 @@ void dgemm_(char *, char *, integer *, integer *, integer *, | |||
1215 | double *, const double *, integer *, const double *, | 1216 | double *, const double *, integer *, const double *, |
1216 | integer *, double *, double *, integer *); | 1217 | integer *, double *, double *, integer *); |
1217 | 1218 | ||
1218 | int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) { | 1219 | int multiplyR(int ta, int tb, KODMAT(a),KODMAT(b),ODMAT(r)) { |
1219 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 1220 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
1220 | DEBUGMSG("dgemm_"); | 1221 | DEBUGMSG("dgemm_"); |
1221 | CHECKNANR(a,"NaN multR Input\n") | 1222 | CHECKNANR(a,"NaN multR Input\n") |
@@ -1237,7 +1238,7 @@ void zgemm_(char *, char *, integer *, integer *, integer *, | |||
1237 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, | 1238 | doublecomplex *, const doublecomplex *, integer *, const doublecomplex *, |
1238 | integer *, doublecomplex *, doublecomplex *, integer *); | 1239 | integer *, doublecomplex *, doublecomplex *, integer *); |
1239 | 1240 | ||
1240 | int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) { | 1241 | int multiplyC(int ta, int tb, KOCMAT(a),KOCMAT(b),OCMAT(r)) { |
1241 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 1242 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
1242 | DEBUGMSG("zgemm_"); | 1243 | DEBUGMSG("zgemm_"); |
1243 | CHECKNANC(a,"NaN multC Input\n") | 1244 | CHECKNANC(a,"NaN multC Input\n") |
@@ -1262,7 +1263,7 @@ void sgemm_(char *, char *, integer *, integer *, integer *, | |||
1262 | float *, const float *, integer *, const float *, | 1263 | float *, const float *, integer *, const float *, |
1263 | integer *, float *, float *, integer *); | 1264 | integer *, float *, float *, integer *); |
1264 | 1265 | ||
1265 | int multiplyF(int ta, int tb, KFMAT(a),KFMAT(b),FMAT(r)) { | 1266 | int multiplyF(int ta, int tb, KOFMAT(a),KOFMAT(b),OFMAT(r)) { |
1266 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 1267 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
1267 | DEBUGMSG("sgemm_"); | 1268 | DEBUGMSG("sgemm_"); |
1268 | integer m = ta?ac:ar; | 1269 | integer m = ta?ac:ar; |
@@ -1281,7 +1282,7 @@ void cgemm_(char *, char *, integer *, integer *, integer *, | |||
1281 | complex *, const complex *, integer *, const complex *, | 1282 | complex *, const complex *, integer *, const complex *, |
1282 | integer *, complex *, complex *, integer *); | 1283 | integer *, complex *, complex *, integer *); |
1283 | 1284 | ||
1284 | int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { | 1285 | int multiplyQ(int ta, int tb, KOQMAT(a),KOQMAT(b),OQMAT(r)) { |
1285 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 1286 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
1286 | DEBUGMSG("cgemm_"); | 1287 | DEBUGMSG("cgemm_"); |
1287 | integer m = ta?ac:ar; | 1288 | integer m = ta?ac:ar; |
@@ -1564,13 +1565,13 @@ int remapQ(KOIMAT(i), KOIMAT(j), KOQMAT(m), OQMAT(r)) { | |||
1564 | 1565 | ||
1565 | //////////////////////////////////////////////////////////////////////////////// | 1566 | //////////////////////////////////////////////////////////////////////////////// |
1566 | 1567 | ||
1567 | int saveMatrix(char * file, char * format, KDMAT(a)){ | 1568 | int saveMatrix(char * file, char * format, KODMAT(a)){ |
1568 | FILE * fp; | 1569 | FILE * fp; |
1569 | fp = fopen (file, "w"); | 1570 | fp = fopen (file, "w"); |
1570 | int r, c; | 1571 | int r, c; |
1571 | for (r=0;r<ar; r++) { | 1572 | for (r=0;r<ar; r++) { |
1572 | for (c=0; c<ac; c++) { | 1573 | for (c=0; c<ac; c++) { |
1573 | fprintf(fp,format,ap[r*ac+c]); | 1574 | fprintf(fp,format,AT(a,r,c)); |
1574 | if (c<ac-1) { | 1575 | if (c<ac-1) { |
1575 | fprintf(fp," "); | 1576 | fprintf(fp," "); |
1576 | } else { | 1577 | } else { |
diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h index bf8c5e9..b38ca7a 100644 --- a/packages/base/src/Internal/C/lapack-aux.h +++ b/packages/base/src/Internal/C/lapack-aux.h | |||
@@ -52,16 +52,6 @@ typedef short ftnlen; | |||
52 | #define CMAT(A) int A##r, int A##c, doublecomplex* A##p | 52 | #define CMAT(A) int A##r, int A##c, doublecomplex* A##p |
53 | #define PMAT(A) int A##r, int A##c, void* A##p, int A##s | 53 | #define PMAT(A) int A##r, int A##c, void* A##p, int A##s |
54 | 54 | ||
55 | #define OIMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, int* A##p | ||
56 | #define OLMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, int64_t* A##p | ||
57 | #define OFMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, float* A##p | ||
58 | #define ODMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, double* A##p | ||
59 | #define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p | ||
60 | #define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p | ||
61 | |||
62 | #define VECG(T,A) int A##n, T* A##p | ||
63 | #define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p | ||
64 | |||
65 | #define KIVEC(A) int A##n, const int*A##p | 55 | #define KIVEC(A) int A##n, const int*A##p |
66 | #define KLVEC(A) int A##n, const int64_t*A##p | 56 | #define KLVEC(A) int A##n, const int64_t*A##p |
67 | #define KFVEC(A) int A##n, const float*A##p | 57 | #define KFVEC(A) int A##n, const float*A##p |
@@ -78,12 +68,22 @@ typedef short ftnlen; | |||
78 | #define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p | 68 | #define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p |
79 | #define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s | 69 | #define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s |
80 | 70 | ||
81 | #define KOIMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const int* A##p | 71 | #define VECG(T,A) int A##n, T* A##p |
82 | #define KOLMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const int64_t* A##p | 72 | #define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p |
83 | #define KOFMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const float* A##p | 73 | |
84 | #define KODMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const double* A##p | 74 | #define OIMAT(A) MATG(int,A) |
85 | #define KOQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const complex* A##p | 75 | #define OLMAT(A) MATG(int64_t,A) |
86 | #define KOCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const doublecomplex* A##p | 76 | #define OFMAT(A) MATG(float,A) |
77 | #define ODMAT(A) MATG(double,A) | ||
78 | #define OQMAT(A) MATG(complex,A) | ||
79 | #define OCMAT(A) MATG(doublecomplex,A) | ||
80 | |||
81 | #define KOIMAT(A) MATG(const int,A) | ||
82 | #define KOLMAT(A) MATG(const int64_t,A) | ||
83 | #define KOFMAT(A) MATG(const float,A) | ||
84 | #define KODMAT(A) MATG(const double,A) | ||
85 | #define KOQMAT(A) MATG(const complex,A) | ||
86 | #define KOCMAT(A) MATG(const doublecomplex,A) | ||
87 | 87 | ||
88 | #define AT(m,i,j) (m##p[(i)*m##Xr + (j)*m##Xc]) | 88 | #define AT(m,i,j) (m##p[(i)*m##Xr + (j)*m##Xc]) |
89 | #define TRAV(m,i,j) int i,j; for (i=0;i<m##r;i++) for (j=0;j<m##c;j++) | 89 | #define TRAV(m,i,j) int i,j; for (i=0;i<m##r;i++) for (j=0;j<m##c;j++) |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 5319e95..2c7148b 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -29,16 +29,12 @@ import System.IO.Unsafe(unsafePerformIO) | |||
29 | ----------------------------------------------------------------------------------- | 29 | ----------------------------------------------------------------------------------- |
30 | 30 | ||
31 | infixl 1 # | 31 | infixl 1 # |
32 | a # b = applyRaw a b | 32 | a # b = apply a b |
33 | {-# INLINE (#) #-} | 33 | {-# INLINE (#) #-} |
34 | 34 | ||
35 | infixl 1 #! | ||
36 | a #! b = apply a b | ||
37 | {-# INLINE (#!) #-} | ||
38 | |||
39 | ----------------------------------------------------------------------------------- | 35 | ----------------------------------------------------------------------------------- |
40 | 36 | ||
41 | type TMMM t = t ..> t ..> t ..> Ok | 37 | type TMMM t = t ::> t ::> t ::> Ok |
42 | 38 | ||
43 | type F = Float | 39 | type F = Float |
44 | type Q = Complex Float | 40 | type Q = Complex Float |
@@ -47,8 +43,8 @@ foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM R | |||
47 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C | 43 | foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TMMM C |
48 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F | 44 | foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TMMM F |
49 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q | 45 | foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TMMM Q |
50 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> I ::> I ::> I ::> Ok | 46 | foreign import ccall unsafe "multiplyI" c_multiplyI :: I -> TMMM I |
51 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> Z ::> Z ::> Z ::> Ok | 47 | foreign import ccall unsafe "multiplyL" c_multiplyL :: Z -> TMMM Z |
52 | 48 | ||
53 | isT (rowOrder -> False) = 0 | 49 | isT (rowOrder -> False) = 0 |
54 | isT _ = 1 | 50 | isT _ = 1 |
@@ -84,7 +80,7 @@ multiplyI m a b = unsafePerformIO $ do | |||
84 | when (cols a /= rows b) $ error $ | 80 | when (cols a /= rows b) $ error $ |
85 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 81 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
86 | s <- createMatrix ColumnMajor (rows a) (cols b) | 82 | s <- createMatrix ColumnMajor (rows a) (cols b) |
87 | c_multiplyI m #! a #! b #! s #|"c_multiplyI" | 83 | c_multiplyI m # a # b # s #|"c_multiplyI" |
88 | return s | 84 | return s |
89 | 85 | ||
90 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z | 86 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z |
@@ -92,12 +88,12 @@ multiplyL m a b = unsafePerformIO $ do | |||
92 | when (cols a /= rows b) $ error $ | 88 | when (cols a /= rows b) $ error $ |
93 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 89 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
94 | s <- createMatrix ColumnMajor (rows a) (cols b) | 90 | s <- createMatrix ColumnMajor (rows a) (cols b) |
95 | c_multiplyL m #! a #! b #! s #|"c_multiplyL" | 91 | c_multiplyL m # a # b # s #|"c_multiplyL" |
96 | return s | 92 | return s |
97 | 93 | ||
98 | ----------------------------------------------------------------------------- | 94 | ----------------------------------------------------------------------------- |
99 | 95 | ||
100 | type TSVD t = t ..> t ..> R :> t ..> Ok | 96 | type TSVD t = t ::> t ::> R :> t ::> Ok |
101 | 97 | ||
102 | foreign import ccall unsafe "svd_l_R" dgesvd :: TSVD R | 98 | foreign import ccall unsafe "svd_l_R" dgesvd :: TSVD R |
103 | foreign import ccall unsafe "svd_l_C" zgesvd :: TSVD C | 99 | foreign import ccall unsafe "svd_l_C" zgesvd :: TSVD C |
@@ -126,8 +122,9 @@ svdAux f st x = unsafePerformIO $ do | |||
126 | v <- createMatrix ColumnMajor c c | 122 | v <- createMatrix ColumnMajor c c |
127 | f # x # u # s # v #| st | 123 | f # x # u # s # v #| st |
128 | return (u,s,v) | 124 | return (u,s,v) |
129 | where r = rows x | 125 | where |
130 | c = cols x | 126 | r = rows x |
127 | c = cols x | ||
131 | 128 | ||
132 | 129 | ||
133 | -- | Thin SVD of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'S\'. | 130 | -- | Thin SVD of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'S\'. |
@@ -152,9 +149,10 @@ thinSVDAux f st x = unsafePerformIO $ do | |||
152 | v <- createMatrix ColumnMajor q c | 149 | v <- createMatrix ColumnMajor q c |
153 | f # x # u # s # v #| st | 150 | f # x # u # s # v #| st |
154 | return (u,s,v) | 151 | return (u,s,v) |
155 | where r = rows x | 152 | where |
156 | c = cols x | 153 | r = rows x |
157 | q = min r c | 154 | c = cols x |
155 | q = min r c | ||
158 | 156 | ||
159 | 157 | ||
160 | -- | Singular values of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'N\'. | 158 | -- | Singular values of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'N\'. |
@@ -177,10 +175,11 @@ svAux f st x = unsafePerformIO $ do | |||
177 | s <- createVector q | 175 | s <- createVector q |
178 | g # x # s #| st | 176 | g # x # s #| st |
179 | return s | 177 | return s |
180 | where r = rows x | 178 | where |
181 | c = cols x | 179 | r = rows x |
182 | q = min r c | 180 | c = cols x |
183 | g ra ca pa nb pb = f ra ca pa 0 0 nullPtr nb pb 0 0 nullPtr | 181 | q = min r c |
182 | g ra ca xra xca pa nb pb = f ra ca xra xca pa 0 0 0 0 nullPtr nb pb 0 0 0 0 nullPtr | ||
184 | 183 | ||
185 | 184 | ||
186 | -- | Singular values and all right singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'N\' and jobvt == \'A\'. | 185 | -- | Singular values and all right singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'N\' and jobvt == \'A\'. |
@@ -196,10 +195,11 @@ rightSVAux f st x = unsafePerformIO $ do | |||
196 | v <- createMatrix ColumnMajor c c | 195 | v <- createMatrix ColumnMajor c c |
197 | g # x # s # v #| st | 196 | g # x # s # v #| st |
198 | return (s,v) | 197 | return (s,v) |
199 | where r = rows x | 198 | where |
200 | c = cols x | 199 | r = rows x |
201 | q = min r c | 200 | c = cols x |
202 | g ra ca pa = f ra ca pa 0 0 nullPtr | 201 | q = min r c |
202 | g ra ca xra xca pa = f ra ca xra xca pa 0 0 0 0 nullPtr | ||
203 | 203 | ||
204 | 204 | ||
205 | -- | Singular values and all left singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'A\' and jobvt == \'N\'. | 205 | -- | Singular values and all left singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'A\' and jobvt == \'N\'. |
@@ -215,25 +215,27 @@ leftSVAux f st x = unsafePerformIO $ do | |||
215 | s <- createVector q | 215 | s <- createVector q |
216 | g # x # u # s #| st | 216 | g # x # u # s #| st |
217 | return (u,s) | 217 | return (u,s) |
218 | where r = rows x | 218 | where |
219 | c = cols x | 219 | r = rows x |
220 | q = min r c | 220 | c = cols x |
221 | g ra ca pa ru cu pu nb pb = f ra ca pa ru cu pu nb pb 0 0 nullPtr | 221 | q = min r c |
222 | g ra ca xra xca pa ru cu xru xcu pu nb pb = f ra ca xra xca pa ru cu xru xcu pu nb pb 0 0 0 0 nullPtr | ||
222 | 223 | ||
223 | ----------------------------------------------------------------------------- | 224 | ----------------------------------------------------------------------------- |
224 | 225 | ||
225 | foreign import ccall unsafe "eig_l_R" dgeev :: R ..> R ..> C :> R ..> Ok | 226 | foreign import ccall unsafe "eig_l_R" dgeev :: R ::> R ::> C :> R ::> Ok |
226 | foreign import ccall unsafe "eig_l_C" zgeev :: C ..> C ..> C :> C ..> Ok | 227 | foreign import ccall unsafe "eig_l_C" zgeev :: C ::> C ::> C :> C ::> Ok |
227 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R ..> R :> R ..> Ok | 228 | foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> R ::> R :> R ::> Ok |
228 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok | 229 | foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ::> R :> C ::> Ok |
229 | 230 | ||
230 | eigAux f st m = unsafePerformIO $ do | 231 | eigAux f st m = unsafePerformIO $ do |
231 | l <- createVector r | 232 | l <- createVector r |
232 | v <- createMatrix ColumnMajor r r | 233 | v <- createMatrix ColumnMajor r r |
233 | g # m # l # v #| st | 234 | g # m # l # v #| st |
234 | return (l,v) | 235 | return (l,v) |
235 | where r = rows m | 236 | where |
236 | g ra ca pa = f ra ca pa 0 0 nullPtr | 237 | r = rows m |
238 | g ra ca xra xca pa = f ra ca xra xca pa 0 0 0 0 nullPtr | ||
237 | 239 | ||
238 | 240 | ||
239 | -- | Eigenvalues and right eigenvectors of a general complex matrix, using LAPACK's /zgeev/. | 241 | -- | Eigenvalues and right eigenvectors of a general complex matrix, using LAPACK's /zgeev/. |
@@ -242,11 +244,12 @@ eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Dou | |||
242 | eigC = eigAux zgeev "eigC" . fmat | 244 | eigC = eigAux zgeev "eigC" . fmat |
243 | 245 | ||
244 | eigOnlyAux f st m = unsafePerformIO $ do | 246 | eigOnlyAux f st m = unsafePerformIO $ do |
245 | l <- createVector r | 247 | l <- createVector r |
246 | g # m # l #| st | 248 | g # m # l #| st |
247 | return l | 249 | return l |
248 | where r = rows m | 250 | where |
249 | g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr | 251 | r = rows m |
252 | g ra ca xra xca pa nl pl = f ra ca xra xca pa 0 0 0 0 nullPtr nl pl 0 0 0 0 nullPtr | ||
250 | 253 | ||
251 | -- | Eigenvalues of a general complex matrix, using LAPACK's /zgeev/ with jobz == \'N\'. | 254 | -- | Eigenvalues of a general complex matrix, using LAPACK's /zgeev/ with jobz == \'N\'. |
252 | -- The eigenvalues are not sorted. | 255 | -- The eigenvalues are not sorted. |
@@ -264,12 +267,13 @@ eigR m = (s', v'') | |||
264 | 267 | ||
265 | eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) | 268 | eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) |
266 | eigRaux m = unsafePerformIO $ do | 269 | eigRaux m = unsafePerformIO $ do |
267 | l <- createVector r | 270 | l <- createVector r |
268 | v <- createMatrix ColumnMajor r r | 271 | v <- createMatrix ColumnMajor r r |
269 | g # m # l # v #| "eigR" | 272 | g # m # l # v #| "eigR" |
270 | return (l,v) | 273 | return (l,v) |
271 | where r = rows m | 274 | where |
272 | g ra ca pa = dgeev ra ca pa 0 0 nullPtr | 275 | r = rows m |
276 | g ra ca xra xca pa = dgeev ra ca xra xca pa 0 0 0 0 nullPtr | ||
273 | 277 | ||
274 | fixeig1 s = toComplex' (subVector 0 r (asReal s), subVector r r (asReal s)) | 278 | fixeig1 s = toComplex' (subVector 0 r (asReal s), subVector r r (asReal s)) |
275 | where r = dim s | 279 | where r = dim s |
@@ -291,11 +295,12 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat | |||
291 | ----------------------------------------------------------------------------- | 295 | ----------------------------------------------------------------------------- |
292 | 296 | ||
293 | eigSHAux f st m = unsafePerformIO $ do | 297 | eigSHAux f st m = unsafePerformIO $ do |
294 | l <- createVector r | 298 | l <- createVector r |
295 | v <- createMatrix ColumnMajor r r | 299 | v <- createMatrix ColumnMajor r r |
296 | f # m # l # v #| st | 300 | f # m # l # v #| st |
297 | return (l,v) | 301 | return (l,v) |
298 | where r = rows m | 302 | where |
303 | r = rows m | ||
299 | 304 | ||
300 | -- | Eigenvalues and right eigenvectors of a symmetric real matrix, using LAPACK's /dsyev/. | 305 | -- | Eigenvalues and right eigenvectors of a symmetric real matrix, using LAPACK's /dsyev/. |
301 | -- The eigenvectors are the columns of v. | 306 | -- The eigenvectors are the columns of v. |
@@ -314,8 +319,9 @@ eigS' = eigSHAux (dsyev 1) "eigS'" . fmat | |||
314 | -- The eigenvalues are sorted in descending order (use 'eigH'' for ascending order). | 319 | -- The eigenvalues are sorted in descending order (use 'eigH'' for ascending order). |
315 | eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) | 320 | eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) |
316 | eigH m = (s', fliprl v) | 321 | eigH m = (s', fliprl v) |
317 | where (s,v) = eigH' (fmat m) | 322 | where |
318 | s' = fromList . reverse . toList $ s | 323 | (s,v) = eigH' (fmat m) |
324 | s' = fromList . reverse . toList $ s | ||
319 | 325 | ||
320 | -- | 'eigH' in ascending order | 326 | -- | 'eigH' in ascending order |
321 | eigH' :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) | 327 | eigH' :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) |
@@ -346,10 +352,11 @@ linearSolveSQAux g f st a b | |||
346 | f # a # b # s #| st | 352 | f # a # b # s #| st |
347 | return s | 353 | return s |
348 | | otherwise = error $ st ++ " of nonsquare matrix" | 354 | | otherwise = error $ st ++ " of nonsquare matrix" |
349 | where n1 = rows a | 355 | where |
350 | n2 = cols a | 356 | n1 = rows a |
351 | r = rows b | 357 | n2 = cols a |
352 | c = cols b | 358 | r = rows b |
359 | c = cols b | ||
353 | 360 | ||
354 | -- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'. | 361 | -- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'. |
355 | linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double | 362 | linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double |
@@ -375,6 +382,7 @@ cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Comp | |||
375 | cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) | 382 | cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) |
376 | 383 | ||
377 | ----------------------------------------------------------------------------------- | 384 | ----------------------------------------------------------------------------------- |
385 | |||
378 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM R | 386 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM R |
379 | foreign import ccall unsafe "linearSolveLSC_l" zgels :: TMMM C | 387 | foreign import ccall unsafe "linearSolveLSC_l" zgels :: TMMM C |
380 | foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM R | 388 | foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM R |
@@ -384,9 +392,10 @@ linearSolveAux f st a b = unsafePerformIO $ do | |||
384 | r <- createMatrix ColumnMajor (max m n) nrhs | 392 | r <- createMatrix ColumnMajor (max m n) nrhs |
385 | f # a # b # r #| st | 393 | f # a # b # r #| st |
386 | return r | 394 | return r |
387 | where m = rows a | 395 | where |
388 | n = cols a | 396 | m = rows a |
389 | nrhs = cols b | 397 | n = cols a |
398 | nrhs = cols b | ||
390 | 399 | ||
391 | -- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'. | 400 | -- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'. |
392 | linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double | 401 | linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double |
@@ -418,7 +427,7 @@ linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) | |||
418 | 427 | ||
419 | ----------------------------------------------------------------------------------- | 428 | ----------------------------------------------------------------------------------- |
420 | 429 | ||
421 | type TMM t = t ..> t ..> Ok | 430 | type TMM t = t ::> t ::> Ok |
422 | 431 | ||
423 | foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C | 432 | foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C |
424 | foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R | 433 | foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R |
@@ -427,7 +436,8 @@ cholAux f st a = do | |||
427 | r <- createMatrix ColumnMajor n n | 436 | r <- createMatrix ColumnMajor n n |
428 | f # a # r #| st | 437 | f # a # r #| st |
429 | return r | 438 | return r |
430 | where n = rows a | 439 | where |
440 | n = rows a | ||
431 | 441 | ||
432 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. | 442 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. |
433 | cholH :: Matrix (Complex Double) -> Matrix (Complex Double) | 443 | cholH :: Matrix (Complex Double) -> Matrix (Complex Double) |
@@ -447,7 +457,7 @@ mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat | |||
447 | 457 | ||
448 | ----------------------------------------------------------------------------------- | 458 | ----------------------------------------------------------------------------------- |
449 | 459 | ||
450 | type TMVM t = t ..> t :> t ..> Ok | 460 | type TMVM t = t ::> t :> t ::> Ok |
451 | 461 | ||
452 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R | 462 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R |
453 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C | 463 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C |
@@ -504,13 +514,14 @@ hessAux f st a = unsafePerformIO $ do | |||
504 | tau <- createVector (mn-1) | 514 | tau <- createVector (mn-1) |
505 | f # a # tau # r #| st | 515 | f # a # tau # r #| st |
506 | return (r,tau) | 516 | return (r,tau) |
507 | where m = rows a | 517 | where |
508 | n = cols a | 518 | m = rows a |
509 | mn = min m n | 519 | n = cols a |
520 | mn = min m n | ||
510 | 521 | ||
511 | ----------------------------------------------------------------------------------- | 522 | ----------------------------------------------------------------------------------- |
512 | foreign import ccall unsafe "schur_l_R" dgees :: TMMM R | 523 | foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> R ::> Ok |
513 | foreign import ccall unsafe "schur_l_C" zgees :: TMMM C | 524 | foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> C ::> Ok |
514 | 525 | ||
515 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. | 526 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. |
516 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) | 527 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) |
@@ -525,11 +536,12 @@ schurAux f st a = unsafePerformIO $ do | |||
525 | s <- createMatrix ColumnMajor n n | 536 | s <- createMatrix ColumnMajor n n |
526 | f # a # u # s #| st | 537 | f # a # u # s #| st |
527 | return (u,s) | 538 | return (u,s) |
528 | where n = rows a | 539 | where |
540 | n = rows a | ||
529 | 541 | ||
530 | ----------------------------------------------------------------------------------- | 542 | ----------------------------------------------------------------------------------- |
531 | foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R | 543 | foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R |
532 | foreign import ccall unsafe "lu_l_C" zgetrf :: C ..> R :> C ..> Ok | 544 | foreign import ccall unsafe "lu_l_C" zgetrf :: C ::> R :> C ::> Ok |
533 | 545 | ||
534 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. | 546 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. |
535 | luR :: Matrix Double -> (Matrix Double, [Int]) | 547 | luR :: Matrix Double -> (Matrix Double, [Int]) |
@@ -544,12 +556,13 @@ luAux f st a = unsafePerformIO $ do | |||
544 | piv <- createVector (min n m) | 556 | piv <- createVector (min n m) |
545 | f # a # piv # lu #| st | 557 | f # a # piv # lu #| st |
546 | return (lu, map (pred.round) (toList piv)) | 558 | return (lu, map (pred.round) (toList piv)) |
547 | where n = rows a | 559 | where |
548 | m = cols a | 560 | n = rows a |
561 | m = cols a | ||
549 | 562 | ||
550 | ----------------------------------------------------------------------------------- | 563 | ----------------------------------------------------------------------------------- |
551 | 564 | ||
552 | type Tlus t = t ..> Double :> t ..> t ..> Ok | 565 | type Tlus t = t ::> Double :> t ::> t ::> Ok |
553 | 566 | ||
554 | foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R | 567 | foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R |
555 | foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C | 568 | foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index c4f95d8..54d9cb8 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -398,6 +398,7 @@ test = (ok, info) | |||
398 | print lgm | 398 | print lgm |
399 | print $ lgm <> lgm | 399 | print $ lgm <> lgm |
400 | 400 | ||
401 | putStrLn "checkGen" | ||
401 | print (checkGen (gen 5 :: Matrix R)) | 402 | print (checkGen (gen 5 :: Matrix R)) |
402 | print (checkGen (gen 5 :: Matrix Float)) | 403 | print (checkGen (gen 5 :: Matrix Float)) |
403 | print (checkGen (cgen 5 :: Matrix C)) | 404 | print (checkGen (cgen 5 :: Matrix C)) |
@@ -408,6 +409,7 @@ test = (ok, info) | |||
408 | print $ mutable (luST (const True)) (gen 5 :: Matrix R) | 409 | print $ mutable (luST (const True)) (gen 5 :: Matrix R) |
409 | print $ mutable (luST (const True)) (gen 5 :: Matrix (Mod 11 Z)) | 410 | print $ mutable (luST (const True)) (gen 5 :: Matrix (Mod 11 Z)) |
410 | 411 | ||
412 | putStrLn "checkLU" | ||
411 | print $ checkLU (magnit 0) (gen 5 :: Matrix R) | 413 | print $ checkLU (magnit 0) (gen 5 :: Matrix R) |
412 | print $ checkLU (magnit 0) (gen 5 :: Matrix Float) | 414 | print $ checkLU (magnit 0) (gen 5 :: Matrix Float) |
413 | print $ checkLU (magnit 0) (cgen 5 :: Matrix C) | 415 | print $ checkLU (magnit 0) (cgen 5 :: Matrix C) |
@@ -415,6 +417,7 @@ test = (ok, info) | |||
415 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) | 417 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) |
416 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) | 418 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) |
417 | 419 | ||
420 | putStrLn "checkSolve" | ||
418 | print $ checkSolve (gen 5 :: Matrix R) | 421 | print $ checkSolve (gen 5 :: Matrix R) |
419 | print $ checkSolve (gen 5 :: Matrix Float) | 422 | print $ checkSolve (gen 5 :: Matrix Float) |
420 | print $ checkSolve (cgen 5 :: Matrix C) | 423 | print $ checkSolve (cgen 5 :: Matrix C) |
@@ -422,6 +425,7 @@ test = (ok, info) | |||
422 | print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) | 425 | print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) |
423 | print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) | 426 | print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) |
424 | 427 | ||
428 | putStrLn "luSolve'" | ||
425 | print $ luSolve' (luPacked' tmm) (ident (rows tmm)) | 429 | print $ luSolve' (luPacked' tmm) (ident (rows tmm)) |
426 | print $ invershur tmm | 430 | print $ invershur tmm |
427 | 431 | ||
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 73cdf0c..23fda99 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -150,6 +150,7 @@ unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | |||
150 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) | 150 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
151 | freezeMatrix m = liftSTMatrix id m | 151 | freezeMatrix m = liftSTMatrix id m |
152 | 152 | ||
153 | -- FIXME | ||
153 | cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'}) | 154 | cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'}) |
154 | 155 | ||
155 | {-# INLINE safeIndexM #-} | 156 | {-# INLINE safeIndexM #-} |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 98eb4ef..258c3a3 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -845,10 +845,11 @@ viewBlock' r c m | |||
845 | | otherwise = Block m11 m12 m21 m22 | 845 | | otherwise = Block m11 m12 m21 m22 |
846 | where | 846 | where |
847 | (rt,ct) = size m | 847 | (rt,ct) = size m |
848 | m11 = sliceMatrix (0,0) (r,c) m | 848 | m11 = subm (0,0) (r,c) m |
849 | m12 = sliceMatrix (0,c) (r,ct-c) m | 849 | m12 = subm (0,c) (r,ct-c) m |
850 | m21 = sliceMatrix (r,0) (rt-r,c) m | 850 | m21 = subm (r,0) (rt-r,c) m |
851 | m22 = sliceMatrix (r,c) (rt-r,ct-c) m | 851 | m22 = subm (r,c) (rt-r,ct-c) m |
852 | subm = sliceMatrix | ||
852 | 853 | ||
853 | viewBlock m = viewBlock' n n m | 854 | viewBlock m = viewBlock' n n m |
854 | where | 855 | where |