summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-20 13:49:48 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-20 13:49:48 +0200
commite09104974d13f66f4ec2e7ae2310346996cf2356 (patch)
treee1203e13f2c1e0587bdcb3938160546dadf0fcb1
parentdb50bc11dafa6834a4367427156306674063ed6b (diff)
NFData for Mod, alternative luPacked'' using gaxpy
-rw-r--r--packages/base/src/Internal/Modular.hs5
-rw-r--r--packages/base/src/Internal/Util.hs41
-rw-r--r--packages/tests/hmatrix-tests.cabal2
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs3
4 files changed, 47 insertions, 4 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 1cae1ac..098817e 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -44,6 +44,7 @@ import Foreign.ForeignPtr(castForeignPtr)
44import Foreign.Storable 44import Foreign.Storable
45import Data.Ratio 45import Data.Ratio
46import Data.Complex 46import Data.Complex
47import Control.DeepSeq ( NFData(..) )
47 48
48 49
49 50
@@ -51,6 +52,10 @@ import Data.Complex
51newtype Mod (n :: Nat) t = Mod {unMod:: t} 52newtype Mod (n :: Nat) t = Mod {unMod:: t}
52 deriving (Storable) 53 deriving (Storable)
53 54
55instance (NFData t) => NFData (Mod n t)
56 where
57 rnf (Mod x) = rnf x
58
54infixr 5 ./. 59infixr 5 ./.
55type (./.) x n = Mod n x 60type (./.) x n = Mod n x
56 61
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 924ca4c..bf6c8b6 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -54,7 +54,7 @@ module Internal.Util(
54 -- ** 2D 54 -- ** 2D
55 corr2, conv2, separable, 55 corr2, conv2, separable,
56 block2x2,block3x3,view1,unView1,foldMatrix, 56 block2x2,block3x3,view1,unView1,foldMatrix,
57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked' 57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked', luPacked''
58) where 58) where
59 59
60import Internal.Vector 60import Internal.Vector
@@ -73,7 +73,7 @@ import Control.Monad(when,forM_)
73import Text.Printf 73import Text.Printf
74import Data.List.Split(splitOn) 74import Data.List.Split(splitOn)
75import Data.List(intercalate,sortBy,foldl') 75import Data.List(intercalate,sortBy,foldl')
76import Control.Arrow((&&&)) 76import Control.Arrow((&&&),(***))
77import Data.Complex 77import Data.Complex
78import Data.Function(on) 78import Data.Function(on)
79import Internal.ST 79import Internal.ST
@@ -714,6 +714,43 @@ luPacked' x = mutable (luST (magnit 0)) x
714 714
715-------------------------------------------------------------------------------- 715--------------------------------------------------------------------------------
716 716
717scalS a (Slice x r0 c0 nr nc) = rowOper (SCAL a (RowRange r0 (r0+nr-1)) (ColRange c0 (c0+nc-1))) x
718
719view x k r = do
720 d <- readMatrix x k k
721 let rr = r-1-k
722 o = if k < r-1 then 1 else 0
723 s = Slice x (k+1) (k+1) rr rr
724 u = Slice x k (k+1) o rr
725 l = Slice x (k+1) k rr o
726 return (d,u,l,s)
727
728withVec r f = \s x -> do
729 p <- newUndefinedVector r
730 _ <- f s x p
731 v <- unsafeFreezeVector p
732 return v
733
734
735luPacked'' m = (id *** toList) (mutable (withVec (rows m) lu2) m)
736 where
737 lu2 (r,_) x p = do
738 forM_ [0..r-1] $ \k -> do
739 pivot x p k
740 (d,u,l,s) <- view x k r
741 when (magnit 0 d) $ do
742 scalS (recip d) l
743 gemmm 1 s (-1) l u
744
745 pivot x p k = do
746 j <- maxIndex . abs . flatten <$> extractMatrix x (FromRow k) (Col k)
747 writeVector p k (j+k)
748 swap k (k+j)
749 where
750 swap i j = rowOper (SWAP i j AllCols) x
751
752--------------------------------------------------------------------------------
753
717rowRange m = [0..rows m -1] 754rowRange m = [0..rows m -1]
718 755
719at k = Pos (idxs[k]) 756at k = Pos (idxs[k])
diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal
index de796e8..49e0640 100644
--- a/packages/tests/hmatrix-tests.cabal
+++ b/packages/tests/hmatrix-tests.cabal
@@ -26,7 +26,7 @@ flag gsl
26 26
27library 27library
28 28
29 Build-Depends: base >= 4 && < 5, 29 Build-Depends: base >= 4 && < 5, deepseq,
30 QuickCheck >= 2, HUnit, random, 30 QuickCheck >= 2, HUnit, random,
31 hmatrix >= 0.17 31 hmatrix >= 0.17
32 if flag(gsl) 32 if flag(gsl)
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index d1fc6ec..b226c9f 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -47,6 +47,7 @@ import Debug.Trace
47import Control.Monad(when) 47import Control.Monad(when)
48import Control.Applicative 48import Control.Applicative
49import Control.Monad(ap) 49import Control.Monad(ap)
50import Control.DeepSeq ( NFData(..) )
50 51
51import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector 52import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector
52 ,sized,classify,Testable,Property 53 ,sized,classify,Testable,Property
@@ -770,7 +771,7 @@ cholBench = do
770luBenchN f n x msg = do 771luBenchN f n x msg = do
771 let m = diagRect 1 (fromList (replicate n x)) n n 772 let m = diagRect 1 (fromList (replicate n x)) n n
772 m `seq` putStr "" 773 m `seq` putStr ""
773 time (msg ++ " "++ show n) (f m) 774 time (msg ++ " "++ show n) (rnf $ f m)
774 775
775luBench = do 776luBench = do
776 putStrLn "" 777 putStrLn ""