From f4124fa6209cbf8290fed2be51cec8464bf7f1b9 Mon Sep 17 00:00:00 2001 From: Maxim Koltsov Date: Mon, 19 Nov 2018 20:43:22 +0300 Subject: Fix #282 LAPACK routine dgttrf mutates its inputs per documentation. To prevent user-visible breakage input vectors must be copied before sending them to LAPACK. --- packages/base/src/Internal/LAPACK.hs | 4 +++- packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 24 +++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index ff55688..27d1f95 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -506,8 +506,10 @@ linearSolveGTAux2 g f st dl d du b | ndl == nd - 1 && ndu == nd - 1 && nd == r = unsafePerformIO . g $ do + dl' <- head . toRows <$> copy ColumnMajor (fromRows [dl]) + du' <- head . toRows <$> copy ColumnMajor (fromRows [du]) s <- copy ColumnMajor b - (dl # d # du #! s) f #| st + (dl' # d # du' #! s) f #| st return s | otherwise = error $ st ++ " of nonsquare matrix" where diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index c0c151a..3c7863f 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs @@ -242,6 +242,29 @@ triDiagTest = utest "triDiagTest" (ok1 && ok2) where --------------------------------------------------------------------- +triDiagRegression = utest "triDiagRegression" ok where + minusOnes, twos :: Vector R + minusOnes = fromList [-1, -1] + twos = fromList [2, 2, 2] + k :: Matrix R + k = (3><3) + [ 2, -1, 0 + , -1, 2, -1 + , 0, -1, 2 + ] + + b :: Matrix R + b = (3><1) [10, 10, 10] + + tridiag = triDiagSolve minusOnes twos minusOnes b + simple = linearSolve k b + + ok = case simple of + Just m -> tridiag |~| m + Nothing -> False + +--------------------------------------------------------------------- + randomTestGaussian = (unSym c) :~3~: unSym (snd (meanCov dat)) where a = (3><3) [1,2,3, @@ -830,6 +853,7 @@ runTests n = do , mbCholTest , triTest , triDiagTest + , triDiagRegression , utest "offset" offsetTest , normsVTest , normsMTest -- cgit v1.2.3