summaryrefslogtreecommitdiff
path: root/packages/base
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base')
-rw-r--r--packages/base/src/Internal/Modular.hs5
-rw-r--r--packages/base/src/Internal/Util.hs41
2 files changed, 44 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 1cae1ac..098817e 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -44,6 +44,7 @@ import Foreign.ForeignPtr(castForeignPtr)
44import Foreign.Storable 44import Foreign.Storable
45import Data.Ratio 45import Data.Ratio
46import Data.Complex 46import Data.Complex
47import Control.DeepSeq ( NFData(..) )
47 48
48 49
49 50
@@ -51,6 +52,10 @@ import Data.Complex
51newtype Mod (n :: Nat) t = Mod {unMod:: t} 52newtype Mod (n :: Nat) t = Mod {unMod:: t}
52 deriving (Storable) 53 deriving (Storable)
53 54
55instance (NFData t) => NFData (Mod n t)
56 where
57 rnf (Mod x) = rnf x
58
54infixr 5 ./. 59infixr 5 ./.
55type (./.) x n = Mod n x 60type (./.) x n = Mod n x
56 61
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 924ca4c..bf6c8b6 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -54,7 +54,7 @@ module Internal.Util(
54 -- ** 2D 54 -- ** 2D
55 corr2, conv2, separable, 55 corr2, conv2, separable,
56 block2x2,block3x3,view1,unView1,foldMatrix, 56 block2x2,block3x3,view1,unView1,foldMatrix,
57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked' 57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked', luPacked''
58) where 58) where
59 59
60import Internal.Vector 60import Internal.Vector
@@ -73,7 +73,7 @@ import Control.Monad(when,forM_)
73import Text.Printf 73import Text.Printf
74import Data.List.Split(splitOn) 74import Data.List.Split(splitOn)
75import Data.List(intercalate,sortBy,foldl') 75import Data.List(intercalate,sortBy,foldl')
76import Control.Arrow((&&&)) 76import Control.Arrow((&&&),(***))
77import Data.Complex 77import Data.Complex
78import Data.Function(on) 78import Data.Function(on)
79import Internal.ST 79import Internal.ST
@@ -714,6 +714,43 @@ luPacked' x = mutable (luST (magnit 0)) x
714 714
715-------------------------------------------------------------------------------- 715--------------------------------------------------------------------------------
716 716
717scalS a (Slice x r0 c0 nr nc) = rowOper (SCAL a (RowRange r0 (r0+nr-1)) (ColRange c0 (c0+nc-1))) x
718
719view x k r = do
720 d <- readMatrix x k k
721 let rr = r-1-k
722 o = if k < r-1 then 1 else 0
723 s = Slice x (k+1) (k+1) rr rr
724 u = Slice x k (k+1) o rr
725 l = Slice x (k+1) k rr o
726 return (d,u,l,s)
727
728withVec r f = \s x -> do
729 p <- newUndefinedVector r
730 _ <- f s x p
731 v <- unsafeFreezeVector p
732 return v
733
734
735luPacked'' m = (id *** toList) (mutable (withVec (rows m) lu2) m)
736 where
737 lu2 (r,_) x p = do
738 forM_ [0..r-1] $ \k -> do
739 pivot x p k
740 (d,u,l,s) <- view x k r
741 when (magnit 0 d) $ do
742 scalS (recip d) l
743 gemmm 1 s (-1) l u
744
745 pivot x p k = do
746 j <- maxIndex . abs . flatten <$> extractMatrix x (FromRow k) (Col k)
747 writeVector p k (j+k)
748 swap k (k+j)
749 where
750 swap i j = rowOper (SWAP i j AllCols) x
751
752--------------------------------------------------------------------------------
753
717rowRange m = [0..rows m -1] 754rowRange m = [0..rows m -1]
718 755
719at k = Pos (idxs[k]) 756at k = Pos (idxs[k])