diff options
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 136 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 72 |
2 files changed, 93 insertions, 115 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ab78dac..8524962 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -901,18 +901,17 @@ int chol_l_S(ODMAT(l)) { | |||
901 | 901 | ||
902 | //////////////////// QR factorization ///////////////////////// | 902 | //////////////////// QR factorization ///////////////////////// |
903 | 903 | ||
904 | /* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer * | 904 | int dgeqr2_(integer *m, integer *n, doublereal *a, integer * |
905 | lda, doublereal *tau, doublereal *work, integer *info); | 905 | lda, doublereal *tau, doublereal *work, integer *info); |
906 | 906 | ||
907 | int qr_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { | 907 | int qr_l_R(DVEC(tau), ODMAT(r)) { |
908 | integer m = ar; | 908 | integer m = rr; |
909 | integer n = ac; | 909 | integer n = rc; |
910 | integer mn = MIN(m,n); | 910 | integer mn = MIN(m,n); |
911 | REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); | 911 | REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE); |
912 | DEBUGMSG("qr_l_R"); | 912 | DEBUGMSG("qr_l_R"); |
913 | double *WORK = (double*)malloc(n*sizeof(double)); | 913 | double *WORK = (double*)malloc(n*sizeof(double)); |
914 | CHECK(!WORK,MEM); | 914 | CHECK(!WORK,MEM); |
915 | memcpy(rp,ap,m*n*sizeof(double)); | ||
916 | integer res; | 915 | integer res; |
917 | dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); | 916 | dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); |
918 | CHECK(res,res); | 917 | CHECK(res,res); |
@@ -920,18 +919,17 @@ int qr_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { | |||
920 | OK | 919 | OK |
921 | } | 920 | } |
922 | 921 | ||
923 | /* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a, | 922 | int zgeqr2_(integer *m, integer *n, doublecomplex *a, |
924 | integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); | 923 | integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); |
925 | 924 | ||
926 | int qr_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { | 925 | int qr_l_C(CVEC(tau), OCMAT(r)) { |
927 | integer m = ar; | 926 | integer m = rr; |
928 | integer n = ac; | 927 | integer n = rc; |
929 | integer mn = MIN(m,n); | 928 | integer mn = MIN(m,n); |
930 | REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); | 929 | REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE); |
931 | DEBUGMSG("qr_l_C"); | 930 | DEBUGMSG("qr_l_C"); |
932 | doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex)); | 931 | doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex)); |
933 | CHECK(!WORK,MEM); | 932 | CHECK(!WORK,MEM); |
934 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
935 | integer res; | 933 | integer res; |
936 | zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); | 934 | zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); |
937 | CHECK(res,res); | 935 | CHECK(res,res); |
@@ -939,19 +937,18 @@ int qr_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { | |||
939 | OK | 937 | OK |
940 | } | 938 | } |
941 | 939 | ||
942 | /* Subroutine */ int dorgqr_(integer *m, integer *n, integer *k, doublereal * | 940 | int dorgqr_(integer *m, integer *n, integer *k, doublereal * |
943 | a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, | 941 | a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, |
944 | integer *info); | 942 | integer *info); |
945 | 943 | ||
946 | int c_dorgqr(KODMAT(a), KDVEC(tau), ODMAT(r)) { | 944 | int c_dorgqr(KDVEC(tau), ODMAT(r)) { |
947 | integer m = ar; | 945 | integer m = rr; |
948 | integer n = MIN(ac,ar); | 946 | integer n = MIN(rc,rr); |
949 | integer k = taun; | 947 | integer k = taun; |
950 | DEBUGMSG("c_dorgqr"); | 948 | DEBUGMSG("c_dorgqr"); |
951 | integer lwork = 8*n; // FIXME | 949 | integer lwork = 8*n; // FIXME |
952 | double *WORK = (double*)malloc(lwork*sizeof(double)); | 950 | double *WORK = (double*)malloc(lwork*sizeof(double)); |
953 | CHECK(!WORK,MEM); | 951 | CHECK(!WORK,MEM); |
954 | memcpy(rp,ap,m*k*sizeof(double)); | ||
955 | integer res; | 952 | integer res; |
956 | dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res); | 953 | dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res); |
957 | CHECK(res,res); | 954 | CHECK(res,res); |
@@ -959,19 +956,18 @@ int c_dorgqr(KODMAT(a), KDVEC(tau), ODMAT(r)) { | |||
959 | OK | 956 | OK |
960 | } | 957 | } |
961 | 958 | ||
962 | /* Subroutine */ int zungqr_(integer *m, integer *n, integer *k, | 959 | int zungqr_(integer *m, integer *n, integer *k, |
963 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | 960 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * |
964 | work, integer *lwork, integer *info); | 961 | work, integer *lwork, integer *info); |
965 | 962 | ||
966 | int c_zungqr(KOCMAT(a), KCVEC(tau), OCMAT(r)) { | 963 | int c_zungqr(KCVEC(tau), OCMAT(r)) { |
967 | integer m = ar; | 964 | integer m = rr; |
968 | integer n = MIN(ac,ar); | 965 | integer n = MIN(rc,rr); |
969 | integer k = taun; | 966 | integer k = taun; |
970 | DEBUGMSG("z_ungqr"); | 967 | DEBUGMSG("z_ungqr"); |
971 | integer lwork = 8*n; // FIXME | 968 | integer lwork = 8*n; // FIXME |
972 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 969 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
973 | CHECK(!WORK,MEM); | 970 | CHECK(!WORK,MEM); |
974 | memcpy(rp,ap,m*k*sizeof(doublecomplex)); | ||
975 | integer res; | 971 | integer res; |
976 | zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res); | 972 | zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res); |
977 | CHECK(res,res); | 973 | CHECK(res,res); |
@@ -982,20 +978,19 @@ int c_zungqr(KOCMAT(a), KCVEC(tau), OCMAT(r)) { | |||
982 | 978 | ||
983 | //////////////////// Hessenberg factorization ///////////////////////// | 979 | //////////////////// Hessenberg factorization ///////////////////////// |
984 | 980 | ||
985 | /* Subroutine */ int dgehrd_(integer *n, integer *ilo, integer *ihi, | 981 | int dgehrd_(integer *n, integer *ilo, integer *ihi, |
986 | doublereal *a, integer *lda, doublereal *tau, doublereal *work, | 982 | doublereal *a, integer *lda, doublereal *tau, doublereal *work, |
987 | integer *lwork, integer *info); | 983 | integer *lwork, integer *info); |
988 | 984 | ||
989 | int hess_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { | 985 | int hess_l_R(DVEC(tau), ODMAT(r)) { |
990 | integer m = ar; | 986 | integer m = rr; |
991 | integer n = ac; | 987 | integer n = rc; |
992 | integer mn = MIN(m,n); | 988 | integer mn = MIN(m,n); |
993 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | 989 | REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE); |
994 | DEBUGMSG("hess_l_R"); | 990 | DEBUGMSG("hess_l_R"); |
995 | integer lwork = 5*n; // fixme | 991 | integer lwork = 5*n; // FIXME |
996 | double *WORK = (double*)malloc(lwork*sizeof(double)); | 992 | double *WORK = (double*)malloc(lwork*sizeof(double)); |
997 | CHECK(!WORK,MEM); | 993 | CHECK(!WORK,MEM); |
998 | memcpy(rp,ap,m*n*sizeof(double)); | ||
999 | integer res; | 994 | integer res; |
1000 | integer one = 1; | 995 | integer one = 1; |
1001 | dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); | 996 | dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); |
@@ -1005,20 +1000,19 @@ int hess_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { | |||
1005 | } | 1000 | } |
1006 | 1001 | ||
1007 | 1002 | ||
1008 | /* Subroutine */ int zgehrd_(integer *n, integer *ilo, integer *ihi, | 1003 | int zgehrd_(integer *n, integer *ilo, integer *ihi, |
1009 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * | 1004 | doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * |
1010 | work, integer *lwork, integer *info); | 1005 | work, integer *lwork, integer *info); |
1011 | 1006 | ||
1012 | int hess_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { | 1007 | int hess_l_C(CVEC(tau), OCMAT(r)) { |
1013 | integer m = ar; | 1008 | integer m = rr; |
1014 | integer n = ac; | 1009 | integer n = rc; |
1015 | integer mn = MIN(m,n); | 1010 | integer mn = MIN(m,n); |
1016 | REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); | 1011 | REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE); |
1017 | DEBUGMSG("hess_l_C"); | 1012 | DEBUGMSG("hess_l_C"); |
1018 | integer lwork = 5*n; // fixme | 1013 | integer lwork = 5*n; // FIXME |
1019 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 1014 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
1020 | CHECK(!WORK,MEM); | 1015 | CHECK(!WORK,MEM); |
1021 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
1022 | integer res; | 1016 | integer res; |
1023 | integer one = 1; | 1017 | integer one = 1; |
1024 | zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); | 1018 | zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); |
@@ -1029,23 +1023,17 @@ int hess_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { | |||
1029 | 1023 | ||
1030 | //////////////////// Schur factorization ///////////////////////// | 1024 | //////////////////// Schur factorization ///////////////////////// |
1031 | 1025 | ||
1032 | /* Subroutine */ int dgees_(char *jobvs, char *sort, L_fp select, integer *n, | 1026 | int dgees_(char *jobvs, char *sort, L_fp select, integer *n, |
1033 | doublereal *a, integer *lda, integer *sdim, doublereal *wr, | 1027 | doublereal *a, integer *lda, integer *sdim, doublereal *wr, |
1034 | doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, | 1028 | doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, |
1035 | integer *lwork, logical *bwork, integer *info); | 1029 | integer *lwork, logical *bwork, integer *info); |
1036 | 1030 | ||
1037 | int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) { | 1031 | int schur_l_R(ODMAT(u), ODMAT(s)) { |
1038 | integer m = ar; | 1032 | integer m = sr; |
1039 | integer n = ac; | 1033 | integer n = sc; |
1040 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | 1034 | REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE); |
1041 | DEBUGMSG("schur_l_R"); | 1035 | DEBUGMSG("schur_l_R"); |
1042 | //int k; | 1036 | integer lwork = 6*n; // FIXME |
1043 | //printf("---------------------------\n"); | ||
1044 | //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n"); | ||
1045 | //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n"); | ||
1046 | //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n"); | ||
1047 | memcpy(sp,ap,n*n*sizeof(double)); | ||
1048 | integer lwork = 6*n; // fixme | ||
1049 | double *WORK = (double*)malloc(lwork*sizeof(double)); | 1037 | double *WORK = (double*)malloc(lwork*sizeof(double)); |
1050 | double *WR = (double*)malloc(n*sizeof(double)); | 1038 | double *WR = (double*)malloc(n*sizeof(double)); |
1051 | double *WI = (double*)malloc(n*sizeof(double)); | 1039 | double *WI = (double*)malloc(n*sizeof(double)); |
@@ -1054,9 +1042,6 @@ int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) { | |||
1054 | integer res; | 1042 | integer res; |
1055 | integer sdim; | 1043 | integer sdim; |
1056 | dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res); | 1044 | dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res); |
1057 | //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n"); | ||
1058 | //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n"); | ||
1059 | //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n"); | ||
1060 | if(res>0) { | 1045 | if(res>0) { |
1061 | return NOCONVER; | 1046 | return NOCONVER; |
1062 | } | 1047 | } |
@@ -1069,18 +1054,17 @@ int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) { | |||
1069 | } | 1054 | } |
1070 | 1055 | ||
1071 | 1056 | ||
1072 | /* Subroutine */ int zgees_(char *jobvs, char *sort, L_fp select, integer *n, | 1057 | int zgees_(char *jobvs, char *sort, L_fp select, integer *n, |
1073 | doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w, | 1058 | doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w, |
1074 | doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, | 1059 | doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, |
1075 | doublereal *rwork, logical *bwork, integer *info); | 1060 | doublereal *rwork, logical *bwork, integer *info); |
1076 | 1061 | ||
1077 | int schur_l_C(KOCMAT(a), OCMAT(u), OCMAT(s)) { | 1062 | int schur_l_C(OCMAT(u), OCMAT(s)) { |
1078 | integer m = ar; | 1063 | integer m = sr; |
1079 | integer n = ac; | 1064 | integer n = sc; |
1080 | REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); | 1065 | REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE); |
1081 | DEBUGMSG("schur_l_C"); | 1066 | DEBUGMSG("schur_l_C"); |
1082 | memcpy(sp,ap,n*n*sizeof(doublecomplex)); | 1067 | integer lwork = 6*n; // FIXME |
1083 | integer lwork = 6*n; // fixme | ||
1084 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); | 1068 | doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); |
1085 | doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); | 1069 | doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); |
1086 | // W not really required in this call | 1070 | // W not really required in this call |
@@ -1103,21 +1087,20 @@ int schur_l_C(KOCMAT(a), OCMAT(u), OCMAT(s)) { | |||
1103 | 1087 | ||
1104 | //////////////////// LU factorization ///////////////////////// | 1088 | //////////////////// LU factorization ///////////////////////// |
1105 | 1089 | ||
1106 | /* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * | 1090 | int dgetrf_(integer *m, integer *n, doublereal *a, integer * |
1107 | lda, integer *ipiv, integer *info); | 1091 | lda, integer *ipiv, integer *info); |
1108 | 1092 | ||
1109 | int lu_l_R(KODMAT(a), DVEC(ipiv), ODMAT(r)) { | 1093 | int lu_l_R(DVEC(ipiv), ODMAT(r)) { |
1110 | integer m = ar; | 1094 | integer m = rr; |
1111 | integer n = ac; | 1095 | integer n = rc; |
1112 | integer mn = MIN(m,n); | 1096 | integer mn = MIN(m,n); |
1113 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | 1097 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); |
1114 | DEBUGMSG("lu_l_R"); | 1098 | DEBUGMSG("lu_l_R"); |
1115 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | 1099 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); |
1116 | memcpy(rp,ap,m*n*sizeof(double)); | ||
1117 | integer res; | 1100 | integer res; |
1118 | dgetrf_ (&m,&n,rp,&m,auxipiv,&res); | 1101 | dgetrf_ (&m,&n,rp,&m,auxipiv,&res); |
1119 | if(res>0) { | 1102 | if(res>0) { |
1120 | res = 0; // fixme | 1103 | res = 0; // FIXME |
1121 | } | 1104 | } |
1122 | CHECK(res,res); | 1105 | CHECK(res,res); |
1123 | int k; | 1106 | int k; |
@@ -1129,21 +1112,20 @@ int lu_l_R(KODMAT(a), DVEC(ipiv), ODMAT(r)) { | |||
1129 | } | 1112 | } |
1130 | 1113 | ||
1131 | 1114 | ||
1132 | /* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a, | 1115 | int zgetrf_(integer *m, integer *n, doublecomplex *a, |
1133 | integer *lda, integer *ipiv, integer *info); | 1116 | integer *lda, integer *ipiv, integer *info); |
1134 | 1117 | ||
1135 | int lu_l_C(KOCMAT(a), DVEC(ipiv), OCMAT(r)) { | 1118 | int lu_l_C(DVEC(ipiv), OCMAT(r)) { |
1136 | integer m = ar; | 1119 | integer m = rr; |
1137 | integer n = ac; | 1120 | integer n = rc; |
1138 | integer mn = MIN(m,n); | 1121 | integer mn = MIN(m,n); |
1139 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | 1122 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); |
1140 | DEBUGMSG("lu_l_C"); | 1123 | DEBUGMSG("lu_l_C"); |
1141 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | 1124 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); |
1142 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
1143 | integer res; | 1125 | integer res; |
1144 | zgetrf_ (&m,&n,rp,&m,auxipiv,&res); | 1126 | zgetrf_ (&m,&n,rp,&m,auxipiv,&res); |
1145 | if(res>0) { | 1127 | if(res>0) { |
1146 | res = 0; // fixme | 1128 | res = 0; // FIXME |
1147 | } | 1129 | } |
1148 | CHECK(res,res); | 1130 | CHECK(res,res); |
1149 | int k; | 1131 | int k; |
@@ -1157,11 +1139,11 @@ int lu_l_C(KOCMAT(a), DVEC(ipiv), OCMAT(r)) { | |||
1157 | 1139 | ||
1158 | //////////////////// LU substitution ///////////////////////// | 1140 | //////////////////// LU substitution ///////////////////////// |
1159 | 1141 | ||
1160 | /* Subroutine */ int dgetrs_(char *trans, integer *n, integer *nrhs, | 1142 | int dgetrs_(char *trans, integer *n, integer *nrhs, |
1161 | doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * | 1143 | doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * |
1162 | ldb, integer *info); | 1144 | ldb, integer *info); |
1163 | 1145 | ||
1164 | int luS_l_R(KODMAT(a), KDVEC(ipiv), KODMAT(b), ODMAT(x)) { | 1146 | int luS_l_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) { |
1165 | integer m = ar; | 1147 | integer m = ar; |
1166 | integer n = ac; | 1148 | integer n = ac; |
1167 | integer mrhs = br; | 1149 | integer mrhs = br; |
@@ -1174,19 +1156,18 @@ int luS_l_R(KODMAT(a), KDVEC(ipiv), KODMAT(b), ODMAT(x)) { | |||
1174 | auxipiv[k] = (integer)ipivp[k]; | 1156 | auxipiv[k] = (integer)ipivp[k]; |
1175 | } | 1157 | } |
1176 | integer res; | 1158 | integer res; |
1177 | memcpy(xp,bp,mrhs*nrhs*sizeof(double)); | 1159 | dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,bp,&mrhs,&res); |
1178 | dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res); | ||
1179 | CHECK(res,res); | 1160 | CHECK(res,res); |
1180 | free(auxipiv); | 1161 | free(auxipiv); |
1181 | OK | 1162 | OK |
1182 | } | 1163 | } |
1183 | 1164 | ||
1184 | 1165 | ||
1185 | /* Subroutine */ int zgetrs_(char *trans, integer *n, integer *nrhs, | 1166 | int zgetrs_(char *trans, integer *n, integer *nrhs, |
1186 | doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, | 1167 | doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, |
1187 | integer *ldb, integer *info); | 1168 | integer *ldb, integer *info); |
1188 | 1169 | ||
1189 | int luS_l_C(KOCMAT(a), KDVEC(ipiv), KOCMAT(b), OCMAT(x)) { | 1170 | int luS_l_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) { |
1190 | integer m = ar; | 1171 | integer m = ar; |
1191 | integer n = ac; | 1172 | integer n = ac; |
1192 | integer mrhs = br; | 1173 | integer mrhs = br; |
@@ -1199,8 +1180,7 @@ int luS_l_C(KOCMAT(a), KDVEC(ipiv), KOCMAT(b), OCMAT(x)) { | |||
1199 | auxipiv[k] = (integer)ipivp[k]; | 1180 | auxipiv[k] = (integer)ipivp[k]; |
1200 | } | 1181 | } |
1201 | integer res; | 1182 | integer res; |
1202 | memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex)); | 1183 | zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,bp,&mrhs,&res); |
1203 | zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,xp,&mrhs,&res); | ||
1204 | CHECK(res,res); | 1184 | CHECK(res,res); |
1205 | free(auxipiv); | 1185 | free(auxipiv); |
1206 | OK | 1186 | OK |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index ce00c16..65deceb 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -455,29 +455,29 @@ mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" | |||
455 | 455 | ||
456 | type TMVM t = t ::> t :> t ::> Ok | 456 | type TMVM t = t ::> t :> t ::> Ok |
457 | 457 | ||
458 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R | 458 | foreign import ccall unsafe "qr_l_R" dgeqr2 :: R :> R ::> Ok |
459 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C | 459 | foreign import ccall unsafe "qr_l_C" zgeqr2 :: C :> C ::> Ok |
460 | 460 | ||
461 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. | 461 | -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. |
462 | qrR :: Matrix Double -> (Matrix Double, Vector Double) | 462 | qrR :: Matrix Double -> (Matrix Double, Vector Double) |
463 | qrR = qrAux dgeqr2 "qrR" . fmat | 463 | qrR = qrAux dgeqr2 "qrR" |
464 | 464 | ||
465 | -- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. | 465 | -- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. |
466 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 466 | qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
467 | qrC = qrAux zgeqr2 "qrC" . fmat | 467 | qrC = qrAux zgeqr2 "qrC" |
468 | 468 | ||
469 | qrAux f st a = unsafePerformIO $ do | 469 | qrAux f st a = unsafePerformIO $ do |
470 | r <- createMatrix ColumnMajor m n | 470 | r <- copy ColumnMajor a |
471 | tau <- createVector mn | 471 | tau <- createVector mn |
472 | f # a # tau # r #| st | 472 | f # tau # r #| st |
473 | return (r,tau) | 473 | return (r,tau) |
474 | where | 474 | where |
475 | m = rows a | 475 | m = rows a |
476 | n = cols a | 476 | n = cols a |
477 | mn = min m n | 477 | mn = min m n |
478 | 478 | ||
479 | foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM R | 479 | foreign import ccall unsafe "c_dorgqr" dorgqr :: R :> R ::> Ok |
480 | foreign import ccall unsafe "c_zungqr" zungqr :: TMVM C | 480 | foreign import ccall unsafe "c_zungqr" zungqr :: C :> C ::> Ok |
481 | 481 | ||
482 | -- | build rotation from reflectors | 482 | -- | build rotation from reflectors |
483 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double | 483 | qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double |
@@ -487,28 +487,28 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co | |||
487 | qrgrC = qrgrAux zungqr "qrgrC" | 487 | qrgrC = qrgrAux zungqr "qrgrC" |
488 | 488 | ||
489 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 489 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
490 | res <- createMatrix ColumnMajor (rows a) n | 490 | res <- copy ColumnMajor (sliceMatrix (0,0) (rows a,n) a) |
491 | f # (fmat a) # (subVector 0 n tau') # res #| st | 491 | f # (subVector 0 n tau') # res #| st |
492 | return res | 492 | return res |
493 | where | 493 | where |
494 | tau' = vjoin [tau, constantD 0 n] | 494 | tau' = vjoin [tau, constantD 0 n] |
495 | 495 | ||
496 | ----------------------------------------------------------------------------------- | 496 | ----------------------------------------------------------------------------------- |
497 | foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM R | 497 | foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok |
498 | foreign import ccall unsafe "hess_l_C" zgehrd :: TMVM C | 498 | foreign import ccall unsafe "hess_l_C" zgehrd :: C :> C ::> Ok |
499 | 499 | ||
500 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. | 500 | -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. |
501 | hessR :: Matrix Double -> (Matrix Double, Vector Double) | 501 | hessR :: Matrix Double -> (Matrix Double, Vector Double) |
502 | hessR = hessAux dgehrd "hessR" . fmat | 502 | hessR = hessAux dgehrd "hessR" |
503 | 503 | ||
504 | -- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. | 504 | -- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. |
505 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) | 505 | hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) |
506 | hessC = hessAux zgehrd "hessC" . fmat | 506 | hessC = hessAux zgehrd "hessC" |
507 | 507 | ||
508 | hessAux f st a = unsafePerformIO $ do | 508 | hessAux f st a = unsafePerformIO $ do |
509 | r <- createMatrix ColumnMajor m n | 509 | r <- copy ColumnMajor a |
510 | tau <- createVector (mn-1) | 510 | tau <- createVector (mn-1) |
511 | f # a # tau # r #| st | 511 | f # tau # r #| st |
512 | return (r,tau) | 512 | return (r,tau) |
513 | where | 513 | where |
514 | m = rows a | 514 | m = rows a |
@@ -516,28 +516,28 @@ hessAux f st a = unsafePerformIO $ do | |||
516 | mn = min m n | 516 | mn = min m n |
517 | 517 | ||
518 | ----------------------------------------------------------------------------------- | 518 | ----------------------------------------------------------------------------------- |
519 | foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> R ::> Ok | 519 | foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> Ok |
520 | foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> C ::> Ok | 520 | foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> Ok |
521 | 521 | ||
522 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. | 522 | -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. |
523 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) | 523 | schurR :: Matrix Double -> (Matrix Double, Matrix Double) |
524 | schurR = schurAux dgees "schurR" . fmat | 524 | schurR = schurAux dgees "schurR" |
525 | 525 | ||
526 | -- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. | 526 | -- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. |
527 | schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) | 527 | schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) |
528 | schurC = schurAux zgees "schurC" . fmat | 528 | schurC = schurAux zgees "schurC" |
529 | 529 | ||
530 | schurAux f st a = unsafePerformIO $ do | 530 | schurAux f st a = unsafePerformIO $ do |
531 | u <- createMatrix ColumnMajor n n | 531 | u <- createMatrix ColumnMajor n n |
532 | s <- createMatrix ColumnMajor n n | 532 | s <- copy ColumnMajor a |
533 | f # a # u # s #| st | 533 | f # u # s #| st |
534 | return (u,s) | 534 | return (u,s) |
535 | where | 535 | where |
536 | n = rows a | 536 | n = rows a |
537 | 537 | ||
538 | ----------------------------------------------------------------------------------- | 538 | ----------------------------------------------------------------------------------- |
539 | foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R | 539 | foreign import ccall unsafe "lu_l_R" dgetrf :: R :> R ::> Ok |
540 | foreign import ccall unsafe "lu_l_C" zgetrf :: C ::> R :> C ::> Ok | 540 | foreign import ccall unsafe "lu_l_C" zgetrf :: R :> C ::> Ok |
541 | 541 | ||
542 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. | 542 | -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. |
543 | luR :: Matrix Double -> (Matrix Double, [Int]) | 543 | luR :: Matrix Double -> (Matrix Double, [Int]) |
@@ -548,9 +548,9 @@ luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) | |||
548 | luC = luAux zgetrf "luC" . fmat | 548 | luC = luAux zgetrf "luC" . fmat |
549 | 549 | ||
550 | luAux f st a = unsafePerformIO $ do | 550 | luAux f st a = unsafePerformIO $ do |
551 | lu <- createMatrix ColumnMajor n m | 551 | lu <- copy ColumnMajor a |
552 | piv <- createVector (min n m) | 552 | piv <- createVector (min n m) |
553 | f # a # piv # lu #| st | 553 | f # piv # lu #| st |
554 | return (lu, map (pred.round) (toList piv)) | 554 | return (lu, map (pred.round) (toList piv)) |
555 | where | 555 | where |
556 | n = rows a | 556 | n = rows a |
@@ -558,10 +558,8 @@ luAux f st a = unsafePerformIO $ do | |||
558 | 558 | ||
559 | ----------------------------------------------------------------------------------- | 559 | ----------------------------------------------------------------------------------- |
560 | 560 | ||
561 | type Tlus t = t ::> Double :> t ::> t ::> Ok | 561 | foreign import ccall unsafe "luS_l_R" dgetrs :: R ::> R :> R ::> Ok |
562 | 562 | foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok | |
563 | foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R | ||
564 | foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C | ||
565 | 563 | ||
566 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. | 564 | -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. |
567 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double | 565 | lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double |
@@ -573,13 +571,13 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) | |||
573 | 571 | ||
574 | lusAux f st a piv b | 572 | lusAux f st a piv b |
575 | | n1==n2 && n2==n =unsafePerformIO $ do | 573 | | n1==n2 && n2==n =unsafePerformIO $ do |
576 | x <- createMatrix ColumnMajor n m | 574 | x <- copy ColumnMajor b |
577 | f # a # piv' # b # x #| st | 575 | f # a # piv' # x #| st |
578 | return x | 576 | return x |
579 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" | 577 | | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" |
580 | where n1 = rows a | 578 | where |
581 | n2 = cols a | 579 | n1 = rows a |
582 | n = rows b | 580 | n2 = cols a |
583 | m = cols b | 581 | n = rows b |
584 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double | 582 | piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double |
585 | 583 | ||