From 717c680a4b65a2226b0dd6fc13f7c63e7bc0431d Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 13 Jun 2015 19:18:16 +0200 Subject: setRect, general luPacked' based on luST --- packages/base/src/Internal/C/lapack-aux.c | 54 +++++++++-------- packages/base/src/Internal/Matrix.hs | 20 +++++++ packages/base/src/Internal/Modular.hs | 75 +++++++++++++++++++----- packages/base/src/Internal/ST.hs | 7 ++- packages/base/src/Internal/Util.hs | 48 ++++++++++++++- packages/base/src/Numeric/LinearAlgebra.hs | 12 ++-- packages/base/src/Numeric/LinearAlgebra/Devel.hs | 2 +- 7 files changed, 169 insertions(+), 49 deletions(-) diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index e42889d..2843ab5 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -1448,7 +1448,7 @@ int transL(KLMAT(x),LMAT(t)) TRANS_IMP //////////////////////// extract ///////////////////////////////// -#define EXTRACT_IMP \ +#define EXTRACT_IMP { \ int i,j,si,sj,ni,nj; \ ni = modei ? in : ip[1]-ip[0]+1; \ nj = modej ? jn : jp[1]-jp[0]+1; \ @@ -1461,33 +1461,35 @@ int transL(KLMAT(x),LMAT(t)) TRANS_IMP \ AT(r,i,j) = AT(m,si,sj); \ } \ - } \ - OK - -int extractD(int modei, int modej, KIVEC(i), KIVEC(j), KODMAT(m), ODMAT(r)) { - EXTRACT_IMP -} - -int extractF(int modei, int modej, KIVEC(i), KIVEC(j), KOFMAT(m), OFMAT(r)) { - EXTRACT_IMP -} - -int extractC(int modei, int modej, KIVEC(i), KIVEC(j), KOCMAT(m), OCMAT(r)) { - EXTRACT_IMP -} - -int extractQ(int modei, int modej, KIVEC(i), KIVEC(j), KOQMAT(m), OQMAT(r)) { - EXTRACT_IMP -} - -int extractI(int modei, int modej, KIVEC(i), KIVEC(j), KOIMAT(m), OIMAT(r)) { - EXTRACT_IMP -} + } OK } -int extractL(int modei, int modej, KIVEC(i), KIVEC(j), KOLMAT(m), OLMAT(r)) { - EXTRACT_IMP -} +#define EXTRACT(T) int extract##T(int modei, int modej, KIVEC(i), KIVEC(j), KO##T##MAT(m), O##T##MAT(r)) EXTRACT_IMP + +EXTRACT(D) +EXTRACT(F) +EXTRACT(C) +EXTRACT(Q) +EXTRACT(I) +EXTRACT(L) + +//////////////////////// setRect ///////////////////////////////// + +#define SETRECT(T) \ +int setRect##T(int i, int j, KO##T##MAT(m), O##T##MAT(r)) { \ + { TRAV(m,a,b) { \ + int x = a+i, y = b+j; \ + if(x>=0 && x=0 && y Element a where transdata :: Int -> Vector a -> Int -> Vector a constantD :: a -> Int -> Vector a extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) + setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () sortI :: Ord a => Vector a -> Vector CInt sortV :: Ord a => Vector a -> Vector a compareV :: Ord a => Vector a -> Vector a -> Vector CInt @@ -287,6 +288,7 @@ instance Element Float where transdata = transdataAux ctransF constantD = constantAux cconstantF extractR = extractAux c_extractF + setRect = setRectAux c_setRectF sortI = sortIdxF sortV = sortValF compareV = compareF @@ -298,6 +300,7 @@ instance Element Double where transdata = transdataAux ctransR constantD = constantAux cconstantR extractR = extractAux c_extractD + setRect = setRectAux c_setRectD sortI = sortIdxD sortV = sortValD compareV = compareD @@ -310,6 +313,7 @@ instance Element (Complex Float) where transdata = transdataAux ctransQ constantD = constantAux cconstantQ extractR = extractAux c_extractQ + setRect = setRectAux c_setRectQ sortI = undefined sortV = undefined compareV = undefined @@ -322,6 +326,7 @@ instance Element (Complex Double) where transdata = transdataAux ctransC constantD = constantAux cconstantC extractR = extractAux c_extractC + setRect = setRectAux c_setRectC sortI = undefined sortV = undefined compareV = undefined @@ -333,6 +338,7 @@ instance Element (CInt) where transdata = transdataAux ctransI constantD = constantAux cconstantI extractR = extractAux c_extractI + setRect = setRectAux c_setRectI sortI = sortIdxI sortV = sortValI compareV = compareI @@ -344,6 +350,7 @@ instance Element Z where transdata = transdataAux ctransL constantD = constantAux cconstantL extractR = extractAux c_extractL + setRect = setRectAux c_setRectL sortI = sortIdxL sortV = sortValL compareV = compareL @@ -454,6 +461,19 @@ foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) foreign import ccall unsafe "extractI" c_extractI :: Extr CInt foreign import ccall unsafe "extractL" c_extractL :: Extr Z +--------------------------------------------------------------- + +setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" + +type SetRect x = I -> I -> x ::> x::> Ok + +foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double +foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float +foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) +foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) +foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I +foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z + -------------------------------------------------------------------------------- sortG f v = unsafePerformIO $ do diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 824fc57..3b27310 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -33,12 +33,15 @@ import Internal.Element import Internal.Container import Internal.Vectorized (prodI,sumI,prodL,sumL) import Internal.LAPACK (multiplyI, multiplyL) -import Internal.Util(Indexable(..),gaussElim) +import Internal.Algorithms(luFact) +import Internal.Util(Normed(..),Indexable(..),gaussElim, gaussElim_1, gaussElim_2,luST, magnit) +import Internal.ST(mutable) import GHC.TypeLits import Data.Proxy(Proxy) import Foreign.ForeignPtr(castForeignPtr) import Foreign.Storable import Data.Ratio +import Data.Complex @@ -116,6 +119,7 @@ instance KnownNat m => Element (Mod m I) transdata n v m = i2f (transdata n (f2i v) m) constantD x n = i2f (constantD (unMod x) n) extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js + setRect i j m x = setRect i j (f2iM m) (f2iM x) sortI = sortI . f2i sortV = i2f . sortV . f2i compareV u v = compareV (f2i u) (f2i v) @@ -130,6 +134,7 @@ instance KnownNat m => Element (Mod m Z) transdata n v m = i2f (transdata n (f2i v) m) constantD x n = i2f (constantD (unMod x) n) extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js + setRect i j m x = setRect i j (f2iM m) (f2iM x) sortI = sortI . f2i sortV = i2f . sortV . f2i compareV u v = compareV (f2i u) (f2i v) @@ -139,18 +144,6 @@ instance KnownNat m => Element (Mod m Z) where m' = fromIntegral . natVal $ (undefined :: Proxy m) -{- -instance (Ord t, Element t) => Element (Mod m t) - where - transdata n v m = i2f (transdata n (f2i v) m) - constantD x n = i2f (constantD (unMod x) n) - extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js - sortI = sortI . f2i - sortV = i2f . sortV . f2i - compareV u v = compareV (f2i u) (f2i v) - selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) - remapM i j m = i2fM (remap i j (f2iM m)) --} instance forall m . KnownNat m => Container Vector (Mod m I) where @@ -258,6 +251,20 @@ instance KnownNat m => Product (Mod m Z) where where m' = fromIntegral . natVal $ (undefined :: Proxy m) +instance KnownNat m => Normed (Vector (Mod m I)) + where + norm_0 = norm_0 . toInt + norm_1 = norm_1 . toInt + norm_2 = norm_2 . toInt + norm_Inf = norm_Inf . toInt + +instance KnownNat m => Normed (Vector (Mod m Z)) + where + norm_0 = norm_0 . toZ + norm_1 = norm_1 . toZ + norm_2 = norm_2 . toZ + norm_Inf = norm_Inf . toZ + instance KnownNat m => Numeric (Mod m I) instance KnownNat m => Numeric (Mod m Z) @@ -334,6 +341,15 @@ test = (ok, info) lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z lgm = fromZ lg :: Matrix (Mod 10000000000 Z) + gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t + + checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x) + + invg t = gaussElim t (ident (rows t)) + + checkLU okf t = norm_Inf $ flatten (l <> u <> p - t) + where + (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t info = do print v @@ -356,11 +372,42 @@ test = (ok, info) print $ lg <> lg print lgm print $ lgm <> lgm + + print (checkGen (gen 5 :: Matrix R)) + print (checkGen (gen 5 :: Matrix C)) + print (checkGen (gen 5 :: Matrix Float)) + print (checkGen (gen 5 :: Matrix (Complex Float))) + print (invg (gen 5) :: Matrix (Mod 7 I)) + print (invg (gen 5) :: Matrix (Mod 7 Z)) + + print $ mutable (luST (const True)) (gen 5 :: Matrix R) + print $ mutable (luST (const True)) (gen 5 :: Matrix (Mod 11 Z)) + + print $ checkLU (magnit 0) (gen 5 :: Matrix R) + print $ checkLU (magnit 0) (gen 5 :: Matrix Float) + print $ checkLU (magnit 0) (gen 5 :: Matrix C) + print $ checkLU (magnit 0) (gen 5 :: Matrix (Complex Float)) + print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) + print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) ok = and [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) - , am <> gaussElim am bm == bm + , am <> gaussElim_1 am bm == bm + , am <> gaussElim_2 am bm == bm + , am <> gaussElim am bm == bm + , (checkGen (gen 5 :: Matrix R)) < 1E-15 + , (checkGen (gen 5 :: Matrix Float)) < 1E-7 + , (checkGen (gen 5 :: Matrix C)) < 1E-15 + , (checkGen (gen 5 :: Matrix (Complex Float))) < 1E-7 + , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0 + , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0 + , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 1E-15 + , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6 + , (checkLU (magnit 1E-10) (gen 5 :: Matrix C)) < 1E-15 + , (checkLU (magnit 1E-5) (gen 5 :: Matrix (Complex Float))) < 1E-6 + , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0 + , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) , gm <> gm == konst 0 (3,3) , lgm <> lgm == konst 0 (3,3) diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 107d3c3..a84ca25 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -21,7 +21,7 @@ module Internal.ST ( -- * Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, - axpy, scal, swap, extractRect, + axpy, scal, swap, extractMatrix, setMatrix, rowOpST, mutable, -- * Unsafe functions newUndefinedVector, @@ -166,6 +166,9 @@ readMatrix = safeIndexM unsafeReadMatrix writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () writeMatrix = safeIndexM unsafeWriteMatrix +setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () +setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x + newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c @@ -182,7 +185,7 @@ axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) -extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) +extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) -------------------------------------------------------------------------------- diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 7a556e9..2650ac8 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -41,6 +41,7 @@ module Internal.Util( norm, ℕ,ℤ,ℝ,ℂ,iC, Normed(..), norm_Frob, norm_nuclear, + magnit, unitary, mt, (~!~), @@ -54,7 +55,7 @@ module Internal.Util( -- ** 2D corr2, conv2, separable, block2x2,block3x3,view1,unView1,foldMatrix, - gaussElim_1, gaussElim_2, gaussElim + gaussElim_1, gaussElim_2, gaussElim, luST ) where import Internal.Vector @@ -300,6 +301,26 @@ instance Normed (Vector I) norm_2 v = sqrt . fromIntegral $ dot v v norm_Inf = fromIntegral . normInf +instance Normed (Vector Z) + where + norm_0 = fromIntegral . sumElements . step . abs + norm_1 = fromIntegral . norm1 + norm_2 v = sqrt . fromIntegral $ dot v v + norm_Inf = fromIntegral . normInf + +instance Normed (Vector Float) + where + norm_0 = norm_0 . double + norm_1 = norm_1 . double + norm_2 = norm_2 . double + norm_Inf = norm_Inf . double + +instance Normed (Vector (Complex Float)) + where + norm_0 = norm_0 . double + norm_1 = norm_1 . double + norm_2 = norm_2 . double + norm_Inf = norm_Inf . double norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ @@ -308,6 +329,9 @@ norm_Frob = norm_2 . flatten norm_nuclear :: Field t => Matrix t -> ℝ norm_nuclear = sumElements . singularValues +magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool +magnit e x = norm_1 (fromList [x]) > e + -- | Obtains a vector in the same direction with 2-norm=1 unitary :: Vector Double -> Vector Double @@ -618,9 +642,10 @@ gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]] gaussST (r,_) x = do let n = r-1 forM_ [0..n] $ \i -> do - c <- maxIndex . abs . flatten <$> extractRect x i n i i + c <- maxIndex . abs . flatten <$> extractMatrix x i n i i swap x i (i+c) a <- readMatrix x i i + when (a == 0) $ error "singular!" scal x (recip a) i forM_ [i+1..n] $ \j -> do b <- readMatrix x j i @@ -630,6 +655,25 @@ gaussST (r,_) x = do b <- readMatrix x j i axpy x (-b) i j + +luST ok (r,c) x = do + let n = r-1 + axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m + p <- thawMatrix . asColumn . range $ r + forM_ [0..n] $ \i -> do + k <- maxIndex . abs . flatten <$> extractMatrix x i n i i + writeMatrix p i 0 (fi (k+i)) + swap x i (i+k) + a <- readMatrix x i i + when (ok a) $ do + forM_ [i+1..n] $ \j -> do + b <- (/a) <$> readMatrix x j i + axpy' x (-b) i j + writeMatrix x j i b + v <- unsafeFreezeMatrix p + return (map ti $ toList $ flatten v) + + -------------------------------------------------------------------------------- instance Testable (Matrix I) where diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index c97f415..0f8efa4 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE FlexibleContexts #-} + ----------------------------------------------------------------------------- {- | Module : Numeric.LinearAlgebra @@ -119,7 +121,7 @@ module Numeric.LinearAlgebra ( schur, -- * LU - lu, luPacked, + lu, luPacked, luFact, luPacked', -- * Matrix functions expm, @@ -134,7 +136,7 @@ module Numeric.LinearAlgebra ( Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, -- * Misc - meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, gaussElim_1, gaussElim_2, + meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, luST, magnit, ℝ,ℂ,iC, -- * Auxiliary classes Element, Container, Product, Numeric, LSDiv, @@ -142,7 +144,6 @@ module Numeric.LinearAlgebra ( RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, Field, --- Normed, Transposable, CGState(..), Testable(..) @@ -155,13 +156,14 @@ import Numeric.Vector() import Internal.Matrix import Internal.Container hiding ((<>)) import Internal.Numeric hiding (mul) -import Internal.Algorithms hiding (linearSolve,Normed,orth) +import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked') import qualified Internal.Algorithms as A import Internal.Util import Internal.Random import Internal.Sparse((!#>)) import Internal.CG import Internal.Conversion +import Internal.ST(mutable) {- | infix synonym of 'mul' @@ -236,3 +238,5 @@ nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m) -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. orth m = orthSVD (Left (1*eps)) m (leftSV m) +luPacked' x = mutable (luST (magnit 0)) x + diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 84763fe..f572656 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs @@ -44,7 +44,7 @@ module Numeric.LinearAlgebra.Devel( -- ** Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, - axpy,scal,swap, extractRect, mutable, + axpy,scal,swap, extractMatrix, setMatrix, mutable, rowOpST, -- ** Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, -- cgit v1.2.3