From e09104974d13f66f4ec2e7ae2310346996cf2356 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 20 Jun 2015 13:49:48 +0200 Subject: NFData for Mod, alternative luPacked'' using gaxpy --- packages/base/src/Internal/Modular.hs | 5 +++ packages/base/src/Internal/Util.hs | 41 +++++++++++++++++++++-- packages/tests/hmatrix-tests.cabal | 2 +- packages/tests/src/Numeric/LinearAlgebra/Tests.hs | 3 +- 4 files changed, 47 insertions(+), 4 deletions(-) (limited to 'packages') diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 1cae1ac..098817e 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -44,6 +44,7 @@ import Foreign.ForeignPtr(castForeignPtr) import Foreign.Storable import Data.Ratio import Data.Complex +import Control.DeepSeq ( NFData(..) ) @@ -51,6 +52,10 @@ import Data.Complex newtype Mod (n :: Nat) t = Mod {unMod:: t} deriving (Storable) +instance (NFData t) => NFData (Mod n t) + where + rnf (Mod x) = rnf x + infixr 5 ./. type (./.) x n = Mod n x diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 924ca4c..bf6c8b6 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -54,7 +54,7 @@ module Internal.Util( -- ** 2D corr2, conv2, separable, block2x2,block3x3,view1,unView1,foldMatrix, - gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked' + gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked', luPacked'' ) where import Internal.Vector @@ -73,7 +73,7 @@ import Control.Monad(when,forM_) import Text.Printf import Data.List.Split(splitOn) import Data.List(intercalate,sortBy,foldl') -import Control.Arrow((&&&)) +import Control.Arrow((&&&),(***)) import Data.Complex import Data.Function(on) import Internal.ST @@ -714,6 +714,43 @@ luPacked' x = mutable (luST (magnit 0)) x -------------------------------------------------------------------------------- +scalS a (Slice x r0 c0 nr nc) = rowOper (SCAL a (RowRange r0 (r0+nr-1)) (ColRange c0 (c0+nc-1))) x + +view x k r = do + d <- readMatrix x k k + let rr = r-1-k + o = if k < r-1 then 1 else 0 + s = Slice x (k+1) (k+1) rr rr + u = Slice x k (k+1) o rr + l = Slice x (k+1) k rr o + return (d,u,l,s) + +withVec r f = \s x -> do + p <- newUndefinedVector r + _ <- f s x p + v <- unsafeFreezeVector p + return v + + +luPacked'' m = (id *** toList) (mutable (withVec (rows m) lu2) m) + where + lu2 (r,_) x p = do + forM_ [0..r-1] $ \k -> do + pivot x p k + (d,u,l,s) <- view x k r + when (magnit 0 d) $ do + scalS (recip d) l + gemmm 1 s (-1) l u + + pivot x p k = do + j <- maxIndex . abs . flatten <$> extractMatrix x (FromRow k) (Col k) + writeVector p k (j+k) + swap k (k+j) + where + swap i j = rowOper (SWAP i j AllCols) x + +-------------------------------------------------------------------------------- + rowRange m = [0..rows m -1] at k = Pos (idxs[k]) diff --git a/packages/tests/hmatrix-tests.cabal b/packages/tests/hmatrix-tests.cabal index de796e8..49e0640 100644 --- a/packages/tests/hmatrix-tests.cabal +++ b/packages/tests/hmatrix-tests.cabal @@ -26,7 +26,7 @@ flag gsl library - Build-Depends: base >= 4 && < 5, + Build-Depends: base >= 4 && < 5, deepseq, QuickCheck >= 2, HUnit, random, hmatrix >= 0.17 if flag(gsl) diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs index d1fc6ec..b226c9f 100644 --- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs +++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs @@ -47,6 +47,7 @@ import Debug.Trace import Control.Monad(when) import Control.Applicative import Control.Monad(ap) +import Control.DeepSeq ( NFData(..) ) import Test.QuickCheck(Arbitrary,arbitrary,coarbitrary,choose,vector ,sized,classify,Testable,Property @@ -770,7 +771,7 @@ cholBench = do luBenchN f n x msg = do let m = diagRect 1 (fromList (replicate n x)) n n m `seq` putStr "" - time (msg ++ " "++ show n) (f m) + time (msg ++ " "++ show n) (rnf $ f m) luBench = do putStrLn "" -- cgit v1.2.3