diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-20 13:49:48 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-20 13:49:48 +0200 |
commit | e09104974d13f66f4ec2e7ae2310346996cf2356 (patch) | |
tree | e1203e13f2c1e0587bdcb3938160546dadf0fcb1 /packages/base/src/Internal | |
parent | db50bc11dafa6834a4367427156306674063ed6b (diff) |
NFData for Mod, alternative luPacked'' using gaxpy
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 5 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 41 |
2 files changed, 44 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 1cae1ac..098817e 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -44,6 +44,7 @@ import Foreign.ForeignPtr(castForeignPtr) | |||
44 | import Foreign.Storable | 44 | import Foreign.Storable |
45 | import Data.Ratio | 45 | import Data.Ratio |
46 | import Data.Complex | 46 | import Data.Complex |
47 | import Control.DeepSeq ( NFData(..) ) | ||
47 | 48 | ||
48 | 49 | ||
49 | 50 | ||
@@ -51,6 +52,10 @@ import Data.Complex | |||
51 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | 52 | newtype Mod (n :: Nat) t = Mod {unMod:: t} |
52 | deriving (Storable) | 53 | deriving (Storable) |
53 | 54 | ||
55 | instance (NFData t) => NFData (Mod n t) | ||
56 | where | ||
57 | rnf (Mod x) = rnf x | ||
58 | |||
54 | infixr 5 ./. | 59 | infixr 5 ./. |
55 | type (./.) x n = Mod n x | 60 | type (./.) x n = Mod n x |
56 | 61 | ||
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 924ca4c..bf6c8b6 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -54,7 +54,7 @@ module Internal.Util( | |||
54 | -- ** 2D | 54 | -- ** 2D |
55 | corr2, conv2, separable, | 55 | corr2, conv2, separable, |
56 | block2x2,block3x3,view1,unView1,foldMatrix, | 56 | block2x2,block3x3,view1,unView1,foldMatrix, |
57 | gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked' | 57 | gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked', luPacked'' |
58 | ) where | 58 | ) where |
59 | 59 | ||
60 | import Internal.Vector | 60 | import Internal.Vector |
@@ -73,7 +73,7 @@ import Control.Monad(when,forM_) | |||
73 | import Text.Printf | 73 | import Text.Printf |
74 | import Data.List.Split(splitOn) | 74 | import Data.List.Split(splitOn) |
75 | import Data.List(intercalate,sortBy,foldl') | 75 | import Data.List(intercalate,sortBy,foldl') |
76 | import Control.Arrow((&&&)) | 76 | import Control.Arrow((&&&),(***)) |
77 | import Data.Complex | 77 | import Data.Complex |
78 | import Data.Function(on) | 78 | import Data.Function(on) |
79 | import Internal.ST | 79 | import Internal.ST |
@@ -714,6 +714,43 @@ luPacked' x = mutable (luST (magnit 0)) x | |||
714 | 714 | ||
715 | -------------------------------------------------------------------------------- | 715 | -------------------------------------------------------------------------------- |
716 | 716 | ||
717 | scalS a (Slice x r0 c0 nr nc) = rowOper (SCAL a (RowRange r0 (r0+nr-1)) (ColRange c0 (c0+nc-1))) x | ||
718 | |||
719 | view x k r = do | ||
720 | d <- readMatrix x k k | ||
721 | let rr = r-1-k | ||
722 | o = if k < r-1 then 1 else 0 | ||
723 | s = Slice x (k+1) (k+1) rr rr | ||
724 | u = Slice x k (k+1) o rr | ||
725 | l = Slice x (k+1) k rr o | ||
726 | return (d,u,l,s) | ||
727 | |||
728 | withVec r f = \s x -> do | ||
729 | p <- newUndefinedVector r | ||
730 | _ <- f s x p | ||
731 | v <- unsafeFreezeVector p | ||
732 | return v | ||
733 | |||
734 | |||
735 | luPacked'' m = (id *** toList) (mutable (withVec (rows m) lu2) m) | ||
736 | where | ||
737 | lu2 (r,_) x p = do | ||
738 | forM_ [0..r-1] $ \k -> do | ||
739 | pivot x p k | ||
740 | (d,u,l,s) <- view x k r | ||
741 | when (magnit 0 d) $ do | ||
742 | scalS (recip d) l | ||
743 | gemmm 1 s (-1) l u | ||
744 | |||
745 | pivot x p k = do | ||
746 | j <- maxIndex . abs . flatten <$> extractMatrix x (FromRow k) (Col k) | ||
747 | writeVector p k (j+k) | ||
748 | swap k (k+j) | ||
749 | where | ||
750 | swap i j = rowOper (SWAP i j AllCols) x | ||
751 | |||
752 | -------------------------------------------------------------------------------- | ||
753 | |||
717 | rowRange m = [0..rows m -1] | 754 | rowRange m = [0..rows m -1] |
718 | 755 | ||
719 | at k = Pos (idxs[k]) | 756 | at k = Pos (idxs[k]) |