summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-17 19:35:31 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-17 19:35:31 +0200
commit52009006791ee2b71530a61f4bf9e1c065c04eae (patch)
tree36c4256822d99a3abc34902a8e86150be2a0ea17
parent61d90ff66af8bfe53ef8cdda8dfe1e70463c213c (diff)
improved luSolve', tests
-rw-r--r--packages/base/src/Internal/Modular.hs46
-rw-r--r--packages/base/src/Internal/ST.hs2
-rw-r--r--packages/base/src/Internal/Util.hs87
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs53
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs16
5 files changed, 132 insertions, 72 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index d158111..1cae1ac 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -34,7 +34,9 @@ import Internal.Container
34import Internal.Vectorized (prodI,sumI,prodL,sumL) 34import Internal.Vectorized (prodI,sumI,prodL,sumL)
35import Internal.LAPACK (multiplyI, multiplyL) 35import Internal.LAPACK (multiplyI, multiplyL)
36import Internal.Algorithms(luFact) 36import Internal.Algorithms(luFact)
37import Internal.Util(Normed(..),Indexable(..),gaussElim, gaussElim_1, gaussElim_2,luST, magnit) 37import Internal.Util(Normed(..),Indexable(..),
38 gaussElim, gaussElim_1, gaussElim_2,
39 luST, luSolve', luPacked', magnit)
38import Internal.ST(mutable) 40import Internal.ST(mutable)
39import GHC.TypeLits 41import GHC.TypeLits
40import Data.Proxy(Proxy) 42import Data.Proxy(Proxy)
@@ -350,7 +352,11 @@ test = (ok, info)
350 lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z 352 lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z
351 lgm = fromZ lg :: Matrix (Mod 10000000000 Z) 353 lgm = fromZ lg :: Matrix (Mod 10000000000 Z)
352 354
353 gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t 355 gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t
356
357 rgen n = gen n :: Matrix R
358 cgen n = complex (rgen n) + fliprl (complex (rgen n)) * scalar (0:+1) :: Matrix C
359 sgen n = single (cgen n)
354 360
355 checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x) 361 checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x)
356 362
@@ -360,6 +366,11 @@ test = (ok, info)
360 where 366 where
361 (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t 367 (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t
362 368
369 checkSolve aa = norm_Inf $ flatten (aa <> x - bb)
370 where
371 bb = flipud aa
372 x = luSolve' (luPacked' aa) bb
373
363 info = do 374 info = do
364 print v 375 print v
365 print m 376 print m
@@ -383,9 +394,9 @@ test = (ok, info)
383 print $ lgm <> lgm 394 print $ lgm <> lgm
384 395
385 print (checkGen (gen 5 :: Matrix R)) 396 print (checkGen (gen 5 :: Matrix R))
386 print (checkGen (gen 5 :: Matrix C))
387 print (checkGen (gen 5 :: Matrix Float)) 397 print (checkGen (gen 5 :: Matrix Float))
388 print (checkGen (gen 5 :: Matrix (Complex Float))) 398 print (checkGen (cgen 5 :: Matrix C))
399 print (checkGen (sgen 5 :: Matrix (Complex Float)))
389 print (invg (gen 5) :: Matrix (Mod 7 I)) 400 print (invg (gen 5) :: Matrix (Mod 7 I))
390 print (invg (gen 5) :: Matrix (Mod 7 Z)) 401 print (invg (gen 5) :: Matrix (Mod 7 Z))
391 402
@@ -394,11 +405,18 @@ test = (ok, info)
394 405
395 print $ checkLU (magnit 0) (gen 5 :: Matrix R) 406 print $ checkLU (magnit 0) (gen 5 :: Matrix R)
396 print $ checkLU (magnit 0) (gen 5 :: Matrix Float) 407 print $ checkLU (magnit 0) (gen 5 :: Matrix Float)
397 print $ checkLU (magnit 0) (gen 5 :: Matrix C) 408 print $ checkLU (magnit 0) (cgen 5 :: Matrix C)
398 print $ checkLU (magnit 0) (gen 5 :: Matrix (Complex Float)) 409 print $ checkLU (magnit 0) (sgen 5 :: Matrix (Complex Float))
399 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) 410 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))
400 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) 411 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))
401 412
413 print $ checkSolve (gen 5 :: Matrix R)
414 print $ checkSolve (gen 5 :: Matrix Float)
415 print $ checkSolve (cgen 5 :: Matrix C)
416 print $ checkSolve (sgen 5 :: Matrix (Complex Float))
417 print $ checkSolve (gen 5 :: Matrix (Mod 7 I))
418 print $ checkSolve (gen 5 :: Matrix (Mod 7 Z))
419
402 420
403 ok = and 421 ok = and
404 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) 422 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v )
@@ -407,16 +425,22 @@ test = (ok, info)
407 , am <> gaussElim am bm == bm 425 , am <> gaussElim am bm == bm
408 , (checkGen (gen 5 :: Matrix R)) < 1E-15 426 , (checkGen (gen 5 :: Matrix R)) < 1E-15
409 , (checkGen (gen 5 :: Matrix Float)) < 2E-7 427 , (checkGen (gen 5 :: Matrix Float)) < 2E-7
410 , (checkGen (gen 5 :: Matrix C)) < 1E-15 428 , (checkGen (cgen 5 :: Matrix C)) < 1E-15
411 , (checkGen (gen 5 :: Matrix (Complex Float))) < 2E-7 429 , (checkGen (sgen 5 :: Matrix (Complex Float))) < 2E-7
412 , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0 430 , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0
413 , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0 431 , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0
414 , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 1E-15 432 , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 2E-15
415 , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6 433 , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6
416 , (checkLU (magnit 1E-10) (gen 5 :: Matrix C)) < 1E-15 434 , (checkLU (magnit 1E-10) (cgen 5 :: Matrix C)) < 5E-15
417 , (checkLU (magnit 1E-5) (gen 5 :: Matrix (Complex Float))) < 1E-6 435 , (checkLU (magnit 1E-5) (sgen 5 :: Matrix (Complex Float))) < 1E-6
418 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0 436 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0
419 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0 437 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0
438 , checkSolve (gen 5 :: Matrix R) < 2E-15
439 , checkSolve (gen 5 :: Matrix Float) < 1E-6
440 , checkSolve (cgen 5 :: Matrix C) < 4E-15
441 , checkSolve (sgen 5 :: Matrix (Complex Float)) < 1E-6
442 , checkSolve (gen 5 :: Matrix (Mod 7 I)) == 0
443 , checkSolve (gen 5 :: Matrix (Mod 7 Z)) == 0
420 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) 444 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I))
421 , gm <> gm == konst 0 (3,3) 445 , gm <> gm == konst 0 (3,3)
422 , lgm <> lgm == konst 0 (3,3) 446 , lgm <> lgm == konst 0 (3,3)
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 25e7f03..d1defda 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -228,10 +228,12 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i
228 (i1,i2) = getRowRange (rows m) rr 228 (i1,i2) = getRowRange (rows m) rr
229 (j1,j2) = getColRange (cols m) rc 229 (j1,j2) = getColRange (cols m) rc
230 230
231-- | r0 c0 height width
231data Slice s t = Slice (STMatrix s t) Int Int Int Int 232data Slice s t = Slice (STMatrix s t) Int Int Int Int
232 233
233slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) 234slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1])
234 235
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
235gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res 237gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res
236 where 238 where
237 res = unsafeIOToST (gemm u v a b r) 239 res = unsafeIOToST (gemm u v a b r)
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index d9777ae..079663d 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -54,7 +54,7 @@ module Internal.Util(
54 -- ** 2D 54 -- ** 2D
55 corr2, conv2, separable, 55 corr2, conv2, separable,
56 block2x2,block3x3,view1,unView1,foldMatrix, 56 block2x2,block3x3,view1,unView1,foldMatrix,
57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve' 57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked'
58) where 58) where
59 59
60import Internal.Vector 60import Internal.Vector
@@ -64,7 +64,7 @@ import Internal.Element
64import Internal.Container 64import Internal.Container
65import Internal.Vectorized 65import Internal.Vectorized
66import Internal.IO 66import Internal.IO
67import Internal.Algorithms hiding (Normed,linearSolve',luSolve') 67import Internal.Algorithms hiding (Normed,linearSolve',luSolve', luPacked')
68import Numeric.Matrix() 68import Numeric.Matrix()
69import Numeric.Vector() 69import Numeric.Vector()
70import Internal.Random 70import Internal.Random
@@ -686,6 +686,35 @@ luST ok (r,_) x = do
686 v <- unsafeFreezeVector p 686 v <- unsafeFreezeVector p
687 return (toList v) 687 return (toList v)
688 688
689{- | Experimental implementation of 'luPacked'
690 for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
691
692>>> let m = ident 5 + (5><5) [0..] :: Matrix (Z ./. 17)
693(5><5)
694 [ 1, 1, 2, 3, 4
695 , 5, 7, 7, 8, 9
696 , 10, 11, 13, 13, 14
697 , 15, 16, 0, 2, 2
698 , 3, 4, 5, 6, 8 ]
699
700>>> let (l,u,p,s) = luFact $ luPacked' m
701>>> l
702(5><5)
703 [ 1, 0, 0, 0, 0
704 , 6, 1, 0, 0, 0
705 , 12, 7, 1, 0, 0
706 , 7, 10, 7, 1, 0
707 , 8, 2, 6, 11, 1 ]
708>>> u
709(5><5)
710 [ 15, 16, 0, 2, 2
711 , 0, 13, 7, 13, 14
712 , 0, 0, 15, 0, 11
713 , 0, 0, 0, 15, 15
714 , 0, 0, 0, 0, 1 ]
715
716-}
717luPacked' x = mutable (luST (magnit 0)) x
689 718
690-------------------------------------------------------------------------------- 719--------------------------------------------------------------------------------
691 720
@@ -693,35 +722,79 @@ rowRange m = [0..rows m -1]
693 722
694at k = Pos (idxs[k]) 723at k = Pos (idxs[k])
695 724
696backSust lup rhs = foldl' f (rhs?[]) (reverse ls) 725backSust' lup rhs = foldl' f (rhs?[]) (reverse ls)
697 where 726 where
698 ls = [ (d k , u k , b k) | k <- rowRange lup ] 727 ls = [ (d k , u k , b k) | k <- rowRange lup ]
699 where 728 where
700 d k = lup ?? (at k, at k) 729 d k = lup ?? (at k, at k)
701 u k = lup ?? (at k, Drop (k+1)) 730 u k = lup ?? (at k, Drop (k+1))
702 b k = rhs ?? (at k, All) 731 b k = rhs ?? (at k, All)
703 732
704 f x (d,u,b) = (b - u<>x) / d 733 f x (d,u,b) = (b - u<>x) / d
705 === 734 ===
706 x 735 x
707 736
708 737
709forwSust lup rhs = foldl' f (rhs?[]) ls 738forwSust' lup rhs = foldl' f (rhs?[]) ls
710 where 739 where
711 ls = [ (l k , b k) | k <- rowRange lup ] 740 ls = [ (l k , b k) | k <- rowRange lup ]
712 where 741 where
713 l k = lup ?? (at k, Take k) 742 l k = lup ?? (at k, Take k)
714 b k = rhs ?? (at k, All) 743 b k = rhs ?? (at k, All)
715 744
716 f x (l,b) = x 745 f x (l,b) = x
717 === 746 ===
718 (b - l<>x) 747 (b - l<>x)
719 748
720 749
721luSolve' (lup,p) b = backSust lup (forwSust lup pb) 750luSolve'' (lup,p) b = backSust' lup (forwSust' lup pb)
722 where 751 where
723 pb = b ?? (Pos (fixPerm' p), All) 752 pb = b ?? (Pos (fixPerm' p), All)
724 753
754--------------------------------------------------------------------------------
755
756forwSust lup rhs = fst $ mutable f rhs
757 where
758 f (r,c) x = do
759 l <- unsafeThawMatrix lup
760 let go k = gemmm 1 (Slice x k 0 1 c) (-1) (Slice l k 0 1 k) (Slice x 0 0 k c)
761 mapM_ go [0..r-1]
762
763
764backSust lup rhs = fst $ mutable f rhs
765 where
766 f (r,c) m = do
767 l <- unsafeThawMatrix lup
768 let d k = recip (lup `atIndex` (k,k))
769 u k = Slice l k (k+1) 1 (r-1-k)
770 b k = Slice m k 0 1 c
771 x k = Slice m (k+1) 0 (r-1-k) c
772 scal k = rowOper (SCAL (d k) (Row k) AllCols) m
773
774 go k = gemmm 1 (b k) (-1) (u k) (x k) >> scal k
775 mapM_ go [r-1,r-2..0]
776
777
778{- | Experimental implementation of 'luSolve' for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
779
780>>> let a = (2><2) [1,2,3,5] :: Matrix (Z ./. 13)
781(2><2)
782 [ 1, 2
783 , 3, 5 ]
784>>> b
785(2><3)
786 [ 5, 1, 3
787 , 8, 6, 3 ]
788
789>>> luSolve' (luPacked' a) b
790(2><3)
791 [ 4, 7, 4
792 , 7, 10, 6 ]
793
794-}
795luSolve' (lup,p) b = backSust lup (forwSust lup pb)
796 where
797 pb = b ?? (Pos (fixPerm' p), All)
725 798
726-------------------------------------------------------------------------------- 799--------------------------------------------------------------------------------
727 800
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index e899445..0b8abbb 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -81,7 +81,6 @@ module Numeric.LinearAlgebra (
81 cholSolve, 81 cholSolve,
82 cgSolve, 82 cgSolve,
83 cgSolve', 83 cgSolve',
84 linearSolve',
85 84
86 -- * Inverse and pseudoinverse 85 -- * Inverse and pseudoinverse
87 inv, pinv, pinvTol, 86 inv, pinv, pinvTol,
@@ -123,7 +122,7 @@ module Numeric.LinearAlgebra (
123 schur, 122 schur,
124 123
125 -- * LU 124 -- * LU
126 lu, luPacked, luFact, luPacked', 125 lu, luPacked, luPacked', luFact,
127 126
128 -- * Matrix functions 127 -- * Matrix functions
129 expm, 128 expm,
@@ -166,7 +165,6 @@ import Internal.Random
166import Internal.Sparse((!#>)) 165import Internal.Sparse((!#>))
167import Internal.CG 166import Internal.CG
168import Internal.Conversion 167import Internal.Conversion
169import Internal.ST(mutable)
170 168
171{- | infix synonym of 'mul' 169{- | infix synonym of 'mul'
172 170
@@ -241,53 +239,4 @@ nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m)
241-- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. 239-- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'.
242orth m = orthSVD (Left (1*eps)) m (leftSV m) 240orth m = orthSVD (Left (1*eps)) m (leftSV m)
243 241
244{- | Experimental implementation of LU factorization
245 working on any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
246
247>>> let m = ident 5 + (5><5) [0..] :: Matrix (Z ./. 17)
248(5><5)
249 [ 1, 1, 2, 3, 4
250 , 5, 7, 7, 8, 9
251 , 10, 11, 13, 13, 14
252 , 15, 16, 0, 2, 2
253 , 3, 4, 5, 6, 8 ]
254
255>>> let (l,u,p,s) = luFact $ luPacked' m
256>>> l
257(5><5)
258 [ 1, 0, 0, 0, 0
259 , 6, 1, 0, 0, 0
260 , 12, 7, 1, 0, 0
261 , 7, 10, 7, 1, 0
262 , 8, 2, 6, 11, 1 ]
263>>> u
264(5><5)
265 [ 15, 16, 0, 2, 2
266 , 0, 13, 7, 13, 14
267 , 0, 0, 15, 0, 11
268 , 0, 0, 0, 15, 15
269 , 0, 0, 0, 0, 1 ]
270
271-}
272luPacked' x = mutable (luST (magnit 0)) x
273
274{- | Experimental implementation of gaussian elimination
275 working on any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
276
277>>> let a = (2><2) [1,2,3,5] :: Matrix (Z ./. 13)
278(2><2)
279 [ 1, 2
280 , 3, 5 ]
281>>> b
282(2><3)
283 [ 5, 1, 3
284 , 8, 6, 3 ]
285
286>>> let x = linearSolve' a b
287(2><3)
288 [ 4, 7, 4
289 , 7, 10, 6 ]
290
291-}
292linearSolve' x y = gaussElim x y
293 242
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index 148bbb9..b1428fb 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -611,9 +611,10 @@ runBenchmarks = do
611 mkVecBench 611 mkVecBench
612 multBench 612 multBench
613 cholBench 613 cholBench
614 luBench
615 luBench_2
614 svdBench 616 svdBench
615 eigBench 617 eigBench
616 luBench
617 putStrLn "" 618 putStrLn ""
618 619
619-------------------------------- 620--------------------------------
@@ -778,6 +779,17 @@ luBench = do
778 luBenchN luPacked' 1000 (5::Mod 9973 I) "luPacked' I mod 9973" 779 luBenchN luPacked' 1000 (5::Mod 9973 I) "luPacked' I mod 9973"
779 luBenchN luPacked' 1000 (5::Mod 9973 Z) "luPacked' Z mod 9973" 780 luBenchN luPacked' 1000 (5::Mod 9973 Z) "luPacked' Z mod 9973"
780 781
781 782luBenchN_2 f g n x msg = do
783 let m = diagRect 1 (fromList (replicate n x)) n n
784 b = flipud m
785 m `seq` b `seq` putStr ""
786 time (msg ++ " "++ show n) (f (g m) b)
787
788luBench_2 = do
789 putStrLn ""
790 luBenchN_2 luSolve luPacked 500 (5::R) "luSolve .luPacked Double "
791 luBenchN_2 luSolve' luPacked' 500 (5::R) "luSolve'.luPacked' Double "
792 luBenchN_2 luSolve' luPacked' 500 (5::Mod 9973 I) "luSolve'.luPacked' I mod 9973"
793 luBenchN_2 luSolve' luPacked' 500 (5::Mod 9973 Z) "luSolve'.luPacked' Z mod 9973"
782 794
783 795