summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Util.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Util.hs')
-rw-r--r--packages/base/src/Internal/Util.hs31
1 files changed, 18 insertions, 13 deletions
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 1b639e1..b1fb800 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -512,7 +512,7 @@ block3x3 r nr c nc m = [[m ?? (er !! i, ec !! j) | j <- [0..2] ] | i <- [0..2] ]
512 er = [ Range 0 1 (r-1), Range r 1 (r+nr-1), Drop (nr+r) ] 512 er = [ Range 0 1 (r-1), Range r 1 (r+nr-1), Drop (nr+r) ]
513 ec = [ Range 0 1 (c-1), Range c 1 (c+nc-1), Drop (nc+c) ] 513 ec = [ Range 0 1 (c-1), Range c 1 (c+nc-1), Drop (nc+c) ]
514 514
515view1 :: Numeric t => Matrix t -> Maybe (t, Vector t, Vector t, Matrix t) 515view1 :: Numeric t => Matrix t -> Maybe (View1 t)
516view1 m 516view1 m
517 | rows m > 0 && cols m > 0 = Just (e, flatten m12, flatten m21 , m22) 517 | rows m > 0 && cols m > 0 = Just (e, flatten m12, flatten m21 , m22)
518 | otherwise = Nothing 518 | otherwise = Nothing
@@ -520,21 +520,25 @@ view1 m
520 [[m11,m12],[m21,m22]] = block2x2 1 1 m 520 [[m11,m12],[m21,m22]] = block2x2 1 1 m
521 e = m11 `atIndex` (0, 0) 521 e = m11 `atIndex` (0, 0)
522 522
523unView1 :: Numeric t => (t, Vector t, Vector t, Matrix t) -> Matrix t 523unView1 :: Numeric t => View1 t -> Matrix t
524unView1 (e,r,c,m) = fromBlocks [[scalar e, asRow r],[asColumn c, m]] 524unView1 (e,r,c,m) = fromBlocks [[scalar e, asRow r],[asColumn c, m]]
525 525
526type View1 t = (t, Vector t, Vector t, Matrix t)
526 527
528foldMatrix :: Numeric t => (Matrix t -> Matrix t) -> (View1 t -> View1 t) -> (Matrix t -> Matrix t)
527foldMatrix g f ( (f <$>) . view1 . g -> Just (e,r,c,m)) = unView1 (e, r, c, foldMatrix g f m) 529foldMatrix g f ( (f <$>) . view1 . g -> Just (e,r,c,m)) = unView1 (e, r, c, foldMatrix g f m)
528foldMatrix _ _ m = m 530foldMatrix _ _ m = m
529 531
530sortRowsBy h j m = m ?? (Pos (sortIndex (h (tr m ! j))), All)
531
532splitColsAt n = (takeColumns n &&& dropColumns n)
533 532
533swapMax k m
534 | rows m > 0 && j>0 = (j, m ?? (Pos (idxs swapped), All))
535 | otherwise = (0,m)
536 where
537 j = maxIndex $ abs (tr m ! k)
538 swapped = j:[1..j-1] ++ 0:[j+1..rows m-1]
534 539
535down a = foldMatrix g f a 540down g a = foldMatrix g f a
536 where 541 where
537 g = sortRowsBy (negate.abs) 0
538 f (e,r,c,m) 542 f (e,r,c,m)
539 | e /= 0 = (1, r', 0, m - outer c r') 543 | e /= 0 = (1, r', 0, m - outer c r')
540 | otherwise = error "singular!" 544 | otherwise = error "singular!"
@@ -547,15 +551,16 @@ down a = foldMatrix g f a
547-- @a <> gaussElim a b = b@ 551-- @a <> gaussElim a b = b@
548-- 552--
549gaussElim 553gaussElim
550 :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t) 554 :: (Eq t, Fractional t, Num (Vector t), Numeric t)
551 => Matrix t -> Matrix t -> Matrix t 555 => Matrix t -> Matrix t -> Matrix t
552 556
553gaussElim a b = r 557gaussElim a b = flipudrl r
554 where 558 where
555 go x y = splitColsAt (cols a) (down $ fromBlocks [[x,y]]) 559 flipudrl = flipud . fliprl
556 (a1,b1) = go a b 560 splitColsAt n = (takeColumns n &&& dropColumns n)
557 ( _, r) = go (flipud . fliprl $ a1) (flipud . fliprl $ b1) 561 go f x y = splitColsAt (cols a) (down f $ fromBlocks [[x,y]])
558 562 (a1,b1) = go (snd . swapMax 0) a b
563 ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1)
559 564
560-------------------------------------------------------------------------------- 565--------------------------------------------------------------------------------
561 566