From c5ed204b8d6a36681c7ec6b227c634bfae501435 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sun, 28 Jun 2015 20:04:02 +0200 Subject: pass copied slice (qr, hess,schur,lu) --- packages/base/src/Internal/C/lapack-aux.c | 136 +++++++++++++----------------- packages/base/src/Internal/LAPACK.hs | 72 ++++++++-------- 2 files changed, 93 insertions(+), 115 deletions(-) (limited to 'packages/base/src/Internal') diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ab78dac..8524962 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -901,18 +901,17 @@ int chol_l_S(ODMAT(l)) { //////////////////// QR factorization ///////////////////////// -/* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer * +int dgeqr2_(integer *m, integer *n, doublereal *a, integer * lda, doublereal *tau, doublereal *work, integer *info); -int qr_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { - integer m = ar; - integer n = ac; +int qr_l_R(DVEC(tau), ODMAT(r)) { + integer m = rr; + integer n = rc; integer mn = MIN(m,n); - REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); + REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE); DEBUGMSG("qr_l_R"); double *WORK = (double*)malloc(n*sizeof(double)); CHECK(!WORK,MEM); - memcpy(rp,ap,m*n*sizeof(double)); integer res; dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); CHECK(res,res); @@ -920,18 +919,17 @@ int qr_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { OK } -/* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a, +int zgeqr2_(integer *m, integer *n, doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex *work, integer *info); -int qr_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { - integer m = ar; - integer n = ac; +int qr_l_C(CVEC(tau), OCMAT(r)) { + integer m = rr; + integer n = rc; integer mn = MIN(m,n); - REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE); + REQUIRES(m>=1 && n >=1 && taun == mn, BAD_SIZE); DEBUGMSG("qr_l_C"); doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex)); CHECK(!WORK,MEM); - memcpy(rp,ap,m*n*sizeof(doublecomplex)); integer res; zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res); CHECK(res,res); @@ -939,19 +937,18 @@ int qr_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { OK } -/* Subroutine */ int dorgqr_(integer *m, integer *n, integer *k, doublereal * +int dorgqr_(integer *m, integer *n, integer *k, doublereal * a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, integer *info); -int c_dorgqr(KODMAT(a), KDVEC(tau), ODMAT(r)) { - integer m = ar; - integer n = MIN(ac,ar); +int c_dorgqr(KDVEC(tau), ODMAT(r)) { + integer m = rr; + integer n = MIN(rc,rr); integer k = taun; DEBUGMSG("c_dorgqr"); integer lwork = 8*n; // FIXME double *WORK = (double*)malloc(lwork*sizeof(double)); CHECK(!WORK,MEM); - memcpy(rp,ap,m*k*sizeof(double)); integer res; dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res); CHECK(res,res); @@ -959,19 +956,18 @@ int c_dorgqr(KODMAT(a), KDVEC(tau), ODMAT(r)) { OK } -/* Subroutine */ int zungqr_(integer *m, integer *n, integer *k, +int zungqr_(integer *m, integer *n, integer *k, doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * work, integer *lwork, integer *info); -int c_zungqr(KOCMAT(a), KCVEC(tau), OCMAT(r)) { - integer m = ar; - integer n = MIN(ac,ar); +int c_zungqr(KCVEC(tau), OCMAT(r)) { + integer m = rr; + integer n = MIN(rc,rr); integer k = taun; DEBUGMSG("z_ungqr"); integer lwork = 8*n; // FIXME doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); CHECK(!WORK,MEM); - memcpy(rp,ap,m*k*sizeof(doublecomplex)); integer res; zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res); CHECK(res,res); @@ -982,20 +978,19 @@ int c_zungqr(KOCMAT(a), KCVEC(tau), OCMAT(r)) { //////////////////// Hessenberg factorization ///////////////////////// -/* Subroutine */ int dgehrd_(integer *n, integer *ilo, integer *ihi, +int dgehrd_(integer *n, integer *ilo, integer *ihi, doublereal *a, integer *lda, doublereal *tau, doublereal *work, integer *lwork, integer *info); -int hess_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { - integer m = ar; - integer n = ac; +int hess_l_R(DVEC(tau), ODMAT(r)) { + integer m = rr; + integer n = rc; integer mn = MIN(m,n); - REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); + REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE); DEBUGMSG("hess_l_R"); - integer lwork = 5*n; // fixme + integer lwork = 5*n; // FIXME double *WORK = (double*)malloc(lwork*sizeof(double)); CHECK(!WORK,MEM); - memcpy(rp,ap,m*n*sizeof(double)); integer res; integer one = 1; dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); @@ -1005,20 +1000,19 @@ int hess_l_R(KODMAT(a), DVEC(tau), ODMAT(r)) { } -/* Subroutine */ int zgehrd_(integer *n, integer *ilo, integer *ihi, +int zgehrd_(integer *n, integer *ilo, integer *ihi, doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex * work, integer *lwork, integer *info); -int hess_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { - integer m = ar; - integer n = ac; +int hess_l_C(CVEC(tau), OCMAT(r)) { + integer m = rr; + integer n = rc; integer mn = MIN(m,n); - REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE); + REQUIRES(m>=1 && n == m && taun == mn-1, BAD_SIZE); DEBUGMSG("hess_l_C"); - integer lwork = 5*n; // fixme + integer lwork = 5*n; // FIXME doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); CHECK(!WORK,MEM); - memcpy(rp,ap,m*n*sizeof(doublecomplex)); integer res; integer one = 1; zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res); @@ -1029,23 +1023,17 @@ int hess_l_C(KOCMAT(a), CVEC(tau), OCMAT(r)) { //////////////////// Schur factorization ///////////////////////// -/* Subroutine */ int dgees_(char *jobvs, char *sort, L_fp select, integer *n, +int dgees_(char *jobvs, char *sort, L_fp select, integer *n, doublereal *a, integer *lda, integer *sdim, doublereal *wr, doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work, integer *lwork, logical *bwork, integer *info); -int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) { - integer m = ar; - integer n = ac; - REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); +int schur_l_R(ODMAT(u), ODMAT(s)) { + integer m = sr; + integer n = sc; + REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE); DEBUGMSG("schur_l_R"); - //int k; - //printf("---------------------------\n"); - //printf("%p: ",ap); for(k=0;k0) { return NOCONVER; } @@ -1069,18 +1054,17 @@ int schur_l_R(KODMAT(a), ODMAT(u), ODMAT(s)) { } -/* Subroutine */ int zgees_(char *jobvs, char *sort, L_fp select, integer *n, +int zgees_(char *jobvs, char *sort, L_fp select, integer *n, doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w, doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork, doublereal *rwork, logical *bwork, integer *info); -int schur_l_C(KOCMAT(a), OCMAT(u), OCMAT(s)) { - integer m = ar; - integer n = ac; - REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE); +int schur_l_C(OCMAT(u), OCMAT(s)) { + integer m = sr; + integer n = sc; + REQUIRES(m>=1 && n==m && ur==n && uc==n, BAD_SIZE); DEBUGMSG("schur_l_C"); - memcpy(sp,ap,n*n*sizeof(doublecomplex)); - integer lwork = 6*n; // fixme + integer lwork = 6*n; // FIXME doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex)); doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex)); // W not really required in this call @@ -1103,21 +1087,20 @@ int schur_l_C(KOCMAT(a), OCMAT(u), OCMAT(s)) { //////////////////// LU factorization ///////////////////////// -/* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer * +int dgetrf_(integer *m, integer *n, doublereal *a, integer * lda, integer *ipiv, integer *info); -int lu_l_R(KODMAT(a), DVEC(ipiv), ODMAT(r)) { - integer m = ar; - integer n = ac; +int lu_l_R(DVEC(ipiv), ODMAT(r)) { + integer m = rr; + integer n = rc; integer mn = MIN(m,n); REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); DEBUGMSG("lu_l_R"); integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); - memcpy(rp,ap,m*n*sizeof(double)); integer res; dgetrf_ (&m,&n,rp,&m,auxipiv,&res); if(res>0) { - res = 0; // fixme + res = 0; // FIXME } CHECK(res,res); int k; @@ -1129,21 +1112,20 @@ int lu_l_R(KODMAT(a), DVEC(ipiv), ODMAT(r)) { } -/* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a, +int zgetrf_(integer *m, integer *n, doublecomplex *a, integer *lda, integer *ipiv, integer *info); -int lu_l_C(KOCMAT(a), DVEC(ipiv), OCMAT(r)) { - integer m = ar; - integer n = ac; +int lu_l_C(DVEC(ipiv), OCMAT(r)) { + integer m = rr; + integer n = rc; integer mn = MIN(m,n); REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); DEBUGMSG("lu_l_C"); integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); - memcpy(rp,ap,m*n*sizeof(doublecomplex)); integer res; zgetrf_ (&m,&n,rp,&m,auxipiv,&res); if(res>0) { - res = 0; // fixme + res = 0; // FIXME } CHECK(res,res); int k; @@ -1157,11 +1139,11 @@ int lu_l_C(KOCMAT(a), DVEC(ipiv), OCMAT(r)) { //////////////////// LU substitution ///////////////////////// -/* Subroutine */ int dgetrs_(char *trans, integer *n, integer *nrhs, +int dgetrs_(char *trans, integer *n, integer *nrhs, doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer * ldb, integer *info); -int luS_l_R(KODMAT(a), KDVEC(ipiv), KODMAT(b), ODMAT(x)) { +int luS_l_R(KODMAT(a), KDVEC(ipiv), ODMAT(b)) { integer m = ar; integer n = ac; integer mrhs = br; @@ -1174,19 +1156,18 @@ int luS_l_R(KODMAT(a), KDVEC(ipiv), KODMAT(b), ODMAT(x)) { auxipiv[k] = (integer)ipivp[k]; } integer res; - memcpy(xp,bp,mrhs*nrhs*sizeof(double)); - dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res); + dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,bp,&mrhs,&res); CHECK(res,res); free(auxipiv); OK } -/* Subroutine */ int zgetrs_(char *trans, integer *n, integer *nrhs, +int zgetrs_(char *trans, integer *n, integer *nrhs, doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer *info); -int luS_l_C(KOCMAT(a), KDVEC(ipiv), KOCMAT(b), OCMAT(x)) { +int luS_l_C(KOCMAT(a), KDVEC(ipiv), OCMAT(b)) { integer m = ar; integer n = ac; integer mrhs = br; @@ -1199,8 +1180,7 @@ int luS_l_C(KOCMAT(a), KDVEC(ipiv), KOCMAT(b), OCMAT(x)) { auxipiv[k] = (integer)ipivp[k]; } integer res; - memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex)); - zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,xp,&mrhs,&res); + zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,bp,&mrhs,&res); CHECK(res,res); free(auxipiv); OK diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index ce00c16..65deceb 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -455,29 +455,29 @@ mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" type TMVM t = t ::> t :> t ::> Ok -foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM R -foreign import ccall unsafe "qr_l_C" zgeqr2 :: TMVM C +foreign import ccall unsafe "qr_l_R" dgeqr2 :: R :> R ::> Ok +foreign import ccall unsafe "qr_l_C" zgeqr2 :: C :> C ::> Ok -- | QR factorization of a real matrix, using LAPACK's /dgeqr2/. qrR :: Matrix Double -> (Matrix Double, Vector Double) -qrR = qrAux dgeqr2 "qrR" . fmat +qrR = qrAux dgeqr2 "qrR" -- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/. qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) -qrC = qrAux zgeqr2 "qrC" . fmat +qrC = qrAux zgeqr2 "qrC" qrAux f st a = unsafePerformIO $ do - r <- createMatrix ColumnMajor m n + r <- copy ColumnMajor a tau <- createVector mn - f # a # tau # r #| st + f # tau # r #| st return (r,tau) where m = rows a n = cols a mn = min m n -foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM R -foreign import ccall unsafe "c_zungqr" zungqr :: TMVM C +foreign import ccall unsafe "c_dorgqr" dorgqr :: R :> R ::> Ok +foreign import ccall unsafe "c_zungqr" zungqr :: C :> C ::> Ok -- | build rotation from reflectors qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double @@ -487,28 +487,28 @@ qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Co qrgrC = qrgrAux zungqr "qrgrC" qrgrAux f st n (a, tau) = unsafePerformIO $ do - res <- createMatrix ColumnMajor (rows a) n - f # (fmat a) # (subVector 0 n tau') # res #| st + res <- copy ColumnMajor (sliceMatrix (0,0) (rows a,n) a) + f # (subVector 0 n tau') # res #| st return res where tau' = vjoin [tau, constantD 0 n] ----------------------------------------------------------------------------------- -foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM R -foreign import ccall unsafe "hess_l_C" zgehrd :: TMVM C +foreign import ccall unsafe "hess_l_R" dgehrd :: R :> R ::> Ok +foreign import ccall unsafe "hess_l_C" zgehrd :: C :> C ::> Ok -- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/. hessR :: Matrix Double -> (Matrix Double, Vector Double) -hessR = hessAux dgehrd "hessR" . fmat +hessR = hessAux dgehrd "hessR" -- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/. hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) -hessC = hessAux zgehrd "hessC" . fmat +hessC = hessAux zgehrd "hessC" hessAux f st a = unsafePerformIO $ do - r <- createMatrix ColumnMajor m n + r <- copy ColumnMajor a tau <- createVector (mn-1) - f # a # tau # r #| st + f # tau # r #| st return (r,tau) where m = rows a @@ -516,28 +516,28 @@ hessAux f st a = unsafePerformIO $ do mn = min m n ----------------------------------------------------------------------------------- -foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> R ::> Ok -foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> C ::> Ok +foreign import ccall unsafe "schur_l_R" dgees :: R ::> R ::> Ok +foreign import ccall unsafe "schur_l_C" zgees :: C ::> C ::> Ok -- | Schur factorization of a square real matrix, using LAPACK's /dgees/. schurR :: Matrix Double -> (Matrix Double, Matrix Double) -schurR = schurAux dgees "schurR" . fmat +schurR = schurAux dgees "schurR" -- | Schur factorization of a square complex matrix, using LAPACK's /zgees/. schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) -schurC = schurAux zgees "schurC" . fmat +schurC = schurAux zgees "schurC" schurAux f st a = unsafePerformIO $ do u <- createMatrix ColumnMajor n n - s <- createMatrix ColumnMajor n n - f # a # u # s #| st + s <- copy ColumnMajor a + f # u # s #| st return (u,s) where n = rows a ----------------------------------------------------------------------------------- -foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM R -foreign import ccall unsafe "lu_l_C" zgetrf :: C ::> R :> C ::> Ok +foreign import ccall unsafe "lu_l_R" dgetrf :: R :> R ::> Ok +foreign import ccall unsafe "lu_l_C" zgetrf :: R :> C ::> Ok -- | LU factorization of a general real matrix, using LAPACK's /dgetrf/. luR :: Matrix Double -> (Matrix Double, [Int]) @@ -548,9 +548,9 @@ luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) luC = luAux zgetrf "luC" . fmat luAux f st a = unsafePerformIO $ do - lu <- createMatrix ColumnMajor n m + lu <- copy ColumnMajor a piv <- createVector (min n m) - f # a # piv # lu #| st + f # piv # lu #| st return (lu, map (pred.round) (toList piv)) where n = rows a @@ -558,10 +558,8 @@ luAux f st a = unsafePerformIO $ do ----------------------------------------------------------------------------------- -type Tlus t = t ::> Double :> t ::> t ::> Ok - -foreign import ccall unsafe "luS_l_R" dgetrs :: Tlus R -foreign import ccall unsafe "luS_l_C" zgetrs :: Tlus C +foreign import ccall unsafe "luS_l_R" dgetrs :: R ::> R :> R ::> Ok +foreign import ccall unsafe "luS_l_C" zgetrs :: C ::> R :> C ::> Ok -- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/. lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double @@ -573,13 +571,13 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) lusAux f st a piv b | n1==n2 && n2==n =unsafePerformIO $ do - x <- createMatrix ColumnMajor n m - f # a # piv' # b # x #| st + x <- copy ColumnMajor b + f # a # piv' # x #| st return x | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" - where n1 = rows a - n2 = cols a - n = rows b - m = cols b - piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double + where + n1 = rows a + n2 = cols a + n = rows b + piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double -- cgit v1.2.3