From e635f3889aed9b4bf7ef02c98945e9065d114df3 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 22 May 2015 12:00:42 +0200 Subject: extraction modes --- packages/base/src/Data/Packed/Internal/Matrix.hs | 20 +++++---- packages/base/src/Data/Packed/Internal/Numeric.hs | 53 +++++++++++++++++++++-- 2 files changed, 61 insertions(+), 12 deletions(-) (limited to 'packages/base/src/Data/Packed') diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index be5fb03..1aee7d3 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs @@ -265,7 +265,7 @@ class (Storable a) => Element a where transdata = transdataP -- transdata' constantD :: a -> Int -> Vector a constantD = constantP -- constant' - extractR :: Matrix a -> Idxs -> Matrix a + extractR :: Matrix a -> CInt -> Idxs -> Matrix a instance Element Float where transdata = transdataAux ctransF @@ -444,23 +444,25 @@ tt x@Matrix{order = ColumnMajor} = trans x tt x@Matrix{order = RowMajor} = x --extractAux :: Matrix Double -> Idxs -> Matrix Double -extractAux f m v = unsafePerformIO $ do - r <- createMatrix RowMajor (dim v) (cols m) - app3 (f (isT m)) vec v mat (tt m) mat r "extractAux" +extractAux f m mode v = unsafePerformIO $ do + let nr | mode == 0 = fromIntegral $ max 0 (v@>1 - v@>0 + 1) + | otherwise = dim v + r <- createMatrix RowMajor nr (cols m) + app3 (f mode (isT m)) vec v mat (tt m) mat r "extractAux" return r foreign import ccall unsafe "extractRD" c_extractRD - :: CInt -> CIdxs (CM Double (CM Double (IO CInt))) + :: CInt -> CInt -> CIdxs (CM Double (CM Double (IO CInt))) foreign import ccall unsafe "extractRF" c_extractRF - :: CInt -> CIdxs (CM Float (CM Float (IO CInt))) + :: CInt -> CInt -> CIdxs (CM Float (CM Float (IO CInt))) foreign import ccall unsafe "extractRC" c_extractRC - :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) + :: CInt -> CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) foreign import ccall unsafe "extractRQ" c_extractRQ - :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) + :: CInt -> CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) foreign import ccall unsafe "extractRI" c_extractRI - :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) + :: CInt -> CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 9b6b55b..f1b4898 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs @@ -39,7 +39,7 @@ module Data.Packed.Internal.Numeric ( roundVector, RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, - CInt, + CInt, Extractor(..), (??),(¿¿), module Data.Complex ) where @@ -53,6 +53,7 @@ import Data.Complex import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ) import Data.Packed.Internal import Foreign.C.Types(CInt) +import Text.Printf(printf) ------------------------------------------------------------------- @@ -66,8 +67,53 @@ type family ArgOf (c :: * -> *) a type instance ArgOf Vector a = a -> a type instance ArgOf Matrix a = a -> a -> a +-------------------------------------------------------------------------- + +data Extractor = All | Range Int Int | At [Int] | AtCyc [Int] | Take Int | Drop Int + +idxs js = fromList (map fromIntegral js) :: Idxs + +infixl 9 ??, ¿¿ +(??),(¿¿) :: Element t => Matrix t -> Extractor -> Matrix t + +m ?? All = m +m ?? Take 0 = (0>= rows m = m +m ?? Drop 0 = m +m ?? Drop n | abs n >= rows m = (0> b = m ?? Take 0 +m ?? Range a b | a < 0 || b >= cols m = error $ + printf "can't extract rows %d to %d from matrix %dx%d" a b (rows m) (cols m) +m ?? At ps | minimum ps < 0 || maximum ps >= rows m = error $ + printf "can't extract rows %s from matrix %dx%d" (show ps) (rows m) (cols m) + +m ?? er = extractR m mode js + where + (mode,js) = mkExt (rows m) er + ran a b = (0, idxs [a,b]) + pos ks = (1, idxs ks) + mkExt _ (At ks) = pos ks + mkExt n (AtCyc ks) = pos (map (`mod` n) ks) + mkExt n All = ran 0 (n-1) + mkExt _ (Range mn mx) = ran mn mx + mkExt n (Take k) + | k >= 0 = ran 0 (k-1) + | otherwise = mkExt n (Drop (n+k)) + mkExt n (Drop k) + | k >= 0 = ran k (n-1) + | otherwise = mkExt n (Take (n+k)) + + +m ¿¿ Range a b | a < 0 || b > cols m -1 = error $ + printf "can't extract columns %d to %d from matrix %dx%d" a b (rows m) (cols m) + +m ¿¿ At ps | minimum ps < 0 || maximum ps >= cols m = error $ + printf "can't extract columns %s from matrix %dx%d" (show ps) (rows m) (cols m) +m ¿¿ ec = trans (trans m ?? ec) + ------------------------------------------------------------------- + -- | Basic element-by-element functions for numeric containers class Element e => SContainer c e where @@ -123,6 +169,7 @@ class (Complexable c, Fractional e, SContainer c e) => Container c e -- element by element inverse tangent arctan2' :: c e -> c e -> c e + -------------------------------------------------------------------------- instance SContainer Vector CInt @@ -245,14 +292,14 @@ instance SContainer Vector (Complex Double) accum' = accumV cond' = undefined -- cannot match - + instance Container Vector (Complex Double) where scaleRecip = vectorMapValC Recip divide = vectorZipC Div arctan2' = vectorZipC ATan2 conj' = conjugateC - + instance SContainer Vector (Complex Float) where -- cgit v1.2.3