summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c136
-rw-r--r--packages/base/src/Internal/LAPACK.hs72
2 files changed, 93 insertions, 115 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index ab78dac..8524962 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -901,18 +901,17 @@ int chol_l_S(ODMAT(l)) {
901 901
902//////////////////// QR factorization ///////////////////////// 902//////////////////// QR factorization /////////////////////////
903 903
904/* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer * 904int dgeqr2_(integer *m, integer *n, doublereal *a, integer *
905 lda, doublereal *tau, doublereal *work, integer *info); 905 lda, doublereal *tau, doublereal *work, integer *info);
906 906
907int qr_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { 907int qr_l_R(DVEC(tau), ODMAT(r)) {
908 integer m = ar; 908 integer m = rr;
909 integer n = ac; 909 integer n = rc;
910 integer mn = MIN(m,n); 910 integer mn = MIN(m,n);
911 REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); 911 REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE);
912 DEBUGMSG("qr_l_R"); 912 DEBUGMSG("qr_l_R");
913 double *WORK = (double*)malloc(n*sizeof(double)); 913 double *WORK = (double*)malloc(n*sizeof(double));
914 CHECK(!WORK,MEM); 914 CHECK(!WORK,MEM);
915 memcpy(rp,ap,m*n*sizeof(double));
916 integer res; 915 integer res;
917 dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); 916 dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res);
918 CHECK(res,res); 917 CHECK(res,res);
@@ -920,18 +919,17 @@ int qr_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) {
920 OK 919 OK
921} 920}
922 921
923/* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a, 922int zgeqr2_(integer *m, integer *n, doublecomplex *a,
924 integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); 923 integer *lda, doublecomplex *tau, doublecomplex *work, integer *info);
925 924
926int qr_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { 925int qr_l_C(CVEC(tau), OCMAT(r)) {
927 integer m = ar; 926 integer m = rr;
928 integer n = ac; 927 integer n = rc;
929 integer mn = MIN(m,n); 928 integer mn = MIN(m,n);
930 REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); 929 REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE);
931 DEBUGMSG("qr_l_C"); 930 DEBUGMSG("qr_l_C");
932 doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex)); 931 doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex));
933 CHECK(!WORK,MEM); 932 CHECK(!WORK,MEM);
934 memcpy(rp,ap,m*n*sizeof(doublecomplex));
935 integer res; 933 integer res;
936 zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); 934 zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res);
937 CHECK(res,res); 935 CHECK(res,res);
@@ -939,19 +937,18 @@ int qr_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) {
939 OK 937 OK
940} 938}
941 939
942/* Subroutine */ int dorgqr_(integer *m, integer *n, integer *k, doublereal * 940int dorgqr_(integer *m, integer *n, integer *k, doublereal *
943 a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, 941 a, integer *lda, doublereal *tau, doublereal *work, integer *lwork,
944 integer *info); 942 integer *info);
945 943
946int c_dorgqr(KODMAT(a), KDVEC(tau), ODMAT(r)) { 944int c_dorgqr(KDVEC(tau), ODMAT(r)) {
947 integer m = ar; 945 integer m = rr;
948 integer n = MIN(ac,ar); 946 integer n = MIN(rc,rr);
949 integer k = taun; 947 integer k = taun;
950 DEBUGMSG("c_dorgqr"); 948 DEBUGMSG("c_dorgqr");
951 integer lwork = 8*n; // FIXME 949 integer lwork = 8*n; // FIXME
952 double *WORK = (double*)malloc(lwork*sizeof(double)); 950 double *WORK = (double*)malloc(lwork*sizeof(double));
953 CHECK(!WORK,MEM); 951 CHECK(!WORK,MEM);
954 memcpy(rp,ap,m*k*sizeof(double));
955 integer res; 952 integer res;
956 dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res); 953 dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res);
957 CHECK(res,res); 954 CHECK(res,res);
@@ -959,19 +956,18 @@ int c_dorgqr(KODMAT(a), KDVEC(tau), ODMAT(r)) {
959 OK 956 OK
960} 957}
961 958
962/* Subroutine */ int zungqr_(integer *m, integer *n, integer *k, 959int zungqr_(integer *m, integer *n, integer *k,
963 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * 960 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex *
964 work, integer *lwork, integer *info); 961 work, integer *lwork, integer *info);
965 962
966int c_zungqr(KOCMAT(a), KCVEC(tau), OCMAT(r)) { 963int c_zungqr(KCVEC(tau), OCMAT(r)) {
967 integer m = ar; 964 integer m = rr;
968 integer n = MIN(ac,ar); 965 integer n = MIN(rc,rr);
969 integer k = taun; 966 integer k = taun;
970 DEBUGMSG("z_ungqr"); 967 DEBUGMSG("z_ungqr");
971 integer lwork = 8*n; // FIXME 968 integer lwork = 8*n; // FIXME
972 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); 969 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
973 CHECK(!WORK,MEM); 970 CHECK(!WORK,MEM);
974 memcpy(rp,ap,m*k*sizeof(doublecomplex));
975 integer res; 971 integer res;
976 zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res); 972 zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res);
977 CHECK(res,res); 973 CHECK(res,res);
@@ -982,20 +978,19 @@ int c_zungqr(KOCMAT(a), KCVEC(tau), OCMAT(r)) {
982 978
983//////////////////// Hessenberg factorization ///////////////////////// 979//////////////////// Hessenberg factorization /////////////////////////
984 980
985/* Subroutine */ int dgehrd_(integer *n, integer *ilo, integer *ihi, 981int dgehrd_(integer *n, integer *ilo, integer *ihi,
986 doublereal *a, integer *lda, doublereal *tau, doublereal *work, 982 doublereal *a, integer *lda, doublereal *tau, doublereal *work,
987 integer *lwork, integer *info); 983 integer *lwork, integer *info);
988 984
989int hess_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { 985int hess_l_R(DVEC(tau), ODMAT(r)) {
990 integer m = ar; 986 integer m = rr;
991 integer n = ac; 987 integer n = rc;
992 integer mn = MIN(m,n); 988 integer mn = MIN(m,n);
993 REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); 989 REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE);
994 DEBUGMSG("hess_l_R"); 990 DEBUGMSG("hess_l_R");
995 integer lwork = 5*n; // fixme 991 integer lwork = 5*n; // FIXME
996 double *WORK = (double*)malloc(lwork*sizeof(double)); 992 double *WORK = (double*)malloc(lwork*sizeof(double));
997 CHECK(!WORK,MEM); 993 CHECK(!WORK,MEM);
998 memcpy(rp,ap,m*n*sizeof(double));
999 integer res; 994 integer res;
1000 integer one = 1; 995 integer one = 1;
1001 dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); 996 dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res);
@@ -1005,20 +1000,19 @@ int hess_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) {
1005} 1000}
1006 1001
1007 1002
1008/* Subroutine */ int zgehrd_(integer *n, integer *ilo, integer *ihi, 1003int zgehrd_(integer *n, integer *ilo, integer *ihi,
1009 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * 1004 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex *
1010 work, integer *lwork, integer *info); 1005 work, integer *lwork, integer *info);
1011 1006
1012int hess_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { 1007int hess_l_C(CVEC(tau), OCMAT(r)) {
1013 integer m = ar; 1008 integer m = rr;
1014 integer n = ac; 1009 integer n = rc;
1015 integer mn = MIN(m,n); 1010 integer mn = MIN(m,n);
1016 REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); 1011 REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE);
1017 DEBUGMSG("hess_l_C"); 1012 DEBUGMSG("hess_l_C");
1018 integer lwork = 5*n; // fixme 1013 integer lwork = 5*n; // FIXME
1019 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); 1014 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1020 CHECK(!WORK,MEM); 1015 CHECK(!WORK,MEM);
1021 memcpy(rp,ap,m*n*sizeof(doublecomplex));
1022 integer res; 1016 integer res;
1023 integer one = 1; 1017 integer one = 1;
1024 zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); 1018 zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res);
@@ -1029,23 +1023,17 @@ int hess_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) {
1029 1023
1030//////////////////// Schur factorization ///////////////////////// 1024//////////////////// Schur factorization /////////////////////////
1031 1025
1032/* Subroutine */ int dgees_(char *jobvs, char *sort, L_fp select, integer *n, 1026int dgees_(char *jobvs, char *sort, L_fp select, integer *n,
1033 doublereal *a, integer *lda, integer *sdim, doublereal *wr, 1027 doublereal *a, integer *lda, integer *sdim, doublereal *wr,
1034 doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, 1028 doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work,
1035 integer *lwork, logical *bwork, integer *info); 1029 integer *lwork, logical *bwork, integer *info);
1036 1030
1037int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) { 1031int schur_l_R(ODMAT(u), ODMAT(s)) {
1038 integer m = ar; 1032 integer m = sr;
1039 integer n = ac; 1033 integer n = sc;
1040 REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); 1034 REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE);
1041 DEBUGMSG("schur_l_R"); 1035 DEBUGMSG("schur_l_R");
1042 //int k; 1036 integer lwork = 6*n; // FIXME
1043 //printf("---------------------------\n");
1044 //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n");
1045 //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n");
1046 //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n");
1047 memcpy(sp,ap,n*n*sizeof(double));
1048 integer lwork = 6*n; // fixme
1049 double *WORK = (double*)malloc(lwork*sizeof(double)); 1037 double *WORK = (double*)malloc(lwork*sizeof(double));
1050 double *WR = (double*)malloc(n*sizeof(double)); 1038 double *WR = (double*)malloc(n*sizeof(double));
1051 double *WI = (double*)malloc(n*sizeof(double)); 1039 double *WI = (double*)malloc(n*sizeof(double));
@@ -1054,9 +1042,6 @@ int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) {
1054 integer res; 1042 integer res;
1055 integer sdim; 1043 integer sdim;
1056 dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res); 1044 dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res);
1057 //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n");
1058 //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n");
1059 //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n");
1060 if(res>0) { 1045 if(res>0) {
1061 return NOCONVER; 1046 return NOCONVER;
1062 } 1047 }
@@ -1069,18 +1054,17 @@ int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) {
1069} 1054}
1070 1055
1071 1056
1072/* Subroutine */ int zgees_(char *jobvs, char *sort, L_fp select, integer *n, 1057int zgees_(char *jobvs, char *sort, L_fp select, integer *n,
1073 doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w, 1058 doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w,
1074 doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, 1059 doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork,
1075 doublereal *rwork, logical *bwork, integer *info); 1060 doublereal *rwork, logical *bwork, integer *info);
1076 1061
1077int schur_l_C(KOCMAT(a), OCMAT(u), OCMAT(s)) { 1062int schur_l_C(OCMAT(u), OCMAT(s)) {
1078 integer m = ar; 1063 integer m = sr;
1079 integer n = ac; 1064 integer n = sc;
1080 REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); 1065 REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE);
1081 DEBUGMSG("schur_l_C"); 1066 DEBUGMSG("schur_l_C");
1082 memcpy(sp,ap,n*n*sizeof(doublecomplex)); 1067 integer lwork = 6*n; // FIXME
1083 integer lwork = 6*n; // fixme
1084 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); 1068 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1085 doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); 1069 doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex));
1086 // W not really required in this call 1070 // W not really required in this call
@@ -1103,21 +1087,20 @@ int schur_l_C(KOCMAT(a), OCMAT(u), OCMAT(s)) {
1103 1087
1104//////////////////// LU factorization ///////////////////////// 1088//////////////////// LU factorization /////////////////////////
1105 1089
1106/* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * 1090int dgetrf_(integer *m, integer *n, doublereal *a, integer *
1107 lda, integer *ipiv, integer *info); 1091 lda, integer *ipiv, integer *info);
1108 1092
1109int lu_l_R(KODMAT(a), DVEC(ipiv), ODMAT(r)) { 1093int lu_l_R(DVEC(ipiv), ODMAT(r)) {
1110 integer m = ar; 1094 integer m = rr;
1111 integer n = ac; 1095 integer n = rc;
1112 integer mn = MIN(m,n); 1096 integer mn = MIN(m,n);
1113 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); 1097 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
1114 DEBUGMSG("lu_l_R"); 1098 DEBUGMSG("lu_l_R");
1115 integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); 1099 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
1116 memcpy(rp,ap,m*n*sizeof(double));
1117 integer res; 1100 integer res;
1118 dgetrf_ (&m,&n,rp,&m,auxipiv,&res); 1101 dgetrf_ (&m,&n,rp,&m,auxipiv,&res);
1119 if(res>0) { 1102 if(res>0) {
1120 res = 0; // fixme 1103 res = 0; // FIXME
1121 } 1104 }
1122 CHECK(res,res); 1105 CHECK(res,res);
1123 int k; 1106 int k;
@@ -1129,21 +1112,20 @@ int lu_l_R(KODMAT(a), DVEC(ipiv), ODMAT(r)) {
1129} 1112}
1130 1113
1131 1114
1132/* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a, 1115int zgetrf_(integer *m, integer *n, doublecomplex *a,
1133 integer *lda, integer *ipiv, integer *info); 1116 integer *lda, integer *ipiv, integer *info);
1134 1117
1135int lu_l_C(KOCMAT(a), DVEC(ipiv), OCMAT(r)) { 1118int lu_l_C(DVEC(ipiv), OCMAT(r)) {
1136 integer m = ar; 1119 integer m = rr;
1137 integer n = ac; 1120 integer n = rc;
1138 integer mn = MIN(m,n); 1121 integer mn = MIN(m,n);
1139 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); 1122 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
1140 DEBUGMSG("lu_l_C"); 1123 DEBUGMSG("lu_l_C");
1141 integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); 1124 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
1142 memcpy(rp,ap,m*n*sizeof(doublecomplex));
1143 integer res; 1125 integer res;
1144 zgetrf_ (&m,&n,rp,&m,auxipiv,&res); 1126 zgetrf_ (&m,&n,rp,&m,auxipiv,&res);
1145 if(res>0) { 1127 if(res>0) {
1146 res = 0; // fixme 1128 res = 0; // FIXME
1147 } 1129 }
1148 CHECK(res,res); 1130 CHECK(res,res);
1149 int k; 1131 int k;
@@ -1157,11 +1139,11 @@ int lu_l_C(KOCMAT(a), DVEC(ipiv), OCMAT(r)) {
1157 1139
1158//////////////////// LU substitution ///////////////////////// 1140//////////////////// LU substitution /////////////////////////
1159 1141
1160/* Subroutine */ int dgetrs_(char *trans, integer *n, integer *nrhs, 1142int dgetrs_(char *trans, integer *n, integer *nrhs,
1161 doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * 1143 doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer *
1162 ldb, integer *info); 1144 ldb, integer *info);
1163 1145
1164int luS_l_R(KODMAT(a), KDVEC(ipiv), KODMAT(b), ODMAT(x)) { 1146int luS_l_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) {
1165 integer m = ar; 1147 integer m = ar;
1166 integer n = ac; 1148 integer n = ac;
1167 integer mrhs = br; 1149 integer mrhs = br;
@@ -1174,19 +1156,18 @@ int luS_l_R(KODMAT(a), KDVEC(ipiv), KODMAT(b), ODMAT(x)) {
1174 auxipiv[k] = (integer)ipivp[k]; 1156 auxipiv[k] = (integer)ipivp[k];
1175 } 1157 }
1176 integer res; 1158 integer res;
1177 memcpy(xp,bp,mrhs*nrhs*sizeof(double)); 1159 dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,bp,&mrhs,&res);
1178 dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res);
1179 CHECK(res,res); 1160 CHECK(res,res);
1180 free(auxipiv); 1161 free(auxipiv);
1181 OK 1162 OK
1182} 1163}
1183 1164
1184 1165
1185/* Subroutine */ int zgetrs_(char *trans, integer *n, integer *nrhs, 1166int zgetrs_(char *trans, integer *n, integer *nrhs,
1186 doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, 1167 doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b,
1187 integer *ldb, integer *info); 1168 integer *ldb, integer *info);
1188 1169
1189int luS_l_C(KOCMAT(a), KDVEC(ipiv), KOCMAT(b), OCMAT(x)) { 1170int luS_l_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) {
1190 integer m = ar; 1171 integer m = ar;
1191 integer n = ac; 1172 integer n = ac;
1192 integer mrhs = br; 1173 integer mrhs = br;
@@ -1199,8 +1180,7 @@ int luS_l_C(KOCMAT(a), KDVEC(ipiv), KOCMAT(b), OCMAT(x)) {
1199 auxipiv[k] = (integer)ipivp[k]; 1180 auxipiv[k] = (integer)ipivp[k];
1200 } 1181 }
1201 integer res; 1182 integer res;
1202 memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex)); 1183 zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,bp,&mrhs,&res);
1203 zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,xp,&mrhs,&res);
1204 CHECK(res,res); 1184 CHECK(res,res);
1205 free(auxipiv); 1185 free(auxipiv);
1206 OK 1186 OK
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index ce00c16..65deceb 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -455,29 +455,29 @@ mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS"
455 455
456type TMVM t = t ::> t :> t ::> Ok 456type TMVM t = t ::> t :> t ::> Ok
457 457
458foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R 458foreign import ccall unsafe "qr_l_R" dgeqr2 :: R :> R ::> Ok
459foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C 459foreign import ccall unsafe "qr_l_C" zgeqr2 :: C :> C ::> Ok
460 460
461-- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. 461-- | QR factorization of a real matrix, using LAPACK's /dgeqr2/.
462qrR :: Matrix Double -> (Matrix Double, Vector Double) 462qrR :: Matrix Double -> (Matrix Double, Vector Double)
463qrR = qrAux dgeqr2 "qrR" . fmat 463qrR = qrAux dgeqr2 "qrR"
464 464
465-- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. 465-- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/.
466qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) 466qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
467qrC = qrAux zgeqr2 "qrC" . fmat 467qrC = qrAux zgeqr2 "qrC"
468 468
469qrAux f st a = unsafePerformIO $ do 469qrAux f st a = unsafePerformIO $ do
470 r <- createMatrix ColumnMajor m n 470 r <- copy ColumnMajor a
471 tau <- createVector mn 471 tau <- createVector mn
472 f # a # tau # r #| st 472 f # tau # r #| st
473 return (r,tau) 473 return (r,tau)
474 where 474 where
475 m = rows a 475 m = rows a
476 n = cols a 476 n = cols a
477 mn = min m n 477 mn = min m n
478 478
479foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM R 479foreign import ccall unsafe "c_dorgqr" dorgqr :: R :> R ::> Ok
480foreign import ccall unsafe "c_zungqr" zungqr :: TMVM C 480foreign import ccall unsafe "c_zungqr" zungqr :: C :> C ::> Ok
481 481
482-- | build rotation from reflectors 482-- | build rotation from reflectors
483qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double 483qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double
@@ -487,28 +487,28 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co
487qrgrC = qrgrAux zungqr "qrgrC" 487qrgrC = qrgrAux zungqr "qrgrC"
488 488
489qrgrAux f st n (a, tau) = unsafePerformIO $ do 489qrgrAux f st n (a, tau) = unsafePerformIO $ do
490 res <- createMatrix ColumnMajor (rows a) n 490 res <- copy ColumnMajor (sliceMatrix (0,0) (rows a,n) a)
491 f # (fmat a) # (subVector 0 n tau') # res #| st 491 f # (subVector 0 n tau') # res #| st
492 return res 492 return res
493 where 493 where
494 tau' = vjoin [tau, constantD 0 n] 494 tau' = vjoin [tau, constantD 0 n]
495 495
496----------------------------------------------------------------------------------- 496-----------------------------------------------------------------------------------
497foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM R 497foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok
498foreign import ccall unsafe "hess_l_C" zgehrd :: TMVM C 498foreign import ccall unsafe "hess_l_C" zgehrd :: C :> C ::> Ok
499 499
500-- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. 500-- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/.
501hessR :: Matrix Double -> (Matrix Double, Vector Double) 501hessR :: Matrix Double -> (Matrix Double, Vector Double)
502hessR = hessAux dgehrd "hessR" . fmat 502hessR = hessAux dgehrd "hessR"
503 503
504-- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. 504-- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/.
505hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) 505hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
506hessC = hessAux zgehrd "hessC" . fmat 506hessC = hessAux zgehrd "hessC"
507 507
508hessAux f st a = unsafePerformIO $ do 508hessAux f st a = unsafePerformIO $ do
509 r <- createMatrix ColumnMajor m n 509 r <- copy ColumnMajor a
510 tau <- createVector (mn-1) 510 tau <- createVector (mn-1)
511 f # a # tau # r #| st 511 f # tau # r #| st
512 return (r,tau) 512 return (r,tau)
513 where 513 where
514 m = rows a 514 m = rows a
@@ -516,28 +516,28 @@ hessAux f st a = unsafePerformIO $ do
516 mn = min m n 516 mn = min m n
517 517
518----------------------------------------------------------------------------------- 518-----------------------------------------------------------------------------------
519foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> R ::> Ok 519foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> Ok
520foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> C ::> Ok 520foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> Ok
521 521
522-- | Schur factorization of a square real matrix, using LAPACK's /dgees/. 522-- | Schur factorization of a square real matrix, using LAPACK's /dgees/.
523schurR :: Matrix Double -> (Matrix Double, Matrix Double) 523schurR :: Matrix Double -> (Matrix Double, Matrix Double)
524schurR = schurAux dgees "schurR" . fmat 524schurR = schurAux dgees "schurR"
525 525
526-- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. 526-- | Schur factorization of a square complex matrix, using LAPACK's /zgees/.
527schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) 527schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double))
528schurC = schurAux zgees "schurC" . fmat 528schurC = schurAux zgees "schurC"
529 529
530schurAux f st a = unsafePerformIO $ do 530schurAux f st a = unsafePerformIO $ do
531 u <- createMatrix ColumnMajor n n 531 u <- createMatrix ColumnMajor n n
532 s <- createMatrix ColumnMajor n n 532 s <- copy ColumnMajor a
533 f # a # u # s #| st 533 f # u # s #| st
534 return (u,s) 534 return (u,s)
535 where 535 where
536 n = rows a 536 n = rows a
537 537
538----------------------------------------------------------------------------------- 538-----------------------------------------------------------------------------------
539foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R 539foreign import ccall unsafe "lu_l_R" dgetrf :: R :> R ::> Ok
540foreign import ccall unsafe "lu_l_C" zgetrf :: C ::> R :> C ::> Ok 540foreign import ccall unsafe "lu_l_C" zgetrf :: R :> C ::> Ok
541 541
542-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. 542-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/.
543luR :: Matrix Double -> (Matrix Double, [Int]) 543luR :: Matrix Double -> (Matrix Double, [Int])
@@ -548,9 +548,9 @@ luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int])
548luC = luAux zgetrf "luC" . fmat 548luC = luAux zgetrf "luC" . fmat
549 549
550luAux f st a = unsafePerformIO $ do 550luAux f st a = unsafePerformIO $ do
551 lu <- createMatrix ColumnMajor n m 551 lu <- copy ColumnMajor a
552 piv <- createVector (min n m) 552 piv <- createVector (min n m)
553 f # a # piv # lu #| st 553 f # piv # lu #| st
554 return (lu, map (pred.round) (toList piv)) 554 return (lu, map (pred.round) (toList piv))
555 where 555 where
556 n = rows a 556 n = rows a
@@ -558,10 +558,8 @@ luAux f st a = unsafePerformIO $ do
558 558
559----------------------------------------------------------------------------------- 559-----------------------------------------------------------------------------------
560 560
561type Tlus t = t ::> Double :> t ::> t ::> Ok 561foreign import ccall unsafe "luS_l_R" dgetrs :: R ::> R :> R ::> Ok
562 562foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok
563foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R
564foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C
565 563
566-- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. 564-- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/.
567lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double 565lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
@@ -573,13 +571,13 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)
573 571
574lusAux f st a piv b 572lusAux f st a piv b
575 | n1==n2 && n2==n =unsafePerformIO $ do 573 | n1==n2 && n2==n =unsafePerformIO $ do
576 x <- createMatrix ColumnMajor n m 574 x <- copy ColumnMajor b
577 f # a # piv' # b # x #| st 575 f # a # piv' # x #| st
578 return x 576 return x
579 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" 577 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
580 where n1 = rows a 578 where
581 n2 = cols a 579 n1 = rows a
582 n = rows b 580 n2 = cols a
583 m = cols b 581 n = rows b
584 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double 582 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double
585 583