From c77c83f1e442e5fff408d883b7aac5043ba512a9 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 27 May 2015 10:47:02 +0200 Subject: omat, AT, remap --- packages/base/src/C/lapack-aux.c | 28 ++++++++++++ packages/base/src/C/lapack-aux.h | 18 ++++++++ packages/base/src/Data/Packed/Internal/Matrix.hs | 52 +++++++++++++++++++--- packages/base/src/Data/Packed/Internal/Numeric.hs | 2 +- .../base/src/Data/Packed/Internal/Signatures.hs | 2 + packages/base/src/Data/Packed/Numeric.hs | 8 +++- packages/base/src/Numeric/LinearAlgebra/Data.hs | 2 +- 7 files changed, 103 insertions(+), 9 deletions(-) diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c index 77381cc..a977d5f 100644 --- a/packages/base/src/C/lapack-aux.c +++ b/packages/base/src/C/lapack-aux.c @@ -1670,3 +1670,31 @@ int extractI(int modei, int modej, int tm, KIVEC(i), KIVEC(j), KIMAT(m), IMAT(r) EXTRACT_IMP } +//////////////////////// remap ///////////////////////////////// + +#define REMAP_IMP \ + REQUIRES(ir==jr && ic==jc && ir==rr && ic==rc ,BAD_SIZE); \ + { TRAV(r,a,b) { AT(r,a,b) = AT(m,AT(i,a,b),AT(j,a,b)); } \ + } \ + OK + +int remapD(KOIMAT(i), KOIMAT(j), KODMAT(m), ODMAT(r)) { + REMAP_IMP +} + +int remapF(KOIMAT(i), KOIMAT(j), KOFMAT(m), OFMAT(r)) { + REMAP_IMP +} + +int remapI(KOIMAT(i), KOIMAT(j), KOIMAT(m), OIMAT(r)) { + REMAP_IMP +} + +int remapC(KOIMAT(i), KOIMAT(j), KOCMAT(m), OCMAT(r)) { + REMAP_IMP +} + +int remapQ(KOIMAT(i), KOIMAT(j), KOQMAT(m), OQMAT(r)) { + REMAP_IMP +} + diff --git a/packages/base/src/C/lapack-aux.h b/packages/base/src/C/lapack-aux.h index b49c7c9..6ffbef1 100644 --- a/packages/base/src/C/lapack-aux.h +++ b/packages/base/src/C/lapack-aux.h @@ -42,6 +42,7 @@ typedef short ftnlen; #define QVEC(A) int A##n, complex*A##p #define CVEC(A) int A##n, doublecomplex*A##p #define PVEC(A) int A##n, void* A##p, int A##s + #define IMAT(A) int A##r, int A##c, int* A##p #define FMAT(A) int A##r, int A##c, float* A##p #define DMAT(A) int A##r, int A##c, double* A##p @@ -49,12 +50,20 @@ typedef short ftnlen; #define CMAT(A) int A##r, int A##c, doublecomplex* A##p #define PMAT(A) int A##r, int A##c, void* A##p, int A##s +#define OIMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, int* A##p +#define OFMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, float* A##p +#define ODMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, double* A##p +#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p +#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p + + #define KIVEC(A) int A##n, const int*A##p #define KFVEC(A) int A##n, const float*A##p #define KDVEC(A) int A##n, const double*A##p #define KQVEC(A) int A##n, const complex*A##p #define KCVEC(A) int A##n, const doublecomplex*A##p #define KPVEC(A) int A##n, const void* A##p, int A##s + #define KIMAT(A) int A##r, int A##c, const int* A##p #define KFMAT(A) int A##r, int A##c, const float* A##p #define KDMAT(A) int A##r, int A##c, const double* A##p @@ -62,3 +71,12 @@ typedef short ftnlen; #define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p #define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s +#define KOIMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const int* A##p +#define KOFMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const float* A##p +#define KODMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const double* A##p +#define KOQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const complex* A##p +#define KOCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, const doublecomplex* A##p + +#define AT(m,i,j) (m##p[(i)*m##Xr + (j)*m##Xc]) +#define TRAV(m,i,j) int i,j; for (i=0;i MatrixOrder orderOf = order +stepRow :: Matrix t -> CInt +stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c +stepRow _ = 1 + +stepCol :: Matrix t -> CInt +stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r +stepCol _ = 1 + -- | Matrix transpose. trans :: Matrix t -> Matrix t @@ -128,6 +136,14 @@ mat a f = g (fi (rows a)) (fi (cols a)) p f m +omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b +omat a f = + unsafeWith (xdat a) $ \p -> do + let m g = do + g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p + f m + + {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. >>> flatten (ident 3) @@ -257,7 +273,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 >instance Element Foo -} class (Storable a) => Element a where - subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position + subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -> Matrix a subMatrixD = subMatrix' @@ -270,6 +286,7 @@ class (Storable a) => Element a where sortV :: Ord a => Vector a -> Vector a compareV :: Ord a => Vector a -> Vector a -> Vector CInt selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a + remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a instance Element Float where @@ -280,7 +297,7 @@ instance Element Float where sortV = sortValF compareV = compareF selectV = selectF - + remapM = remapF instance Element Double where transdata = transdataAux ctransR @@ -290,6 +307,7 @@ instance Element Double where sortV = sortValD compareV = compareD selectV = selectD + remapM = remapD instance Element (Complex Float) where @@ -300,6 +318,7 @@ instance Element (Complex Float) where sortV = undefined compareV = undefined selectV = selectQ + remapM = remapQ instance Element (Complex Double) where @@ -310,8 +329,8 @@ instance Element (Complex Double) where sortV = undefined compareV = undefined selectV = selectC + remapM = remapC - instance Element (CInt) where transdata = transdataAux ctransI constantD = constantAux cconstantI @@ -320,7 +339,7 @@ instance Element (CInt) where sortV = sortValI compareV = compareI selectV = selectI - + remapM = remapI ------------------------------------------------------------------- @@ -394,7 +413,7 @@ foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () - -- | Extracts a submatrix from a matrix. subMatrix :: Element a - => (Int,Int) -- ^ (r0,c0) starting position + => (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -- ^ input matrix -> Matrix a -- ^ result @@ -427,7 +446,7 @@ conformMs ms = map (conformMTo (r,c)) ms where r = maxZ (map rows ms) c = maxZ (map cols ms) - + conformVs vs = map (conformVTo n) vs where @@ -554,4 +573,25 @@ foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) +--------------------------------------------------------------------------- + +remapG f i j m = unsafePerformIO $ do + r <- createMatrix RowMajor (rows i) (cols i) + app4 f omat i omat j omat m omat r "remapG" + return r + +remapD = remapG c_remapD +remapF = remapG c_remapF +remapI = remapG c_remapI +remapC = remapG c_remapC +remapQ = remapG c_remapQ + +type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) + +foreign import ccall unsafe "remapD" c_remapD :: Rem Double +foreign import ccall unsafe "remapF" c_remapF :: Rem Float +foreign import ccall unsafe "remapI" c_remapI :: Rem CInt +foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) +foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) + diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index a241c48..67d047c 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs @@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric ( roundVector, fromInt, toInt, RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, - I, Extractor(..), (??), range, idxs, + I, Extractor(..), (??), range, idxs, remapM, module Data.Complex ) where diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs index 37dac16..5c54498 100644 --- a/packages/base/src/Data/Packed/Internal/Signatures.hs +++ b/packages/base/src/Data/Packed/Internal/Signatures.hs @@ -71,5 +71,7 @@ type TMMCVM = CInt -> CInt -> PD -> TMCVM -- type CM b r = CInt -> CInt -> Ptr b -> r type CV b r = CInt -> Ptr b -> r +type OM b r = CInt -> CInt -> CInt -> CInt -> Ptr b -> r + type CIdxs r = CV CInt r diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs index 906bc83..80f1718 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Data/Packed/Numeric.hs @@ -31,7 +31,7 @@ module Data.Packed.Numeric ( diag, ident, ctrans, -- * Generic operations - Container(..), Numeric, Extractor(..), (??), range, idxs, I, + Container(..), Numeric, Extractor(..), (??), range, idxs, I, remap, -- add, mul, sub, divide, equal, scaleRecip, addConstant, scalar, conj, scale, arctan2, cmap, cmod, atIndex, minIndex, maxIndex, minElement, maxElement, @@ -315,3 +315,9 @@ ccompare = ccompare' cselect :: (Container c t) => c I -> c t -> c t -> c t -> c t cselect = cselect' +remap :: Element t => Matrix I -> Matrix I -> Matrix t -> Matrix t +remap i j m + | minElement i >= 0 && maxElement i < fromIntegral (rows m) && + minElement j >= 0 && maxElement j < fromIntegral (cols m) = remapM i j m + | otherwise = error $ "out of range index in rmap" + diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 79dd06b..8bacb09 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs @@ -53,7 +53,7 @@ module Numeric.LinearAlgebra.Data( -- * Matrix extraction Extractor(..), (??), takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud, - + remap, -- * Block matrix fromBlocks, (|||), (===), diagBlock, repmat, toBlocks, toBlocksEvery, -- cgit v1.2.3