From 9c05df0cd663bafaf0b69eafee53fce45569dc95 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 29 Jun 2015 16:37:51 +0200 Subject: pass copied slice in linearSolve --- packages/base/src/Internal/C/lapack-aux.c | 165 ++++++++++-------------------- packages/base/src/Internal/LAPACK.hs | 74 +++++++++----- 2 files changed, 98 insertions(+), 141 deletions(-) diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 1018126..ca60846 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -314,22 +314,19 @@ int svd_l_Cdd(OCMAT(a),OCMAT(u), DVEC(s),OCMAT(v)) { } double *rwk = (double*)malloc(lrwk*sizeof(double));; CHECK(!rwk,MEM); - //printf("%s %ld %d\n",jobz,q,lrwk); integer lwk = -1; integer res; // ask for optimal lwk doublecomplex ans; zgesdd_ (jobz,&m,&n,ap,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,rwk,iwk,&res); lwk = ans.r; - //printf("lwk = %ld\n",lwk); doublecomplex * workv = (doublecomplex*)malloc(lwk*sizeof(doublecomplex)); CHECK(!workv,MEM); zgesdd_ (jobz,&m,&n,ap,&m,sp,up,&m,vp,&ldvt,workv,&lwk,rwk,iwk,&res); - //printf("res = %ld\n",res); CHECK(res,res); - free(workv); // printf("freed workv\n"); - free(rwk); // printf("freed rwk\n"); - free(iwk); // printf("freed iwk\n"); + free(workv); + free(rwk); + free(iwk); OK } @@ -498,80 +495,72 @@ int eig_l_H(int wantV,DVEC(s),OCMAT(v)) { //////////////////// general real linear system //////////// -/* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer +int dgesv_(integer *n, integer *nrhs, doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info); -int linearSolveR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { +int linearSolveR_l(ODMAT(a),ODMAT(b)) { integer n = ar; integer nhrs = bc; REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); DEBUGMSG("linearSolveR_l"); - double*AC = (double*)malloc(n*n*sizeof(double)); - memcpy(AC,ap,n*n*sizeof(double)); - memcpy(xp,bp,n*nhrs*sizeof(double)); integer * ipiv = (integer*)malloc(n*sizeof(integer)); integer res; dgesv_ (&n,&nhrs, - AC, &n, + ap, &n, ipiv, - xp, &n, + bp, &n, &res); if(res>0) { return SINGULAR; } CHECK(res,res); free(ipiv); - free(AC); OK } //////////////////// general complex linear system //////////// -/* Subroutine */ int zgesv_(integer *n, integer *nrhs, doublecomplex *a, +int zgesv_(integer *n, integer *nrhs, doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer * info); -int linearSolveC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { +int linearSolveC_l(OCMAT(a),OCMAT(b)) { integer n = ar; integer nhrs = bc; REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); DEBUGMSG("linearSolveC_l"); - doublecomplex*AC = (doublecomplex*)malloc(n*n*sizeof(doublecomplex)); - memcpy(AC,ap,n*n*sizeof(doublecomplex)); - memcpy(xp,bp,n*nhrs*sizeof(doublecomplex)); integer * ipiv = (integer*)malloc(n*sizeof(integer)); integer res; zgesv_ (&n,&nhrs, - AC, &n, + ap, &n, ipiv, - xp, &n, + bp, &n, &res); if(res>0) { return SINGULAR; } CHECK(res,res); free(ipiv); - free(AC); OK } //////// symmetric positive definite real linear system using Cholesky //////////// -/* Subroutine */ int dpotrs_(char *uplo, integer *n, integer *nrhs, +int dpotrs_(char *uplo, integer *n, integer *nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * info); -int cholSolveR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { +int cholSolveR_l(KODMAT(a),ODMAT(b)) { integer n = ar; + integer lda = aXc; integer nhrs = bc; REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); DEBUGMSG("cholSolveR_l"); - memcpy(xp,bp,n*nhrs*sizeof(double)); integer res; dpotrs_ ("U", &n,&nhrs, - (double*)ap, &n, - xp, &n, + (double*)ap, &lda, + bp, &n, &res); CHECK(res,res); OK @@ -579,21 +568,21 @@ int cholSolveR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { //////// Hermitian positive definite real linear system using Cholesky //////////// -/* Subroutine */ int zpotrs_(char *uplo, integer *n, integer *nrhs, +int zpotrs_(char *uplo, integer *n, integer *nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, integer *info); -int cholSolveC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { +int cholSolveC_l(KOCMAT(a),OCMAT(b)) { integer n = ar; + integer lda = aXc; integer nhrs = bc; REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); DEBUGMSG("cholSolveC_l"); - memcpy(xp,bp,n*nhrs*sizeof(doublecomplex)); integer res; zpotrs_ ("U", &n,&nhrs, - (doublecomplex*)ap, &n, - xp, &n, + (doublecomplex*)ap, &lda, + bp, &n, &res); CHECK(res,res); OK @@ -601,41 +590,30 @@ int cholSolveC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { //////////////////// least squares real linear system //////////// -/* Subroutine */ int dgels_(char *trans, integer *m, integer *n, integer * +int dgels_(char *trans, integer *m, integer *n, integer * nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, doublereal *work, integer *lwork, integer *info); -int linearSolveLSR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { +int linearSolveLSR_l(ODMAT(a),ODMAT(b)) { integer m = ar; integer n = ac; integer nrhs = bc; - integer ldb = xr; - REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); + integer ldb = bXc; + REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); DEBUGMSG("linearSolveLSR_l"); - double*AC = (double*)malloc(m*n*sizeof(double)); - memcpy(AC,ap,m*n*sizeof(double)); - if (m>=n) { - memcpy(xp,bp,m*nrhs*sizeof(double)); - } else { - int k; - for(k = 0; k0) { @@ -643,47 +621,35 @@ int linearSolveLSR_l(KODMAT(a),KODMAT(b),ODMAT(x)) { } CHECK(res,res); free(work); - free(AC); OK } //////////////////// least squares complex linear system //////////// -/* Subroutine */ int zgels_(char *trans, integer *m, integer *n, integer * +int zgels_(char *trans, integer *m, integer *n, integer * nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, doublecomplex *work, integer *lwork, integer *info); -int linearSolveLSC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { +int linearSolveLSC_l(OCMAT(a),OCMAT(b)) { integer m = ar; integer n = ac; integer nrhs = bc; - integer ldb = xr; - REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); + integer ldb = bXc; + REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); DEBUGMSG("linearSolveLSC_l"); - doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); - memcpy(AC,ap,m*n*sizeof(doublecomplex)); - if (m>=n) { - memcpy(xp,bp,m*nrhs*sizeof(doublecomplex)); - } else { - int k; - for(k = 0; k0) { @@ -691,52 +657,40 @@ int linearSolveLSC_l(KOCMAT(a),KOCMAT(b),OCMAT(x)) { } CHECK(res,res); free(work); - free(AC); OK } //////////////////// least squares real linear system using SVD //////////// -/* Subroutine */ int dgelss_(integer *m, integer *n, integer *nrhs, +int dgelss_(integer *m, integer *n, integer *nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb, doublereal * s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork, integer *info); -int linearSolveSVDR_l(double rcond,KODMAT(a),KODMAT(b),ODMAT(x)) { +int linearSolveSVDR_l(double rcond,ODMAT(a),ODMAT(b)) { integer m = ar; integer n = ac; integer nrhs = bc; - integer ldb = xr; - REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); + integer ldb = bXc; + REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); DEBUGMSG("linearSolveSVDR_l"); - double*AC = (double*)malloc(m*n*sizeof(double)); double*S = (double*)malloc(MIN(m,n)*sizeof(double)); - memcpy(AC,ap,m*n*sizeof(double)); - if (m>=n) { - memcpy(xp,bp,m*nrhs*sizeof(double)); - } else { - int k; - for(k = 0; k=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE); + integer ldb = bXc; + REQUIRES(m>=1 && n>=1 && br==MAX(m,n), BAD_SIZE); DEBUGMSG("linearSolveSVDC_l"); - doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex)); double*S = (double*)malloc(MIN(m,n)*sizeof(double)); double*RWORK = (double*)malloc(5*MIN(m,n)*sizeof(double)); - memcpy(AC,ap,m*n*sizeof(doublecomplex)); - if (m>=n) { - memcpy(xp,bp,m*nrhs*sizeof(doublecomplex)); - } else { - int k; - for(k = 0; k R ::> Ok +foreign import ccall unsafe "linearSolveC_l" zgesv :: C ::> C ::> Ok linearSolveSQAux g f st a b | n1==n2 && n1==r = unsafePerformIO . g $ do - s <- createMatrix ColumnMajor r c - f # a # b # s #| st + a' <- copy ColumnMajor a + s <- copy ColumnMajor b + f # a' # s #| st return s | otherwise = error $ st ++ " of nonsquare matrix" where n1 = rows a n2 = cols a r = rows b - c = cols b -- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'. linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double -linearSolveR a b = linearSolveSQAux id dgesv "linearSolveR" (fmat a) (fmat b) +linearSolveR a b = linearSolveSQAux id dgesv "linearSolveR" a b mbLinearSolveR :: Matrix Double -> Matrix Double -> Maybe (Matrix Double) -mbLinearSolveR a b = linearSolveSQAux mbCatch dgesv "linearSolveR" (fmat a) (fmat b) +mbLinearSolveR a b = linearSolveSQAux mbCatch dgesv "linearSolveR" a b -- | Solve a complex linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /zgesv/. For underconstrained or overconstrained systems use 'linearSolveLSC' or 'linearSolveSVDC'. See also 'lusC'. linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) -linearSolveC a b = linearSolveSQAux id zgesv "linearSolveC" (fmat a) (fmat b) +linearSolveC a b = linearSolveSQAux id zgesv "linearSolveC" a b mbLinearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Maybe (Matrix (Complex Double)) -mbLinearSolveC a b = linearSolveSQAux mbCatch zgesv "linearSolveC" (fmat a) (fmat b) +mbLinearSolveC a b = linearSolveSQAux mbCatch zgesv "linearSolveC" a b + +-------------------------------------------------------------------------------- +foreign import ccall unsafe "cholSolveR_l" dpotrs :: R ::> R ::> Ok +foreign import ccall unsafe "cholSolveC_l" zpotrs :: C ::> C ::> Ok + + +linearSolveSQAux2 g f st a b + | n1==n2 && n1==r = unsafePerformIO . g $ do + s <- copy ColumnMajor b + f # a # s #| st + return s + | otherwise = error $ st ++ " of nonsquare matrix" + where + n1 = rows a + n2 = cols a + r = rows b -- | Solves a symmetric positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholS'. cholSolveR :: Matrix Double -> Matrix Double -> Matrix Double -cholSolveR a b = linearSolveSQAux id dpotrs "cholSolveR" (fmat a) (fmat b) +cholSolveR a b = linearSolveSQAux2 id dpotrs "cholSolveR" (fmat a) b -- | Solves a Hermitian positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholH'. cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) -cholSolveC a b = linearSolveSQAux id zpotrs "cholSolveC" (fmat a) (fmat b) +cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b ----------------------------------------------------------------------------------- -foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM R -foreign import ccall unsafe "linearSolveLSC_l" zgels :: TMMM C -foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM R -foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C - -linearSolveAux f st a b = unsafePerformIO $ do - r <- createMatrix ColumnMajor (max m n) nrhs - f # a # b # r #| st - return r +foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok +foreign import ccall unsafe "linearSolveLSC_l" zgels :: C ::> C ::> Ok +foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> R ::> R ::> Ok +foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> C ::> C ::> Ok + +linearSolveAux f st a b + | m == rows b = unsafePerformIO $ do + a' <- copy ColumnMajor a + r <- createMatrix ColumnMajor (max m n) nrhs + setRect 0 0 b r + f # a' # r #| st + return r + | otherwise = error $ "different number of rows in linearSolve ("++st++")" where m = rows a n = cols a @@ -408,12 +426,12 @@ linearSolveAux f st a b = unsafePerformIO $ do -- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'. linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ - linearSolveAux dgels "linearSolverLSR" (fmat a) (fmat b) + linearSolveAux dgels "linearSolverLSR" a b -- | Least squared error solution of an overconstrained complex linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /zgels/. For rank-deficient systems use 'linearSolveSVDC'. linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ - linearSolveAux zgels "linearSolveLSC" (fmat a) (fmat b) + linearSolveAux zgels "linearSolveLSC" a b -- | Minimum norm solution of a general real linear least squares problem Ax=B using the SVD, based on LAPACK's /dgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. linearSolveSVDR :: Maybe Double -- ^ rcond @@ -421,8 +439,8 @@ linearSolveSVDR :: Maybe Double -- ^ rcond -> Matrix Double -- ^ right hand sides (as columns) -> Matrix Double -- ^ solution vectors (as columns) linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ - linearSolveAux (dgelss rcond) "linearSolveSVDR" (fmat a) (fmat b) -linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) (fmat a) (fmat b) + linearSolveAux (dgelss rcond) "linearSolveSVDR" a b +linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b -- | Minimum norm solution of a general complex linear least squares problem Ax=B using the SVD, based on LAPACK's /zgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. linearSolveSVDC :: Maybe Double -- ^ rcond @@ -430,8 +448,8 @@ linearSolveSVDC :: Maybe Double -- ^ rcond -> Matrix (Complex Double) -- ^ right hand sides (as columns) -> Matrix (Complex Double) -- ^ solution vectors (as columns) linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ - linearSolveAux (zgelss rcond) "linearSolveSVDC" (fmat a) (fmat b) -linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) + linearSolveAux (zgelss rcond) "linearSolveSVDC" a b +linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b ----------------------------------------------------------------------------------- -- cgit v1.2.3