summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-09-07 14:25:07 +0200
committerAlberto Ruiz <aruiz@um.es>2014-09-07 14:25:07 +0200
commit225901798773228e73b4c98670d56e844c040b3d (patch)
tree4eb48639cfcf6eb7fdc98cf3514231f22591840c /packages/base/src
parentaa9e29a5fc728965811557652a409338c792b609 (diff)
fix diagRectR/C (eye), check zero cols in in gmat(fromList-matrix), and thanks
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs10
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs2
2 files changed, 9 insertions, 3 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index cbcd4e2..cc5eb4f 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -512,7 +512,10 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
512-------------------------------------------------------------------------------- 512--------------------------------------------------------------------------------
513 513
514diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n 514diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
515diagRectR x v = r 515diagRectR x v
516 | m' == 1 = mkL (LA.diagRect x ev m' n')
517 | m'*n' > 0 = r
518 | otherwise = matrix []
516 where 519 where
517 r = mkL (asRow (vjoin [scalar x, ev, zeros])) 520 r = mkL (asRow (vjoin [scalar x, ev, zeros]))
518 ev = extract v 521 ev = extract v
@@ -521,7 +524,10 @@ diagRectR x v = r
521 524
522 525
523diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n 526diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n
524diagRectC x v = r 527diagRectC x v
528 | m' == 1 = mkM (LA.diagRect x ev m' n')
529 | m'*n' > 0 = r
530 | otherwise = fromList []
525 where 531 where
526 r = mkM (asRow (vjoin [scalar x, ev, zeros])) 532 r = mkM (asRow (vjoin [scalar x, ev, zeros]))
527 ev = extract v 533 ev = extract v
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
index 339ef7d..ec02cf6 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
@@ -150,7 +150,7 @@ gmat st xs'
150 (xs,rest) = splitAt (m'*n') xs' 150 (xs,rest) = splitAt (m'*n') xs'
151 v = LA.fromList xs 151 v = LA.fromList xs
152 x = reshape n' v 152 x = reshape n' v
153 ok = rem (LA.size v) n' == 0 && LA.size x == (m',n') && null rest 153 ok = null rest && ((n' == 0 && dim v == 0) || n'> 0 && (rem (LA.size v) n' == 0) && LA.size x == (m',n'))
154 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 154 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
155 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 155 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
156 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info 156 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info