From 8ede2ed162f3d00172ee3fa4835e3ee2184bcd99 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sun, 24 May 2015 12:45:23 +0200 Subject: joint extractor of rows and columns --- packages/base/src/Data/Packed/Internal/Matrix.hs | 44 +++++++-------- packages/base/src/Data/Packed/Internal/Numeric.hs | 66 ++++++++++++++--------- 2 files changed, 63 insertions(+), 47 deletions(-) (limited to 'packages/base/src/Data/Packed') diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 1aee7d3..76d2204 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs @@ -265,32 +265,32 @@ class (Storable a) => Element a where transdata = transdataP -- transdata' constantD :: a -> Int -> Vector a constantD = constantP -- constant' - extractR :: Matrix a -> CInt -> Idxs -> Matrix a + extractR :: Matrix a -> CInt -> Idxs -> CInt -> Idxs -> Matrix a instance Element Float where transdata = transdataAux ctransF constantD = constantAux cconstantF - extractR = extractAux c_extractRF + extractR = extractAux c_extractF instance Element Double where transdata = transdataAux ctransR constantD = constantAux cconstantR - extractR = extractAux c_extractRD + extractR = extractAux c_extractD instance Element (Complex Float) where transdata = transdataAux ctransQ constantD = constantAux cconstantQ - extractR = extractAux c_extractRQ + extractR = extractAux c_extractQ instance Element (Complex Double) where transdata = transdataAux ctransC constantD = constantAux cconstantC - extractR = extractAux c_extractRC + extractR = extractAux c_extractC instance Element (CInt) where transdata = transdataAux ctransI constantD = constantAux cconstantI - extractR = extractAux c_extractRI + extractR = extractAux c_extractI ------------------------------------------------------------------- @@ -443,26 +443,26 @@ isT Matrix{order = RowMajor} = 0 tt x@Matrix{order = ColumnMajor} = trans x tt x@Matrix{order = RowMajor} = x ---extractAux :: Matrix Double -> Idxs -> Matrix Double -extractAux f m mode v = unsafePerformIO $ do - let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) - | otherwise = dim v - r <- createMatrix RowMajor nr (cols m) - app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" + +extractAux f m moder vr modec vc = unsafePerformIO $ do + let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr + nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc + r <- createMatrix RowMajor nr nc + app4 (f moder modec (isT m)) vec vr vec vc mat (tt m) mat r "extractAux" return r -foreign import ccall unsafe "extractRD" c_extractRD - :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) +foreign import ccall unsafe "extractD" c_extractD + :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Double (CM Double (IO CInt)))) -foreign import ccall unsafe "extractRF" c_extractRF - :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) +foreign import ccall unsafe "extractF" c_extractF + :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM Float (CM Float (IO CInt)))) -foreign import ccall unsafe "extractRC" c_extractRC - :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) +foreign import ccall unsafe "extractC" c_extractC + :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt)))) -foreign import ccall unsafe "extractRQ" c_extractRQ - :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) +foreign import ccall unsafe "extractQ" c_extractQ + :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt)))) -foreign import ccall unsafe "extractRI" c_extractRI - :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) +foreign import ccall unsafe "extractI" c_extractI + :: CInt -> CInt -> CInt -> CIdxs (CIdxs (CM CInt (CM CInt (IO CInt)))) diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 00ec70c..353877a 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs @@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric ( roundVector, RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, - CInt, Extractor(..), (??),(¿¿), + CInt, Extractor(..), (??), module Data.Complex ) where @@ -69,27 +69,50 @@ type instance ArgOf Matrix a = a -> a -> a -------------------------------------------------------------------------- -data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int +data Extractor + = All + | Range Int Int + | At [Int] + | AtCyc [Int] + | Take Int + | Drop Int + deriving Show idxs js = fromList (map fromIntegral js) :: Idxs -infixl 9 ??, ¿¿ -(??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t - -m ?? All = m -m ?? Take 0 = (0>= rows m = m -m ?? Drop 0 = m -m ?? Drop n | abs n >= rows m = (0> b = m ?? Take 0 -m ?? Range a b | a < 0 || b >= cols m = error $ - printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) -m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ - printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) - -m ?? er = extractR m mode js +infixl 9 ?? +(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t + + +extractError m e = error $ printf "can't extract %s from matrix %dx%d" (show e) (rows m) (cols m) + +m ?? e@(Range a b,_) | a < 0 || b >= rows m = extractError m e +m ?? e@(_,Range a b) | a < 0 || b >= cols m = extractError m e +m ?? e@(At ps,_) | minimum ps < 0 || maximum ps >= rows m = extractError m e +m ?? e@(_,At ps) | minimum ps < 0 || maximum ps >= cols m = extractError m e + +m ?? (All,All) = m + +m ?? (Range a b,e) | a > b = m ?? (Take 0,e) +m ?? (e,Range a b) | a > b = m ?? (e,Take 0) + +m ?? (Take 0,e) = (0><0) [] ?? (e,All) + +m ?? (Take n,e) | abs n > rows m = m ?? (All,e) +m ?? (e,Take n) | abs n > cols m = m ?? (e,All) + +m ?? (Drop 0,e) = m ?? (All,e) +m ?? (e,Drop 0) = m ?? (e,All) + +m ?? (Drop n,e) | abs n > rows m = m ?? (Take 0,e) +m ?? (e,Drop n) | abs n > cols m = m ?? (e,Take 0) + + +m ?? (er,ec) = extractR m moder rs modec cs where - (mode,js) = mkExt (rows m) er + (moder,rs) = mkExt (rows m) er + (modec,cs) = mkExt (cols m) ec ran a b = (0, idxs [a,b]) pos ks = (1, idxs ks) mkExt _ (At ks) = pos ks @@ -104,13 +127,6 @@ m ?? er = extractR m mode js | otherwise = mkExt n (Take (n+k)) -m ¿¿ Range a b | a < 0 || b > cols m -1 = error $ - printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m) - -m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $ - printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m) -m ¿¿ ec = trans (trans m ?? ec) - ------------------------------------------------------------------- -- cgit v1.2.3