summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Util.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Util.hs')
-rw-r--r--packages/base/src/Internal/Util.hs80
1 files changed, 74 insertions, 6 deletions
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index b1fb800..7a556e9 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -54,7 +54,7 @@ module Internal.Util(
54 -- ** 2D 54 -- ** 2D
55 corr2, conv2, separable, 55 corr2, conv2, separable,
56 block2x2,block3x3,view1,unView1,foldMatrix, 56 block2x2,block3x3,view1,unView1,foldMatrix,
57 gaussElim 57 gaussElim_1, gaussElim_2, gaussElim
58) where 58) where
59 59
60import Internal.Vector 60import Internal.Vector
@@ -64,17 +64,19 @@ import Internal.Element
64import Internal.Container 64import Internal.Container
65import Internal.Vectorized 65import Internal.Vectorized
66import Internal.IO 66import Internal.IO
67import Internal.Algorithms hiding (i,Normed) 67import Internal.Algorithms hiding (i,Normed,swap)
68import Numeric.Matrix() 68import Numeric.Matrix()
69import Numeric.Vector() 69import Numeric.Vector()
70import Internal.Random 70import Internal.Random
71import Internal.Convolution 71import Internal.Convolution
72import Control.Monad(when) 72import Control.Monad(when,forM_)
73import Text.Printf 73import Text.Printf
74import Data.List.Split(splitOn) 74import Data.List.Split(splitOn)
75import Data.List(intercalate,) 75import Data.List(intercalate,sortBy)
76import Control.Arrow((&&&)) 76import Control.Arrow((&&&))
77import Data.Complex 77import Data.Complex
78import Data.Function(on)
79import Internal.ST
78 80
79type ℝ = Double 81type ℝ = Double
80type ℕ = Int 82type ℕ = Int
@@ -359,6 +361,10 @@ instance Indexable (Vector I) I
359 where 361 where
360 (!) = (@>) 362 (!) = (@>)
361 363
364instance Indexable (Vector Z) Z
365 where
366 (!) = (@>)
367
362instance Indexable (Vector (Complex Double)) (Complex Double) 368instance Indexable (Vector (Complex Double)) (Complex Double)
363 where 369 where
364 (!) = (@>) 370 (!) = (@>)
@@ -550,11 +556,11 @@ down g a = foldMatrix g f a
550-- 556--
551-- @a <> gaussElim a b = b@ 557-- @a <> gaussElim a b = b@
552-- 558--
553gaussElim 559gaussElim_2
554 :: (Eq t, Fractional t, Num (Vector t), Numeric t) 560 :: (Eq t, Fractional t, Num (Vector t), Numeric t)
555 => Matrix t -> Matrix t -> Matrix t 561 => Matrix t -> Matrix t -> Matrix t
556 562
557gaussElim a b = flipudrl r 563gaussElim_2 a b = flipudrl r
558 where 564 where
559 flipudrl = flipud . fliprl 565 flipudrl = flipud . fliprl
560 splitColsAt n = (takeColumns n &&& dropColumns n) 566 splitColsAt n = (takeColumns n &&& dropColumns n)
@@ -564,6 +570,68 @@ gaussElim a b = flipudrl r
564 570
565-------------------------------------------------------------------------------- 571--------------------------------------------------------------------------------
566 572
573gaussElim_1
574 :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t)
575 => Matrix t -> Matrix t -> Matrix t
576
577gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2)
578 where
579 rs = toRows $ fromBlocks [[x , y]]
580 s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting
581 s2 = pivotUp (rows x-1) (toRows $ flipud s1)
582
583pivotDown t n xs
584 | t == n = []
585 | otherwise = y : pivotDown t (n+1) ys
586 where
587 y:ys = redu (pivot n xs)
588
589 pivot k = (const k &&& id)
590 . sortBy (flip compare `on` (abs. (!k)))
591
592 redu (k,x:zs)
593 | p == 0 = error "gauss: singular!" -- FIXME
594 | otherwise = u : map f zs
595 where
596 p = x!k
597 u = scale (recip (x!k)) x
598 f z = z - scale (z!k) u
599 redu (_,[]) = []
600
601
602pivotUp n xs
603 | n == -1 = []
604 | otherwise = y : pivotUp (n-1) ys
605 where
606 y:ys = redu' (n,xs)
607
608 redu' (k,x:zs) = u : map f zs
609 where
610 u = x
611 f z = z - scale (z!k) u
612 redu' (_,[]) = []
613
614--------------------------------------------------------------------------------
615
616gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]])
617
618gaussST (r,_) x = do
619 let n = r-1
620 forM_ [0..n] $ \i -> do
621 c <- maxIndex . abs . flatten <$> extractRect x i n i i
622 swap x i (i+c)
623 a <- readMatrix x i i
624 scal x (recip a) i
625 forM_ [i+1..n] $ \j -> do
626 b <- readMatrix x j i
627 axpy x (-b) i j
628 forM_ [n,n-1..1] $ \i -> do
629 forM_ [i-1,i-2..0] $ \j -> do
630 b <- readMatrix x j i
631 axpy x (-b) i j
632
633--------------------------------------------------------------------------------
634
567instance Testable (Matrix I) where 635instance Testable (Matrix I) where
568 checkT _ = test 636 checkT _ = test
569 637