From 49d718705d205d62aea2762445f95735a671f305 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Tue, 21 Mar 2017 17:35:43 +0000 Subject: Add tridiagonal solver and tests for it and triagonal solver. --- packages/base/src/Internal/Algorithms.hs | 76 ++++++++++++++- packages/base/src/Internal/C/lapack-aux.c | 72 +++++++++++++++ packages/base/src/Internal/LAPACK.hs | 22 +++++ packages/base/src/Numeric/LinearAlgebra.hs | 2 + packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 107 ++++++++++++++++++++++ 5 files changed, 277 insertions(+), 2 deletions(-) diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 23a5e13..99c9e34 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs @@ -62,6 +62,7 @@ class (Numeric t, linearSolve' :: Matrix t -> Matrix t -> Matrix t cholSolve' :: Matrix t -> Matrix t -> Matrix t triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t + triDiagSolve' :: Vector t -> Vector t -> Vector t -> Matrix t -> Matrix t ldlPacked' :: Matrix t -> (Matrix t, [Int]) ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t @@ -88,6 +89,7 @@ instance Field Double where mbLinearSolve' = mbLinearSolveR cholSolve' = cholSolveR triSolve' = triSolveR + triDiagSolve' = triDiagSolveR linearSolveLS' = linearSolveLSR linearSolveSVD' = linearSolveSVDR Nothing eig' = eigR @@ -118,6 +120,7 @@ instance Field (Complex Double) where mbLinearSolve' = mbLinearSolveC cholSolve' = cholSolveC triSolve' = triSolveC + triDiagSolve' = triDiagSolveC linearSolveLS' = linearSolveLSC linearSolveSVD' = linearSolveSVDC Nothing eig' = eigC @@ -356,10 +359,79 @@ cholSolve -> Matrix t -- ^ solution cholSolve = {-# SCC "cholSolve" #-} cholSolve' --- | Solve a triangular linear system. -triSolve :: Field t => UpLo -> Matrix t -> Matrix t -> Matrix t +-- | Solve a triangular linear system. If `Upper` is specified then +-- all elements below the diagonal are ignored; if `Lower` is +-- specified then all elements above the diagonal are ignored. +triSolve + :: Field t + => UpLo -- ^ `Lower` or `Upper` + -> Matrix t -- ^ coefficient matrix + -> Matrix t -- ^ right hand sides + -> Matrix t -- ^ solution triSolve = {-# SCC "triSolve" #-} triSolve' +-- | Solve a tridiagonal linear system. Suppose you wish to solve \(Ax = b\) where +-- +-- \[ +-- A = +-- \begin{bmatrix} +-- 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 +-- \\ 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 +-- \\ 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 +-- \\ 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 +-- \\ 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 +-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 +-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 +-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 +-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 +-- \end{bmatrix} +-- \quad +-- b = +-- \begin{bmatrix} +-- 1.0 & 1.0 & 1.0 +-- \\ 1.0 & -1.0 & 2.0 +-- \\ 1.0 & 1.0 & 3.0 +-- \\ 1.0 & -1.0 & 4.0 +-- \\ 1.0 & 1.0 & 5.0 +-- \\ 1.0 & -1.0 & 6.0 +-- \\ 1.0 & 1.0 & 7.0 +-- \\ 1.0 & -1.0 & 8.0 +-- \\ 1.0 & 1.0 & 9.0 +-- \end{bmatrix} +-- \] +-- +-- then +-- +-- @ +-- dL = fromList [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0] +-- d = fromList [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] +-- dU = fromList [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0] +-- +-- b = (9><3) +-- [ +-- 1.0, 1.0, 1.0, +-- 1.0, -1.0, 2.0, +-- 1.0, 1.0, 3.0, +-- 1.0, -1.0, 4.0, +-- 1.0, 1.0, 5.0, +-- 1.0, -1.0, 6.0, +-- 1.0, 1.0, 7.0, +-- 1.0, -1.0, 8.0, +-- 1.0, 1.0, 9.0 +-- ] +-- +-- x = triDiagSolve dL d dU b +-- @ +-- +triDiagSolve + :: Field t + => Vector t -- ^ lower diagonal: \(n - 1\) elements + -> Vector t -- ^ diagonal: \(n\) elements + -> Vector t -- ^ upper diagonal: \(n - 1\) elements + -> Matrix t -- ^ right hand sides + -> Matrix t -- ^ solution +triDiagSolve = {-# SCC "triDiagSolve" #-} triDiagSolve' + -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 4a8129c..5018a45 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -668,6 +668,78 @@ int triSolveC_l_l(KOCMAT(a),OCMAT(b)) { OK } +//////// tridiagonal real linear system //////////// + +int dgttrf_(integer *n, + doublereal *dl, doublereal *d, doublereal *du, doublereal *du2, + integer *ipiv, + integer *info); + +int dgttrs_(char *trans, integer *n, integer *nrhs, + doublereal *dl, doublereal *d, doublereal *du, doublereal *du2, + integer *ipiv, doublereal *b, integer *ldb, + integer *info); + +int triDiagSolveR_l(DVEC(dl), DVEC(d), DVEC(du), ODMAT(b)) { + integer n = dn; + integer nhrs = bc; + REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE); + DEBUGMSG("triDiagSolveR_l"); + integer res; + integer* ipiv = (integer*)malloc(n*sizeof(integer)); + double* du2 = (double*)malloc((n - 2)*sizeof(double)); + dgttrf_ (&n, + dlp, dp, dup, du2, + ipiv, + &res); + CHECK(res,res); + dgttrs_ ("N", + &n,&nhrs, + dlp, dp, dup, du2, + ipiv, bp, &n, + &res); + CHECK(res,res); + free(ipiv); + free(du2); + OK +} + +//////// tridiagonal complex linear system //////////// + +int zgttrf_(integer *n, + doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2, + integer *ipiv, + integer *info); + +int zgttrs_(char *trans, integer *n, integer *nrhs, + doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2, + integer *ipiv, doublecomplex *b, integer *ldb, + integer *info); + +int triDiagSolveC_l(CVEC(dl), CVEC(d), CVEC(du), OCMAT(b)) { + integer n = dn; + integer nhrs = bc; + REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE); + DEBUGMSG("triDiagSolveC_l"); + integer res; + integer* ipiv = (integer*)malloc(n*sizeof(integer)); + doublecomplex* du2 = (doublecomplex*)malloc((n - 2)*sizeof(doublecomplex)); + zgttrf_ (&n, + dlp, dp, dup, du2, + ipiv, + &res); + CHECK(res,res); + zgttrs_ ("N", + &n,&nhrs, + dlp, dp, dup, du2, + ipiv, bp, &n, + &res); + CHECK(res,res); + free(ipiv); + free(du2); + OK +} + //////////////////// least squares real linear system //////////// int dgels_(char *trans, integer *m, integer *n, integer * diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index b4dd5cf..e306454 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -436,6 +436,28 @@ triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matri triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b +-------------------------------------------------------------------------------- +foreign import ccall unsafe "triDiagSolveR_l" dgttrs :: R :> R :> R :> R ::> Ok +foreign import ccall unsafe "triDiagSolveC_l" zgttrs :: C :> C :> C :> C ::> Ok + +linearSolveGTAux2 g f st dl d du b + | ndl == nd - 1 && + ndu == nd - 1 && + nd == r = unsafePerformIO . g $ do + s <- copy ColumnMajor b + (dl # d # du #! s) f #| st + return s + | otherwise = error $ st ++ " of nonsquare matrix" + where + ndl = dim dl + nd = dim d + ndu = dim du + r = rows b + +-- | Solves a tridiagonal system of linear equations. +triDiagSolveR dl d du b = linearSolveGTAux2 id dgttrs "triDiagSolveR" dl d du b +triDiagSolveC dl d du b = linearSolveGTAux2 id zgttrs "triDiagSolveC" dl d du b + ----------------------------------------------------------------------------------- foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 869330c..fd100e0 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -97,6 +97,8 @@ module Numeric.LinearAlgebra ( -- ** Triangular UpLo(..), triSolve, + -- ** Tridiagonal + triDiagSolve, -- ** Sparse cgSolve, cgSolve', diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index 043ebf3..55a5f74 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs @@ -131,6 +131,111 @@ mbCholTest = utest "mbCholTest" (ok1 && ok2) where ok1 = mbChol (trustSym m1) == Nothing ok2 = mbChol (trustSym m2) == Just (chol $ trustSym m2) +----------------------------------------------------- + +triTest = utest "triTest" ok1 where + + a :: Matrix R + a = (4><4) + [ + 4.30, 0.00, 0.00, 0.00, + -3.96, -4.87, 0.00, 0.00, + 0.40, 0.31, -8.02, 0.00, + -0.27, 0.07, -5.95, 0.12 + ] + + w :: Matrix R + w = (4><2) + [ + -12.90, -21.50, + 16.75, 14.93, + -17.55, 6.33, + -11.04, 8.09 + ] + + v :: Matrix R + v = triSolve Lower a w + + e :: Matrix R + e = (4><2) + [ + -3.0000, -5.0000, + -1.0000, 1.0000, + 2.0000, -1.0000, + 1.0000, 6.0000 + ] + + ok1 = (maximum $ map abs $ concat $ toLists $ e - v) <= 1e-14 + +----------------------------------------------------- + +triDiagTest = utest "triDiagTest" (ok1 && ok2) where + + dL, d, dU :: Vector Double + dL = fromList [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0] + d = fromList [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + dU = fromList [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0] + + b :: Matrix R + b = (9><3) + [ + 1.0, 1.0, 1.0, + 1.0, -1.0, 2.0, + 1.0, 1.0, 3.0, + 1.0, -1.0, 4.0, + 1.0, 1.0, 5.0, + 1.0, -1.0, 6.0, + 1.0, 1.0, 7.0, + 1.0, -1.0, 8.0, + 1.0, 1.0, 9.0 + ] + + y :: Matrix R + y = (9><9) + [ + 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 3.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 3.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 3.0, 1.0, 4.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 3.0, 1.0, 4.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 3.0, 1.0, 4.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 1.0, 4.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 1.0, 4.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 1.0 + ] + + x :: Matrix R + x = triDiagSolve dL d dU b + + z :: Matrix C + z = (4><4) + [ + 1.0 :+ 1.0, 4.0 :+ 4.0, 0.0 :+ 0.0, 0.0 :+ 0.0, + 3.0 :+ 3.0, 1.0 :+ 1.0, 4.0 :+ 4.0, 0.0 :+ 0.0, + 0.0 :+ 0.0, 3.0 :+ 3.0, 1.0 :+ 1.0, 4.0 :+ 4.0, + 0.0 :+ 0.0, 0.0 :+ 0.0, 3.0 :+ 3.0, 1.0 :+ 1.0 + ] + + zDL, zD, zDu :: Vector C + zDL = fromList [3.0 :+ 3.0, 3.0 :+ 3.0, 3.0 :+ 3.0] + zD = fromList [1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ 1.0] + zDu = fromList [4.0 :+ 4.0, 4.0 :+ 4.0, 4.0 :+ 4.0] + + zB :: Matrix C + zB = (4><3) + [ + 1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ (-1.0), + 1.0 :+ 1.0, (-1.0) :+ (-1.0), 1.0 :+ (-1.0), + 1.0 :+ 1.0, 1.0 :+ 1.0, 1.0 :+ (-1.0), + 1.0 :+ 1.0, (-1.0) :+ (-1.0), 1.0 :+ (-1.0) + ] + + u :: Matrix C + u = triDiagSolve zDL zD zDu zB + + ok1 = (maximum $ map abs $ concat $ toLists $ b - (y <> x)) <= 1e-15 + ok2 = (maximum $ map magnitude $ concat $ toLists $ zB - (z <> u)) <= 1e-15 + --------------------------------------------------------------------- randomTestGaussian = (unSym c) :~3~: unSym (snd (meanCov dat)) @@ -715,6 +820,8 @@ runTests n = do && rank ((2><3)[1,0,0,1,7*peps,0::Double]) == 2 , utest "block" $ fromBlocks [[ident 3,0],[0,ident 4]] == (ident 7 :: CM) , mbCholTest + , triTest + , triDiagTest , utest "offset" offsetTest , normsVTest , normsMTest -- cgit v1.2.3