From 4b3e29097aa272d429f8005fe17b459cf0c049c8 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 12 Jun 2015 20:58:13 +0200 Subject: row ops in ST --- packages/base/hmatrix.cabal | 1 + packages/base/src/Internal/C/lapack-aux.c | 87 ++++++++++++++++++++++++ packages/base/src/Internal/C/lapack-aux.h | 1 + packages/base/src/Internal/C/vector-aux.c | 7 +- packages/base/src/Internal/Element.hs | 3 +- packages/base/src/Internal/Matrix.hs | 32 ++++++++- packages/base/src/Internal/Modular.hs | 43 ++++++++++-- packages/base/src/Internal/ST.hs | 26 ++++++- packages/base/src/Internal/Util.hs | 80 ++++++++++++++++++++-- packages/base/src/Numeric/LinearAlgebra.hs | 2 +- packages/base/src/Numeric/LinearAlgebra/Devel.hs | 1 + 11 files changed, 263 insertions(+), 20 deletions(-) (limited to 'packages') diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal index 0ab4821..f725341 100644 --- a/packages/base/hmatrix.cabal +++ b/packages/base/hmatrix.cabal @@ -81,6 +81,7 @@ library ghc-options: -Wall -fno-warn-missing-signatures -fno-warn-orphans + -fprof-auto cc-options: -O4 -Wall diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index dcce1c5..e42889d 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c @@ -4,6 +4,12 @@ #include #include #include +#include + +typedef double complex TCD; +typedef float complex TCF; + +#undef complex #include "lapack-aux.h" @@ -46,6 +52,10 @@ #define NODEFPOS 2006 #define NOSPRTD 2007 +inline int mod (int a, int b); + +inline int64_t mod_l (int64_t a, int64_t b); + //--------------------------------------- void asm_finit() { #ifdef i386 @@ -1310,6 +1320,83 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) { int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP +/////////////////////////////// inplace row ops //////////////////////////////// + +#define AXPY_IMP { \ + int j; \ + for(j=j1; j<=j2; j++) { \ + AT(r,i2,j) += a*AT(r,i1,j); \ + } OK } + +#define AXPY_MOD_IMP(M) { \ + int j; \ + for(j=j1; j<=j2; j++) { \ + AT(r,i2,j) = M(AT(r,i2,j) + M(a*AT(r,i1,j), m) , m); \ + } OK } + + +#define SCAL_IMP { \ + int i,j; \ + for(i=i1; i<=i2; i++) { \ + for(j=j1; j<=j2; j++) { \ + AT(r,i,j) = a*AT(r,i,j); \ + } \ + } OK } + +#define SCAL_MOD_IMP(M) { \ + int i,j; \ + for(i=i1; i<=i2; i++) { \ + for(j=j1; j<=j2; j++) { \ + AT(r,i,j) = M(a*AT(r,i,j) , m); \ + } \ + } OK } + + +#define SWAP_IMP(T) { \ + T aux; \ + int k; \ + if (i1 != i2) { \ + for (k=j1; k<=j2; k++) { \ + aux = AT(r,i1,k); \ + AT(r,i1,k) = AT(r,i2,k); \ + AT(r,i2,k) = aux; \ + } \ + } OK } + + +#define ROWOP_IMP(T) { \ + T a = *pa; \ + switch(code) { \ + case 0: AXPY_IMP \ + case 1: SCAL_IMP \ + case 2: SWAP_IMP(T) \ + default: ERROR(BAD_CODE); \ + } \ +} + +#define ROWOP_MOD_IMP(T,M) { \ + T a = *pa; \ + switch(code) { \ + case 0: AXPY_MOD_IMP(M) \ + case 1: SCAL_MOD_IMP(M) \ + case 2: SWAP_IMP(T) \ + default: ERROR(BAD_CODE); \ + } \ +} + + +#define ROWOP(T) int rowop_##T(int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_IMP(T) + +#define ROWOP_MOD(T,M) int rowop_mod_##T(T m, int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_MOD_IMP(T,M) + +ROWOP(double) +ROWOP(float) +ROWOP(TCD) +ROWOP(TCF) +ROWOP(int32_t) +ROWOP(int64_t) +ROWOP_MOD(int32_t,mod) +ROWOP_MOD(int64_t,mod_l) ////////////////// sparse matrix-product /////////////////////////////////////// diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h index 1549bb5..e4d95bc 100644 --- a/packages/base/src/Internal/C/lapack-aux.h +++ b/packages/base/src/Internal/C/lapack-aux.h @@ -59,6 +59,7 @@ typedef short ftnlen; #define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p #define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p +#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p #define KIVEC(A) int A##n, const int*A##p #define KLVEC(A) int A##n, const int64_t*A##p diff --git a/packages/base/src/Internal/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c index c161556..5528a9d 100644 --- a/packages/base/src/Internal/C/vector-aux.c +++ b/packages/base/src/Internal/C/vector-aux.c @@ -716,6 +716,7 @@ int mapValF(int code, float* pval, KFVEC(x), FVEC(r)) { } } +inline int mod (int a, int b) { int m = a % b; if (b>0) { @@ -741,7 +742,7 @@ int mapValI(int code, int* pval, KIVEC(x), IVEC(r)) { } } - +inline int64_t mod_l (int64_t a, int64_t b) { int64_t m = a % b; if (b>0) { @@ -1230,7 +1231,7 @@ int round_vector_i(KDVEC(v),IVEC(r)) { int mod_vector(int m, KIVEC(v), IVEC(r)) { int k; for(k=0; k Element a where transdata :: Int -> Vector a -> Int -> Vector a constantD :: a -> Int -> Vector a - extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a + extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) sortI :: Ord a => Vector a -> Vector CInt sortV :: Ord a => Vector a -> Vector a compareV :: Ord a => Vector a -> Vector a -> Vector CInt selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a + rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () instance Element Float where @@ -290,6 +292,7 @@ instance Element Float where compareV = compareF selectV = selectF remapM = remapF + rowOp = rowOpAux c_rowOpF instance Element Double where transdata = transdataAux ctransR @@ -300,6 +303,7 @@ instance Element Double where compareV = compareD selectV = selectD remapM = remapD + rowOp = rowOpAux c_rowOpD instance Element (Complex Float) where @@ -311,6 +315,7 @@ instance Element (Complex Float) where compareV = undefined selectV = selectQ remapM = remapQ + rowOp = rowOpAux c_rowOpQ instance Element (Complex Double) where @@ -322,6 +327,7 @@ instance Element (Complex Double) where compareV = undefined selectV = selectC remapM = remapC + rowOp = rowOpAux c_rowOpC instance Element (CInt) where transdata = transdataAux ctransI @@ -332,6 +338,7 @@ instance Element (CInt) where compareV = compareI selectV = selectI remapM = remapI + rowOp = rowOpAux c_rowOpI instance Element Z where transdata = transdataAux ctransL @@ -342,6 +349,7 @@ instance Element Z where compareV = compareL selectV = selectL remapM = remapL + rowOp = rowOpAux c_rowOpL ------------------------------------------------------------------- @@ -379,7 +387,7 @@ subMatrix :: Element a -> Matrix a -- ^ result subMatrix (r0,c0) (rt,ct) m | 0 <= r0 && 0 <= rt && r0+rt <= rows m && - 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) + 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) | otherwise = error $ "wrong subMatrix "++ show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) @@ -430,7 +438,7 @@ instance (Storable t, NFData t) => NFData (Matrix t) --------------------------------------------------------------- -extractAux f m moder vr modec vc = unsafePerformIO $ do +extractAux f m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix RowMajor nr nc @@ -538,6 +546,24 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z -------------------------------------------------------------------------------- +rowOpAux f c x i1 i2 j1 j2 m = do + px <- newArray [x] + app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" + free px + +type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok + +foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R +foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float +foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C +foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) +foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I +foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z +foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I +foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z + +-------------------------------------------------------------------------------- + foreign import ccall unsafe "saveMatrix" c_saveMatrix :: CString -> CString -> Double ..> Ok diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 1289a21..824fc57 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -111,18 +111,46 @@ instance forall n t . (Integral t, KnownNat n) => Num (Mod n t) fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) +instance KnownNat m => Element (Mod m I) + where + transdata n v m = i2f (transdata n (f2i v) m) + constantD x n = i2f (constantD (unMod x) n) + extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js + sortI = sortI . f2i + sortV = i2f . sortV . f2i + compareV u v = compareV (f2i u) (f2i v) + selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) + remapM i j m = i2fM (remap i j (f2iM m)) + rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) -instance (Ord t, Element t) => Element (Mod n t) +instance KnownNat m => Element (Mod m Z) where transdata n v m = i2f (transdata n (f2i v) m) constantD x n = i2f (constantD (unMod x) n) - extractR m mi is mj js = i2fM (extractR (f2iM m) mi is mj js) + extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js sortI = sortI . f2i sortV = i2f . sortV . f2i compareV u v = compareV (f2i u) (f2i v) selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) remapM i j m = i2fM (remap i j (f2iM m)) + rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) +{- +instance (Ord t, Element t) => Element (Mod m t) + where + transdata n v m = i2f (transdata n (f2i v) m) + constantD x n = i2f (constantD (unMod x) n) + extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js + sortI = sortI . f2i + sortV = i2f . sortV . f2i + compareV u v = compareV (f2i u) (f2i v) + selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) + remapM i j m = i2fM (remap i j (f2iM m)) +-} instance forall m . KnownNat m => Container Vector (Mod m I) where @@ -205,12 +233,10 @@ instance forall m . KnownNat m => Container Vector (Mod m Z) toZ' = f2i - instance (Storable t, Indexable (Vector t) t) => Indexable (Vector (Mod m t)) (Mod m t) where (!) = (@>) - type instance RealOf (Mod n I) = I type instance RealOf (Mod n Z) = Z @@ -270,6 +296,15 @@ instance forall m . KnownNat m => Num (Vector (Mod m I)) negate = lift1 negate fromInteger x = fromInt (fromInteger x) +instance forall m . KnownNat m => Num (Vector (Mod m Z)) + where + (+) = lift2 (+) + (*) = lift2 (*) + (-) = lift2 (-) + abs = lift1 abs + signum = lift1 signum + negate = lift1 negate + fromInteger x = fromZ (fromInteger x) -------------------------------------------------------------------------------- diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index ae75a1b..107d3c3 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -1,5 +1,6 @@ {-# LANGUAGE Rank2Types #-} {-# LANGUAGE BangPatterns #-} + ----------------------------------------------------------------------------- -- | -- Module : Internal.ST @@ -20,6 +21,8 @@ module Internal.ST ( -- * Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, + axpy, scal, swap, extractRect, + mutable, -- * Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, @@ -34,8 +37,6 @@ import Internal.Matrix import Internal.Vectorized import Control.Monad.ST(ST, runST) import Foreign.Storable(Storable, peekElemOff, pokeElemOff) - - import Control.Monad.ST.Unsafe(unsafeIOToST) {-# INLINE ioReadV #-} @@ -144,6 +145,7 @@ liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x + freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) freezeMatrix m = liftSTMatrix id m @@ -171,3 +173,23 @@ newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) +-------------------------------------------------------------------------------- + +rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () +rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) + +axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) +scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) +swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) + +extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) + +-------------------------------------------------------------------------------- + +mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) +mutable f a = runST $ do + x <- thawMatrix a + info <- f (rows a, cols a) x + r <- unsafeFreezeMatrix x + return (r,info) + diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index b1fb800..7a556e9 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -54,7 +54,7 @@ module Internal.Util( -- ** 2D corr2, conv2, separable, block2x2,block3x3,view1,unView1,foldMatrix, - gaussElim + gaussElim_1, gaussElim_2, gaussElim ) where import Internal.Vector @@ -64,17 +64,19 @@ import Internal.Element import Internal.Container import Internal.Vectorized import Internal.IO -import Internal.Algorithms hiding (i,Normed) +import Internal.Algorithms hiding (i,Normed,swap) import Numeric.Matrix() import Numeric.Vector() import Internal.Random import Internal.Convolution -import Control.Monad(when) +import Control.Monad(when,forM_) import Text.Printf import Data.List.Split(splitOn) -import Data.List(intercalate,) +import Data.List(intercalate,sortBy) import Control.Arrow((&&&)) import Data.Complex +import Data.Function(on) +import Internal.ST type ℝ = Double type ℕ = Int @@ -359,6 +361,10 @@ instance Indexable (Vector I) I where (!) = (@>) +instance Indexable (Vector Z) Z + where + (!) = (@>) + instance Indexable (Vector (Complex Double)) (Complex Double) where (!) = (@>) @@ -550,11 +556,11 @@ down g a = foldMatrix g f a -- -- @a <> gaussElim a b = b@ -- -gaussElim +gaussElim_2 :: (Eq t, Fractional t, Num (Vector t), Numeric t) => Matrix t -> Matrix t -> Matrix t -gaussElim a b = flipudrl r +gaussElim_2 a b = flipudrl r where flipudrl = flipud . fliprl splitColsAt n = (takeColumns n &&& dropColumns n) @@ -564,6 +570,68 @@ gaussElim a b = flipudrl r -------------------------------------------------------------------------------- +gaussElim_1 + :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t) + => Matrix t -> Matrix t -> Matrix t + +gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) + where + rs = toRows $ fromBlocks [[x , y]] + s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting + s2 = pivotUp (rows x-1) (toRows $ flipud s1) + +pivotDown t n xs + | t == n = [] + | otherwise = y : pivotDown t (n+1) ys + where + y:ys = redu (pivot n xs) + + pivot k = (const k &&& id) + . sortBy (flip compare `on` (abs. (!k))) + + redu (k,x:zs) + | p == 0 = error "gauss: singular!" -- FIXME + | otherwise = u : map f zs + where + p = x!k + u = scale (recip (x!k)) x + f z = z - scale (z!k) u + redu (_,[]) = [] + + +pivotUp n xs + | n == -1 = [] + | otherwise = y : pivotUp (n-1) ys + where + y:ys = redu' (n,xs) + + redu' (k,x:zs) = u : map f zs + where + u = x + f z = z - scale (z!k) u + redu' (_,[]) = [] + +-------------------------------------------------------------------------------- + +gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]) + +gaussST (r,_) x = do + let n = r-1 + forM_ [0..n] $ \i -> do + c <- maxIndex . abs . flatten <$> extractRect x i n i i + swap x i (i+c) + a <- readMatrix x i i + scal x (recip a) i + forM_ [i+1..n] $ \j -> do + b <- readMatrix x j i + axpy x (-b) i j + forM_ [n,n-1..1] $ \i -> do + forM_ [i-1,i-2..0] $ \j -> do + b <- readMatrix x j i + axpy x (-b) i j + +-------------------------------------------------------------------------------- + instance Testable (Matrix I) where checkT _ = test diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 56e5053..c97f415 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -134,7 +134,7 @@ module Numeric.LinearAlgebra ( Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, -- * Misc - meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, + meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, gaussElim_1, gaussElim_2, ℝ,ℂ,iC, -- * Auxiliary classes Element, Container, Product, Numeric, LSDiv, diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 1a70663..84763fe 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs @@ -44,6 +44,7 @@ module Numeric.LinearAlgebra.Devel( -- ** Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, + axpy,scal,swap, extractRect, mutable, -- ** Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, -- cgit v1.2.3