diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-17 19:35:31 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-17 19:35:31 +0200 |
commit | 52009006791ee2b71530a61f4bf9e1c065c04eae (patch) | |
tree | 36c4256822d99a3abc34902a8e86150be2a0ea17 /packages/base | |
parent | 61d90ff66af8bfe53ef8cdda8dfe1e70463c213c (diff) |
improved luSolve', tests
Diffstat (limited to 'packages/base')
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 46 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 87 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 53 |
4 files changed, 118 insertions, 70 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index d158111..1cae1ac 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -34,7 +34,9 @@ import Internal.Container | |||
34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) | 34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) |
35 | import Internal.LAPACK (multiplyI, multiplyL) | 35 | import Internal.LAPACK (multiplyI, multiplyL) |
36 | import Internal.Algorithms(luFact) | 36 | import Internal.Algorithms(luFact) |
37 | import Internal.Util(Normed(..),Indexable(..),gaussElim, gaussElim_1, gaussElim_2,luST, magnit) | 37 | import Internal.Util(Normed(..),Indexable(..), |
38 | gaussElim, gaussElim_1, gaussElim_2, | ||
39 | luST, luSolve', luPacked', magnit) | ||
38 | import Internal.ST(mutable) | 40 | import Internal.ST(mutable) |
39 | import GHC.TypeLits | 41 | import GHC.TypeLits |
40 | import Data.Proxy(Proxy) | 42 | import Data.Proxy(Proxy) |
@@ -350,7 +352,11 @@ test = (ok, info) | |||
350 | lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z | 352 | lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z |
351 | lgm = fromZ lg :: Matrix (Mod 10000000000 Z) | 353 | lgm = fromZ lg :: Matrix (Mod 10000000000 Z) |
352 | 354 | ||
353 | gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t | 355 | gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t |
356 | |||
357 | rgen n = gen n :: Matrix R | ||
358 | cgen n = complex (rgen n) + fliprl (complex (rgen n)) * scalar (0:+1) :: Matrix C | ||
359 | sgen n = single (cgen n) | ||
354 | 360 | ||
355 | checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x) | 361 | checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x) |
356 | 362 | ||
@@ -360,6 +366,11 @@ test = (ok, info) | |||
360 | where | 366 | where |
361 | (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t | 367 | (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t |
362 | 368 | ||
369 | checkSolve aa = norm_Inf $ flatten (aa <> x - bb) | ||
370 | where | ||
371 | bb = flipud aa | ||
372 | x = luSolve' (luPacked' aa) bb | ||
373 | |||
363 | info = do | 374 | info = do |
364 | print v | 375 | print v |
365 | print m | 376 | print m |
@@ -383,9 +394,9 @@ test = (ok, info) | |||
383 | print $ lgm <> lgm | 394 | print $ lgm <> lgm |
384 | 395 | ||
385 | print (checkGen (gen 5 :: Matrix R)) | 396 | print (checkGen (gen 5 :: Matrix R)) |
386 | print (checkGen (gen 5 :: Matrix C)) | ||
387 | print (checkGen (gen 5 :: Matrix Float)) | 397 | print (checkGen (gen 5 :: Matrix Float)) |
388 | print (checkGen (gen 5 :: Matrix (Complex Float))) | 398 | print (checkGen (cgen 5 :: Matrix C)) |
399 | print (checkGen (sgen 5 :: Matrix (Complex Float))) | ||
389 | print (invg (gen 5) :: Matrix (Mod 7 I)) | 400 | print (invg (gen 5) :: Matrix (Mod 7 I)) |
390 | print (invg (gen 5) :: Matrix (Mod 7 Z)) | 401 | print (invg (gen 5) :: Matrix (Mod 7 Z)) |
391 | 402 | ||
@@ -394,11 +405,18 @@ test = (ok, info) | |||
394 | 405 | ||
395 | print $ checkLU (magnit 0) (gen 5 :: Matrix R) | 406 | print $ checkLU (magnit 0) (gen 5 :: Matrix R) |
396 | print $ checkLU (magnit 0) (gen 5 :: Matrix Float) | 407 | print $ checkLU (magnit 0) (gen 5 :: Matrix Float) |
397 | print $ checkLU (magnit 0) (gen 5 :: Matrix C) | 408 | print $ checkLU (magnit 0) (cgen 5 :: Matrix C) |
398 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Complex Float)) | 409 | print $ checkLU (magnit 0) (sgen 5 :: Matrix (Complex Float)) |
399 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) | 410 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) |
400 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) | 411 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) |
401 | 412 | ||
413 | print $ checkSolve (gen 5 :: Matrix R) | ||
414 | print $ checkSolve (gen 5 :: Matrix Float) | ||
415 | print $ checkSolve (cgen 5 :: Matrix C) | ||
416 | print $ checkSolve (sgen 5 :: Matrix (Complex Float)) | ||
417 | print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) | ||
418 | print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) | ||
419 | |||
402 | 420 | ||
403 | ok = and | 421 | ok = and |
404 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) | 422 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) |
@@ -407,16 +425,22 @@ test = (ok, info) | |||
407 | , am <> gaussElim am bm == bm | 425 | , am <> gaussElim am bm == bm |
408 | , (checkGen (gen 5 :: Matrix R)) < 1E-15 | 426 | , (checkGen (gen 5 :: Matrix R)) < 1E-15 |
409 | , (checkGen (gen 5 :: Matrix Float)) < 2E-7 | 427 | , (checkGen (gen 5 :: Matrix Float)) < 2E-7 |
410 | , (checkGen (gen 5 :: Matrix C)) < 1E-15 | 428 | , (checkGen (cgen 5 :: Matrix C)) < 1E-15 |
411 | , (checkGen (gen 5 :: Matrix (Complex Float))) < 2E-7 | 429 | , (checkGen (sgen 5 :: Matrix (Complex Float))) < 2E-7 |
412 | , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0 | 430 | , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0 |
413 | , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0 | 431 | , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0 |
414 | , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 1E-15 | 432 | , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 2E-15 |
415 | , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6 | 433 | , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6 |
416 | , (checkLU (magnit 1E-10) (gen 5 :: Matrix C)) < 1E-15 | 434 | , (checkLU (magnit 1E-10) (cgen 5 :: Matrix C)) < 5E-15 |
417 | , (checkLU (magnit 1E-5) (gen 5 :: Matrix (Complex Float))) < 1E-6 | 435 | , (checkLU (magnit 1E-5) (sgen 5 :: Matrix (Complex Float))) < 1E-6 |
418 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0 | 436 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0 |
419 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0 | 437 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0 |
438 | , checkSolve (gen 5 :: Matrix R) < 2E-15 | ||
439 | , checkSolve (gen 5 :: Matrix Float) < 1E-6 | ||
440 | , checkSolve (cgen 5 :: Matrix C) < 4E-15 | ||
441 | , checkSolve (sgen 5 :: Matrix (Complex Float)) < 1E-6 | ||
442 | , checkSolve (gen 5 :: Matrix (Mod 7 I)) == 0 | ||
443 | , checkSolve (gen 5 :: Matrix (Mod 7 Z)) == 0 | ||
420 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) | 444 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) |
421 | , gm <> gm == konst 0 (3,3) | 445 | , gm <> gm == konst 0 (3,3) |
422 | , lgm <> lgm == konst 0 (3,3) | 446 | , lgm <> lgm == konst 0 (3,3) |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 25e7f03..d1defda 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -228,10 +228,12 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i | |||
228 | (i1,i2) = getRowRange (rows m) rr | 228 | (i1,i2) = getRowRange (rows m) rr |
229 | (j1,j2) = getColRange (cols m) rc | 229 | (j1,j2) = getColRange (cols m) rc |
230 | 230 | ||
231 | -- | r0 c0 height width | ||
231 | data Slice s t = Slice (STMatrix s t) Int Int Int Int | 232 | data Slice s t = Slice (STMatrix s t) Int Int Int Int |
232 | 233 | ||
233 | slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) | 234 | slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1]) |
234 | 235 | ||
236 | gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s () | ||
235 | gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res | 237 | gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res |
236 | where | 238 | where |
237 | res = unsafeIOToST (gemm u v a b r) | 239 | res = unsafeIOToST (gemm u v a b r) |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index d9777ae..079663d 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -54,7 +54,7 @@ module Internal.Util( | |||
54 | -- ** 2D | 54 | -- ** 2D |
55 | corr2, conv2, separable, | 55 | corr2, conv2, separable, |
56 | block2x2,block3x3,view1,unView1,foldMatrix, | 56 | block2x2,block3x3,view1,unView1,foldMatrix, |
57 | gaussElim_1, gaussElim_2, gaussElim, luST, luSolve' | 57 | gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked' |
58 | ) where | 58 | ) where |
59 | 59 | ||
60 | import Internal.Vector | 60 | import Internal.Vector |
@@ -64,7 +64,7 @@ import Internal.Element | |||
64 | import Internal.Container | 64 | import Internal.Container |
65 | import Internal.Vectorized | 65 | import Internal.Vectorized |
66 | import Internal.IO | 66 | import Internal.IO |
67 | import Internal.Algorithms hiding (Normed,linearSolve',luSolve') | 67 | import Internal.Algorithms hiding (Normed,linearSolve',luSolve', luPacked') |
68 | import Numeric.Matrix() | 68 | import Numeric.Matrix() |
69 | import Numeric.Vector() | 69 | import Numeric.Vector() |
70 | import Internal.Random | 70 | import Internal.Random |
@@ -686,6 +686,35 @@ luST ok (r,_) x = do | |||
686 | v <- unsafeFreezeVector p | 686 | v <- unsafeFreezeVector p |
687 | return (toList v) | 687 | return (toList v) |
688 | 688 | ||
689 | {- | Experimental implementation of 'luPacked' | ||
690 | for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'. | ||
691 | |||
692 | >>> let m = ident 5 + (5><5) [0..] :: Matrix (Z ./. 17) | ||
693 | (5><5) | ||
694 | [ 1, 1, 2, 3, 4 | ||
695 | , 5, 7, 7, 8, 9 | ||
696 | , 10, 11, 13, 13, 14 | ||
697 | , 15, 16, 0, 2, 2 | ||
698 | , 3, 4, 5, 6, 8 ] | ||
699 | |||
700 | >>> let (l,u,p,s) = luFact $ luPacked' m | ||
701 | >>> l | ||
702 | (5><5) | ||
703 | [ 1, 0, 0, 0, 0 | ||
704 | , 6, 1, 0, 0, 0 | ||
705 | , 12, 7, 1, 0, 0 | ||
706 | , 7, 10, 7, 1, 0 | ||
707 | , 8, 2, 6, 11, 1 ] | ||
708 | >>> u | ||
709 | (5><5) | ||
710 | [ 15, 16, 0, 2, 2 | ||
711 | , 0, 13, 7, 13, 14 | ||
712 | , 0, 0, 15, 0, 11 | ||
713 | , 0, 0, 0, 15, 15 | ||
714 | , 0, 0, 0, 0, 1 ] | ||
715 | |||
716 | -} | ||
717 | luPacked' x = mutable (luST (magnit 0)) x | ||
689 | 718 | ||
690 | -------------------------------------------------------------------------------- | 719 | -------------------------------------------------------------------------------- |
691 | 720 | ||
@@ -693,35 +722,79 @@ rowRange m = [0..rows m -1] | |||
693 | 722 | ||
694 | at k = Pos (idxs[k]) | 723 | at k = Pos (idxs[k]) |
695 | 724 | ||
696 | backSust lup rhs = foldl' f (rhs?[]) (reverse ls) | 725 | backSust' lup rhs = foldl' f (rhs?[]) (reverse ls) |
697 | where | 726 | where |
698 | ls = [ (d k , u k , b k) | k <- rowRange lup ] | 727 | ls = [ (d k , u k , b k) | k <- rowRange lup ] |
699 | where | 728 | where |
700 | d k = lup ?? (at k, at k) | 729 | d k = lup ?? (at k, at k) |
701 | u k = lup ?? (at k, Drop (k+1)) | 730 | u k = lup ?? (at k, Drop (k+1)) |
702 | b k = rhs ?? (at k, All) | 731 | b k = rhs ?? (at k, All) |
703 | 732 | ||
704 | f x (d,u,b) = (b - u<>x) / d | 733 | f x (d,u,b) = (b - u<>x) / d |
705 | === | 734 | === |
706 | x | 735 | x |
707 | 736 | ||
708 | 737 | ||
709 | forwSust lup rhs = foldl' f (rhs?[]) ls | 738 | forwSust' lup rhs = foldl' f (rhs?[]) ls |
710 | where | 739 | where |
711 | ls = [ (l k , b k) | k <- rowRange lup ] | 740 | ls = [ (l k , b k) | k <- rowRange lup ] |
712 | where | 741 | where |
713 | l k = lup ?? (at k, Take k) | 742 | l k = lup ?? (at k, Take k) |
714 | b k = rhs ?? (at k, All) | 743 | b k = rhs ?? (at k, All) |
715 | 744 | ||
716 | f x (l,b) = x | 745 | f x (l,b) = x |
717 | === | 746 | === |
718 | (b - l<>x) | 747 | (b - l<>x) |
719 | 748 | ||
720 | 749 | ||
721 | luSolve' (lup,p) b = backSust lup (forwSust lup pb) | 750 | luSolve'' (lup,p) b = backSust' lup (forwSust' lup pb) |
722 | where | 751 | where |
723 | pb = b ?? (Pos (fixPerm' p), All) | 752 | pb = b ?? (Pos (fixPerm' p), All) |
724 | 753 | ||
754 | -------------------------------------------------------------------------------- | ||
755 | |||
756 | forwSust lup rhs = fst $ mutable f rhs | ||
757 | where | ||
758 | f (r,c) x = do | ||
759 | l <- unsafeThawMatrix lup | ||
760 | let go k = gemmm 1 (Slice x k 0 1 c) (-1) (Slice l k 0 1 k) (Slice x 0 0 k c) | ||
761 | mapM_ go [0..r-1] | ||
762 | |||
763 | |||
764 | backSust lup rhs = fst $ mutable f rhs | ||
765 | where | ||
766 | f (r,c) m = do | ||
767 | l <- unsafeThawMatrix lup | ||
768 | let d k = recip (lup `atIndex` (k,k)) | ||
769 | u k = Slice l k (k+1) 1 (r-1-k) | ||
770 | b k = Slice m k 0 1 c | ||
771 | x k = Slice m (k+1) 0 (r-1-k) c | ||
772 | scal k = rowOper (SCAL (d k) (Row k) AllCols) m | ||
773 | |||
774 | go k = gemmm 1 (b k) (-1) (u k) (x k) >> scal k | ||
775 | mapM_ go [r-1,r-2..0] | ||
776 | |||
777 | |||
778 | {- | Experimental implementation of 'luSolve' for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'. | ||
779 | |||
780 | >>> let a = (2><2) [1,2,3,5] :: Matrix (Z ./. 13) | ||
781 | (2><2) | ||
782 | [ 1, 2 | ||
783 | , 3, 5 ] | ||
784 | >>> b | ||
785 | (2><3) | ||
786 | [ 5, 1, 3 | ||
787 | , 8, 6, 3 ] | ||
788 | |||
789 | >>> luSolve' (luPacked' a) b | ||
790 | (2><3) | ||
791 | [ 4, 7, 4 | ||
792 | , 7, 10, 6 ] | ||
793 | |||
794 | -} | ||
795 | luSolve' (lup,p) b = backSust lup (forwSust lup pb) | ||
796 | where | ||
797 | pb = b ?? (Pos (fixPerm' p), All) | ||
725 | 798 | ||
726 | -------------------------------------------------------------------------------- | 799 | -------------------------------------------------------------------------------- |
727 | 800 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index e899445..0b8abbb 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -81,7 +81,6 @@ module Numeric.LinearAlgebra ( | |||
81 | cholSolve, | 81 | cholSolve, |
82 | cgSolve, | 82 | cgSolve, |
83 | cgSolve', | 83 | cgSolve', |
84 | linearSolve', | ||
85 | 84 | ||
86 | -- * Inverse and pseudoinverse | 85 | -- * Inverse and pseudoinverse |
87 | inv, pinv, pinvTol, | 86 | inv, pinv, pinvTol, |
@@ -123,7 +122,7 @@ module Numeric.LinearAlgebra ( | |||
123 | schur, | 122 | schur, |
124 | 123 | ||
125 | -- * LU | 124 | -- * LU |
126 | lu, luPacked, luFact, luPacked', | 125 | lu, luPacked, luPacked', luFact, |
127 | 126 | ||
128 | -- * Matrix functions | 127 | -- * Matrix functions |
129 | expm, | 128 | expm, |
@@ -166,7 +165,6 @@ import Internal.Random | |||
166 | import Internal.Sparse((!#>)) | 165 | import Internal.Sparse((!#>)) |
167 | import Internal.CG | 166 | import Internal.CG |
168 | import Internal.Conversion | 167 | import Internal.Conversion |
169 | import Internal.ST(mutable) | ||
170 | 168 | ||
171 | {- | infix synonym of 'mul' | 169 | {- | infix synonym of 'mul' |
172 | 170 | ||
@@ -241,53 +239,4 @@ nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m) | |||
241 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. | 239 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. |
242 | orth m = orthSVD (Left (1*eps)) m (leftSV m) | 240 | orth m = orthSVD (Left (1*eps)) m (leftSV m) |
243 | 241 | ||
244 | {- | Experimental implementation of LU factorization | ||
245 | working on any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'. | ||
246 | |||
247 | >>> let m = ident 5 + (5><5) [0..] :: Matrix (Z ./. 17) | ||
248 | (5><5) | ||
249 | [ 1, 1, 2, 3, 4 | ||
250 | , 5, 7, 7, 8, 9 | ||
251 | , 10, 11, 13, 13, 14 | ||
252 | , 15, 16, 0, 2, 2 | ||
253 | , 3, 4, 5, 6, 8 ] | ||
254 | |||
255 | >>> let (l,u,p,s) = luFact $ luPacked' m | ||
256 | >>> l | ||
257 | (5><5) | ||
258 | [ 1, 0, 0, 0, 0 | ||
259 | , 6, 1, 0, 0, 0 | ||
260 | , 12, 7, 1, 0, 0 | ||
261 | , 7, 10, 7, 1, 0 | ||
262 | , 8, 2, 6, 11, 1 ] | ||
263 | >>> u | ||
264 | (5><5) | ||
265 | [ 15, 16, 0, 2, 2 | ||
266 | , 0, 13, 7, 13, 14 | ||
267 | , 0, 0, 15, 0, 11 | ||
268 | , 0, 0, 0, 15, 15 | ||
269 | , 0, 0, 0, 0, 1 ] | ||
270 | |||
271 | -} | ||
272 | luPacked' x = mutable (luST (magnit 0)) x | ||
273 | |||
274 | {- | Experimental implementation of gaussian elimination | ||
275 | working on any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'. | ||
276 | |||
277 | >>> let a = (2><2) [1,2,3,5] :: Matrix (Z ./. 13) | ||
278 | (2><2) | ||
279 | [ 1, 2 | ||
280 | , 3, 5 ] | ||
281 | >>> b | ||
282 | (2><3) | ||
283 | [ 5, 1, 3 | ||
284 | , 8, 6, 3 ] | ||
285 | |||
286 | >>> let x = linearSolve' a b | ||
287 | (2><3) | ||
288 | [ 4, 7, 4 | ||
289 | , 7, 10, 6 ] | ||
290 | |||
291 | -} | ||
292 | linearSolve' x y = gaussElim x y | ||
293 | 242 | ||