From fa1642dcf26f1da0a6f4c1324bcd1e8baf9fd478 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 17 Mar 2017 14:20:07 +0000 Subject: Support triangular matrices. --- packages/base/src/Internal/Algorithms.hs | 14 ++++- packages/base/src/Internal/C/lapack-aux.c | 84 ++++++++++++++++++++++++++++++ packages/base/src/Internal/LAPACK.hs | 30 +++++++++++ packages/base/src/Numeric/LinearAlgebra.hs | 3 ++ 4 files changed, 129 insertions(+), 2 deletions(-) (limited to 'packages/base') diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 70d65d7..23a5e13 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs @@ -20,13 +20,16 @@ imported from "Numeric.LinearAlgebra.LAPACK". -} ----------------------------------------------------------------------------- -module Internal.Algorithms where +module Internal.Algorithms ( + module Internal.Algorithms, + UpLo(..) +) where import Internal.Vector import Internal.Matrix import Internal.Element import Internal.Conversion -import Internal.LAPACK as LAPACK +import Internal.LAPACK import Internal.Numeric import Data.List(foldl1') import qualified Data.Array as A @@ -58,6 +61,7 @@ class (Numeric t, mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) linearSolve' :: Matrix t -> Matrix t -> Matrix t cholSolve' :: Matrix t -> Matrix t -> Matrix t + triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t ldlPacked' :: Matrix t -> (Matrix t, [Int]) ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t @@ -83,6 +87,7 @@ instance Field Double where linearSolve' = linearSolveR -- (luSolve . luPacked) ?? mbLinearSolve' = mbLinearSolveR cholSolve' = cholSolveR + triSolve' = triSolveR linearSolveLS' = linearSolveLSR linearSolveSVD' = linearSolveSVDR Nothing eig' = eigR @@ -112,6 +117,7 @@ instance Field (Complex Double) where linearSolve' = linearSolveC mbLinearSolve' = mbLinearSolveC cholSolve' = cholSolveC + triSolve' = triSolveC linearSolveLS' = linearSolveLSC linearSolveSVD' = linearSolveSVDC Nothing eig' = eigC @@ -350,6 +356,10 @@ cholSolve -> Matrix t -- ^ solution cholSolve = {-# SCC "cholSolve" #-} cholSolve' +-- | Solve a triangular linear system. +triSolve :: Field t => UpLo -> Matrix t -> Matrix t -> Matrix t +triSolve = {-# SCC "triSolve" #-} triSolve' + -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ff7ad92..4a8129c 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -584,6 +584,90 @@ int cholSolveC_l(KOCMAT(a),OCMAT(b)) { OK } +//////// triangular real linear system //////////// + +int dtrtrs_(char *uplo, char *trans, char *diag, integer *n, integer *nrhs, + doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * + info); + +int triSolveR_l_u(KODMAT(a),ODMAT(b)) { + integer n = ar; + integer lda = aXc; + integer nhrs = bc; + REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); + DEBUGMSG("triSolveR_l_u"); + integer res; + dtrtrs_ ("U", + "N", + "N", + &n,&nhrs, + (double*)ap, &lda, + bp, &n, + &res); + CHECK(res,res); + OK +} + +int triSolveR_l_l(KODMAT(a),ODMAT(b)) { + integer n = ar; + integer lda = aXc; + integer nhrs = bc; + REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); + DEBUGMSG("triSolveR_l_l"); + integer res; + dtrtrs_ ("L", + "N", + "N", + &n,&nhrs, + (double*)ap, &lda, + bp, &n, + &res); + CHECK(res,res); + OK +} + +//////// triangular complex linear system //////////// + +int ztrtrs_(char *uplo, char *trans, char *diag, integer *n, integer *nrhs, + doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, + integer *info); + +int triSolveC_l_u(KOCMAT(a),OCMAT(b)) { + integer n = ar; + integer lda = aXc; + integer nhrs = bc; + REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); + DEBUGMSG("triSolveC_l_u"); + integer res; + ztrtrs_ ("U", + "N", + "N", + &n,&nhrs, + (doublecomplex*)ap, &lda, + bp, &n, + &res); + CHECK(res,res); + OK +} + +int triSolveC_l_l(KOCMAT(a),OCMAT(b)) { + integer n = ar; + integer lda = aXc; + integer nhrs = bc; + REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); + DEBUGMSG("triSolveC_l_u"); + integer res; + ztrtrs_ ("L", + "N", + "N", + &n,&nhrs, + (doublecomplex*)ap, &lda, + bp, &n, + &res); + CHECK(res,res); + OK +} + //////////////////// least squares real linear system //////////// int dgels_(char *trans, integer *m, integer *n, integer * diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 231109a..b4dd5cf 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -406,6 +406,36 @@ cholSolveR a b = linearSolveSQAux2 id dpotrs "cholSolveR" (fmat a) b cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b +-------------------------------------------------------------------------------- +foreign import ccall unsafe "triSolveR_l_u" dtrtrs_u :: R ::> R ::> Ok +foreign import ccall unsafe "triSolveC_l_u" ztrtrs_u :: C ::> C ::> Ok +foreign import ccall unsafe "triSolveR_l_l" dtrtrs_l :: R ::> R ::> Ok +foreign import ccall unsafe "triSolveC_l_l" ztrtrs_l :: C ::> C ::> Ok + + +linearSolveTRAux2 g f st a b + | n1==n2 && n1==r = unsafePerformIO . g $ do + s <- copy ColumnMajor b + (a #! s) f #| st + return s + | otherwise = error $ st ++ " of nonsquare matrix" + where + n1 = rows a + n2 = cols a + r = rows b + +data UpLo = Lower | Upper + +-- | Solves a triangular system of linear equations. +triSolveR :: UpLo -> Matrix Double -> Matrix Double -> Matrix Double +triSolveR Lower a b = linearSolveTRAux2 id dtrtrs_l "triSolveR" (fmat a) b +triSolveR Upper a b = linearSolveTRAux2 id dtrtrs_u "triSolveR" (fmat a) b + +-- | Solves a triangular system of linear equations. +triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b +triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b + ----------------------------------------------------------------------------------- foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index badf8f9..869330c 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -94,6 +94,9 @@ module Numeric.LinearAlgebra ( ldlSolve, ldlPacked, -- ** Positive definite cholSolve, + -- ** Triangular + UpLo(..), + triSolve, -- ** Sparse cgSolve, cgSolve', -- cgit v1.2.3