From fa1642dcf26f1da0a6f4c1324bcd1e8baf9fd478 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 17 Mar 2017 14:20:07 +0000 Subject: Support triangular matrices. --- packages/base/src/Internal/LAPACK.hs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) (limited to 'packages/base/src/Internal/LAPACK.hs') diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 231109a..b4dd5cf 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -406,6 +406,36 @@ cholSolveR a b = linearSolveSQAux2 id dpotrs "cholSolveR" (fmat a) b cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b +-------------------------------------------------------------------------------- +foreign import ccall unsafe "triSolveR_l_u" dtrtrs_u :: R ::> R ::> Ok +foreign import ccall unsafe "triSolveC_l_u" ztrtrs_u :: C ::> C ::> Ok +foreign import ccall unsafe "triSolveR_l_l" dtrtrs_l :: R ::> R ::> Ok +foreign import ccall unsafe "triSolveC_l_l" ztrtrs_l :: C ::> C ::> Ok + + +linearSolveTRAux2 g f st a b + | n1==n2 && n1==r = unsafePerformIO . g $ do + s <- copy ColumnMajor b + (a #! s) f #| st + return s + | otherwise = error $ st ++ " of nonsquare matrix" + where + n1 = rows a + n2 = cols a + r = rows b + +data UpLo = Lower | Upper + +-- | Solves a triangular system of linear equations. +triSolveR :: UpLo -> Matrix Double -> Matrix Double -> Matrix Double +triSolveR Lower a b = linearSolveTRAux2 id dtrtrs_l "triSolveR" (fmat a) b +triSolveR Upper a b = linearSolveTRAux2 id dtrtrs_u "triSolveR" (fmat a) b + +-- | Solves a triangular system of linear equations. +triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b +triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b + ----------------------------------------------------------------------------------- foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok -- cgit v1.2.3