diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-20 13:49:48 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-20 13:49:48 +0200 |
commit | e09104974d13f66f4ec2e7ae2310346996cf2356 (patch) | |
tree | e1203e13f2c1e0587bdcb3938160546dadf0fcb1 /packages | |
parent | db50bc11dafa6834a4367427156306674063ed6b (diff) |
NFData for Mod, alternative luPacked'' using gaxpy
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 5 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 41 | ||||
-rw-r--r-- | packages/tests/hmatrix-tests.cabal | 2 | ||||
-rw-r--r-- | packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 3 |
4 files changed, 47 insertions, 4 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 1cae1ac..098817e 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -44,6 +44,7 @@ import Foreign.ForeignPtr(castForeignPtr) | |||
44 | import Foreign.Storable | 44 | import Foreign.Storable |
45 | import Data.Ratio | 45 | import Data.Ratio |
46 | import Data.Complex | 46 | import Data.Complex |
47 | import Control.DeepSeq ( NFData(..) ) | ||
47 | 48 | ||
48 | 49 | ||
49 | 50 | ||
@@ -51,6 +52,10 @@ import Data.Complex | |||
51 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | 52 | newtype Mod (n :: Nat) t = Mod {unMod:: t} |
52 | deriving (Storable) | 53 | deriving (Storable) |
53 | 54 | ||
55 | instance (NFData t) => NFData (Mod n t) | ||
56 | where | ||
57 | rnf (Mod x) = rnf x | ||
58 | |||
54 | infixr 5 ./. | 59 | infixr 5 ./. |
55 | type (./.) x n = Mod n x | 60 | type (./.) x n = Mod n x |
56 | 61 | ||
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 924ca4c..bf6c8b6 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -54,7 +54,7 @@ module Internal.Util( | |||
54 | -- ** 2D | 54 | -- ** 2D |
55 | corr2, conv2, separable, | 55 | corr2, conv2, separable, |
56 | block2x2,block3x3,view1,unView1,foldMatrix, | 56 | block2x2,block3x3,view1,unView1,foldMatrix, |
57 | gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked' | 57 | gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked', luPacked'' |
58 | ) where | 58 | ) where |
59 | 59 | ||
60 | import Internal.Vector | 60 | import Internal.Vector |
@@ -73,7 +73,7 @@ import Control.Monad(when,forM_) | |||
73 | import Text.Printf | 73 | import Text.Printf |
74 | import Data.List.Split(splitOn) | 74 | import Data.List.Split(splitOn) |
75 | import Data.List(intercalate,sortBy,foldl') | 75 | import Data.List(intercalate,sortBy,foldl') |
76 | import Control.Arrow((&&&)) | 76 | import Control.Arrow((&&&),(***)) |
77 | import Data.Complex | 77 | import Data.Complex |
78 | import Data.Function(on) | 78 | import Data.Function(on) |
79 | import Internal.ST | 79 | import Internal.ST |
@@ -714,6 +714,43 @@ luPacked' x = mutable (luST (magnit 0)) x | |||
714 | 714 | ||
715 | -------------------------------------------------------------------------------- | 715 | -------------------------------------------------------------------------------- |
716 | 716 | ||
717 | scalS a (Slice x r0 c0 nr nc) = rowOper (SCAL a (RowRange r0 (r0+nr-1)) (ColRange c0 (c0+nc-1))) x | ||
718 | |||
719 | view x k r = do | ||
720 | d <- readMatrix x k k | ||
721 | let rr = r-1-k | ||
722 | o = if k < r-1 then 1 else 0 | ||
723 | s = Slice x (k+1) (k+1) rr rr | ||
724 | u = Slice x k (k+1) o rr | ||
725 | l = Slice x (k+1) k rr o | ||
726 | return (d,u,l,s) | ||
727 | |||
728 | withVec r f = \s x -> do | ||
729 | p <- newUndefinedVector r | ||
730 | _ <- f s x p | ||
731 | v <- unsafeFreezeVector p | ||
732 | return v | ||
733 | |||
734 | |||
735 | luPacked'' m = (id *** toList) (mutable (withVec (rows m) lu2) m) | ||
736 | where | ||
737 | lu2 (r,_) x p = do | ||
738 | forM_ [0..r-1] $ \k -> do | ||
739 | pivot x p k | ||
740 | (d,u,l,s) <- view x k r | ||
741 | when (magnit 0 d) $ do | ||
742 | scalS (recip d) l | ||
743 | gemmm 1 s (-1) l u | ||
744 | |||
745 | pivot x p k = do | ||
746 | j <- maxIndex . abs . flatten <$> extractMatrix x (FromRow k) (Col k) | ||
747 | writeVector p k (j+k) | ||
748 | swap k (k+j) | ||
749 | where | ||
750 | swap i j = rowOper (SWAP i j AllCols) x | ||
751 | |||
752 | -------------------------------------------------------------------------------- | ||
753 | |||
717 | rowRange m = [0..rows m -1] | 754 | rowRange m = [0..rows m -1] |
718 | 755 | ||
719 | at k = Pos (idxs[k]) | 756 | at k = Pos (idxs[k]) |
diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal index de796e8..49e0640 100644 --- a/packages/tests/hmatrix-tests.cabal +++ b/packages/tests/hmatrix-tests.cabal | |||
@@ -26,7 +26,7 @@ flag gsl | |||
26 | 26 | ||
27 | library | 27 | library |
28 | 28 | ||
29 | Build-Depends: base >= 4 && < 5, | 29 | Build-Depends: base >= 4 && < 5, deepseq, |
30 | QuickCheck >= 2, HUnit, random, | 30 | QuickCheck >= 2, HUnit, random, |
31 | hmatrix >= 0.17 | 31 | hmatrix >= 0.17 |
32 | if flag(gsl) | 32 | if flag(gsl) |
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index d1fc6ec..b226c9f 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs | |||
@@ -47,6 +47,7 @@ import Debug.Trace | |||
47 | import Control.Monad(when) | 47 | import Control.Monad(when) |
48 | import Control.Applicative | 48 | import Control.Applicative |
49 | import Control.Monad(ap) | 49 | import Control.Monad(ap) |
50 | import Control.DeepSeq ( NFData(..) ) | ||
50 | 51 | ||
51 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector | 52 | import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector |
52 | ,sized,classify,Testable,Property | 53 | ,sized,classify,Testable,Property |
@@ -770,7 +771,7 @@ cholBench = do | |||
770 | luBenchN f n x msg = do | 771 | luBenchN f n x msg = do |
771 | let m = diagRect 1 (fromList (replicate n x)) n n | 772 | let m = diagRect 1 (fromList (replicate n x)) n n |
772 | m `seq` putStr "" | 773 | m `seq` putStr "" |
773 | time (msg ++ " "++ show n) (f m) | 774 | time (msg ++ " "++ show n) (rnf $ f m) |
774 | 775 | ||
775 | luBench = do | 776 | luBench = do |
776 | putStrLn "" | 777 | putStrLn "" |