From 225901798773228e73b4c98670d56e844c040b3d Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sun, 7 Sep 2014 14:25:07 +0200 Subject: fix diagRectR/C (eye), check zero cols in in gmat(fromList-matrix), and thanks --- packages/base/THANKS.md | 2 ++ packages/base/src/Numeric/LinearAlgebra/Static.hs | 10 ++++++++-- packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) (limited to 'packages/base') diff --git a/packages/base/THANKS.md b/packages/base/THANKS.md index 2e9574b..805a19e 100644 --- a/packages/base/THANKS.md +++ b/packages/base/THANKS.md @@ -157,6 +157,8 @@ module reorganization, monadic mapVectorM, and many other improvements. - Denis Laxalde separated the gsl tests from the base ones. +- "idontgetoutmuch" reported a bug in the static diagonal creation functions. + - Dylan Thurston reported an error in the glpk documentation. - Ian Ross reported the max/minIndex bug. diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index cbcd4e2..cc5eb4f 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -512,7 +512,10 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) -------------------------------------------------------------------------------- diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n -diagRectR x v = r +diagRectR x v + | m' == 1 = mkL (LA.diagRect x ev m' n') + | m'*n' > 0 = r + | otherwise = matrix [] where r = mkL (asRow (vjoin [scalar x, ev, zeros])) ev = extract v @@ -521,7 +524,10 @@ diagRectR x v = r diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n -diagRectC x v = r +diagRectC x v + | m' == 1 = mkM (LA.diagRect x ev m' n') + | m'*n' > 0 = r + | otherwise = fromList [] where r = mkM (asRow (vjoin [scalar x, ev, zeros])) ev = extract v diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs index 339ef7d..ec02cf6 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs @@ -150,7 +150,7 @@ gmat st xs' (xs,rest) = splitAt (m'*n') xs' v = LA.fromList xs x = reshape n' v - ok = rem (LA.size v) n' == 0 && LA.size x == (m',n') && null rest + ok = null rest && ((n' == 0 && dim v == 0) || n'> 0 && (rem (LA.size v) n' == 0) && LA.size x == (m',n')) m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info -- cgit v1.2.3