From c5795a191ded450987a30302c1d1fa4a265350ff Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 27 May 2015 09:10:22 +0200 Subject: ccompare, cselect, toInt --- packages/base/src/C/lapack-aux.c | 7 +++ packages/base/src/Data/Packed/Internal/Matrix.hs | 53 +++++++++++++++++++++++ packages/base/src/Data/Packed/Internal/Numeric.hs | 44 ++++++++++++++++++- packages/base/src/Data/Packed/Numeric.hs | 11 +++-- packages/base/src/Numeric/LinearAlgebra/Data.hs | 7 +-- packages/base/src/Numeric/LinearAlgebra/Util.hs | 1 - 6 files changed, 115 insertions(+), 8 deletions(-) diff --git a/packages/base/src/C/lapack-aux.c b/packages/base/src/C/lapack-aux.c index af515ca..77381cc 100644 --- a/packages/base/src/C/lapack-aux.c +++ b/packages/base/src/C/lapack-aux.c @@ -1619,6 +1619,13 @@ int chooseI(KIVEC(cond),KIVEC(lt),KIVEC(eq),KIVEC(gt),IVEC(r)) { CHOOSE_IMP } +int chooseC(KIVEC(cond),KCVEC(lt),KCVEC(eq),KCVEC(gt),CVEC(r)) { + CHOOSE_IMP +} + +int chooseQ(KIVEC(cond),KQVEC(lt),KQVEC(eq),KQVEC(gt),QVEC(r)) { + CHOOSE_IMP +} //////////////////////// extract ///////////////////////////////// diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 1679ea6..82a9d8f 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs @@ -268,6 +268,9 @@ class (Storable a) => Element a where extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a sortI :: Ord a => Vector a -> Vector CInt sortV :: Ord a => Vector a -> Vector a + compareV :: Ord a => Vector a -> Vector a -> Vector CInt + selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a + instance Element Float where transdata = transdataAux ctransF @@ -275,6 +278,9 @@ instance Element Float where extractR = extractAux c_extractF sortI = sortIdxF sortV = sortValF + compareV = compareF + selectV = selectF + instance Element Double where transdata = transdataAux ctransR @@ -282,6 +288,9 @@ instance Element Double where extractR = extractAux c_extractD sortI = sortIdxD sortV = sortValD + compareV = compareD + selectV = selectD + instance Element (Complex Float) where transdata = transdataAux ctransQ @@ -289,6 +298,9 @@ instance Element (Complex Float) where extractR = extractAux c_extractQ sortI = undefined sortV = undefined + compareV = undefined + selectV = selectQ + instance Element (Complex Double) where transdata = transdataAux ctransC @@ -296,6 +308,9 @@ instance Element (Complex Double) where extractR = extractAux c_extractC sortI = undefined sortV = undefined + compareV = undefined + selectV = selectC + instance Element (CInt) where transdata = transdataAux ctransI @@ -303,6 +318,9 @@ instance Element (CInt) where extractR = extractAux c_extractI sortI = sortIdxI sortV = sortValI + compareV = compareI + selectV = selectI + ------------------------------------------------------------------- @@ -502,3 +520,38 @@ foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO -------------------------------------------------------------------------------- +compareG f u v = unsafePerformIO $ do + r <- createVector (dim v) + app3 f vec u vec v vec r "compareG" + return r + +compareD = compareG c_compareD +compareF = compareG c_compareF +compareI = compareG c_compareI + +foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) +foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) +foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) + +-------------------------------------------------------------------------------- + +selectG f c u v w = unsafePerformIO $ do + r <- createVector (dim v) + app5 f vec c vec u vec v vec w vec r "selectG" + return r + +selectD = selectG c_selectD +selectF = selectG c_selectF +selectI = selectG c_selectI +selectC = selectG c_selectC +selectQ = selectG c_selectQ + +type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) + +foreign import ccall unsafe "chooseD" c_selectD :: Sel Double +foreign import ccall unsafe "chooseF" c_selectF :: Sel Float +foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt +foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) +foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) + + diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 51bee5c..a241c48 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs @@ -36,7 +36,7 @@ module Data.Packed.Internal.Numeric ( Convert(..), Complexable(), RealElement(), - roundVector, fromInt, + roundVector, fromInt, toInt, RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, I, Extractor(..), (??), range, idxs, @@ -171,6 +171,8 @@ class Element e => Container c e -> c e -- ^ e -> c e -- ^ g -> c e -- ^ result + ccompare' :: Ord e => c e -> c e -> c I + cselect' :: c I -> c e -> c e -> c e -> c e find' :: (e -> Bool) -> c e -> [IndexOf c] assoc' :: IndexOf c -- ^ size -> e -- ^ default value @@ -192,6 +194,7 @@ class Element e => Container c e arctan2' :: Fractional e => c e -> c e -> c e cmod' :: Integral e => e -> c e -> c e fromInt' :: c I -> c e + toInt' :: c e -> c I -------------------------------------------------------------------------- @@ -222,6 +225,8 @@ instance Container Vector I assoc' = assocV accum' = accumV cond' = condV condI + ccompare' = compareCV compareV + cselect' = selectCV selectV scaleRecip = undefined -- cannot match divide = undefined arctan2' = undefined @@ -229,6 +234,7 @@ instance Container Vector I | m /= 0 = vectorMapValI ModVS m x | otherwise = error $ "cmod 0 on vector of size "++(show $ dim x) fromInt' = id + toInt' = id instance Container Vector Float where @@ -256,11 +262,14 @@ instance Container Vector Float assoc' = assocV accum' = accumV cond' = condV condF + ccompare' = compareCV compareV + cselect' = selectCV selectV scaleRecip = vectorMapValF Recip divide = vectorZipF Div arctan2' = vectorZipF ATan2 cmod' = undefined fromInt' = int2floatV + toInt' = float2IntV @@ -290,11 +299,14 @@ instance Container Vector Double assoc' = assocV accum' = accumV cond' = condV condD + ccompare' = compareCV compareV + cselect' = selectCV selectV scaleRecip = vectorMapValR Recip divide = vectorZipR Div arctan2' = vectorZipR ATan2 cmod' = undefined fromInt' = int2DoubleV + toInt' = double2IntV instance Container Vector (Complex Double) @@ -323,11 +335,14 @@ instance Container Vector (Complex Double) assoc' = assocV accum' = accumV cond' = undefined -- cannot match + ccompare' = undefined + cselect' = selectCV selectV scaleRecip = vectorMapValC Recip divide = vectorZipC Div arctan2' = vectorZipC ATan2 cmod' = undefined fromInt' = complex . int2DoubleV + toInt' = toInt' . fst . fromComplex instance Container Vector (Complex Float) where @@ -355,11 +370,14 @@ instance Container Vector (Complex Float) assoc' = assocV accum' = accumV cond' = undefined -- cannot match + ccompare' = undefined + cselect' = selectCV selectV scaleRecip = vectorMapValQ Recip divide = vectorZipQ Div arctan2' = vectorZipQ ATan2 cmod' = undefined fromInt' = complex . int2floatV + toInt' = toInt' . fst . fromComplex --------------------------------------------------------------- @@ -391,6 +409,8 @@ instance (Num a, Element a, Container Vector a) => Container Matrix a assoc' = assocM accum' = accumM cond' = condM + ccompare' = compareM + cselect' = selectM scaleRecip x = liftMatrix (scaleRecip x) divide = liftMatrix2 divide arctan2' = liftMatrix2 arctan2' @@ -398,6 +418,7 @@ instance (Num a, Element a, Container Vector a) => Container Matrix a | m /= 0 = liftMatrix (cmod' m) x | otherwise = error $ "cmod 0 on matrix "++shSize x fromInt' = liftMatrix fromInt' + toInt' = liftMatrix toInt' emptyErrorV msg f v = @@ -448,6 +469,9 @@ cmod m = cmod' (fromIntegral m) fromInt :: (Container c e) => c I -> c e fromInt = fromInt' +toInt :: (Container c e) => c e -> c I +toInt = toInt' + -- | like 'fmap' (cannot implement instance Functor because of Element class constraint) cmap :: (Element b, Container c e) => (e -> b) -> c e -> c b @@ -852,6 +876,24 @@ condV f a b l e t = f a' b' l' e' t' where [a', b', l', e', t'] = conformVs [a,b,l,e,t] +compareM a b = matrixFromVector RowMajor (rows a'') (cols a'') $ ccompare' a' b' + where + args@(a'':_) = conformMs [a,b] + [a', b'] = map flatten args + +compareCV f a b = f a' b' + where + [a', b'] = conformVs [a,b] + +selectM c l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cselect' (toInt c') l' e' t' + where + args@(a'':_) = conformMs [fromInt c,l,e,t] + [c', l', e', t'] = map flatten args + +selectCV f c l e t = f (toInt c') l' e' t' + where + [c', l', e', t'] = conformVs [fromInt c,l,e,t] + -------------------------------------------------------------------------------- class Transposable m mt | m -> mt, mt -> m diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs index cb449a9..906bc83 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Data/Packed/Numeric.hs @@ -31,12 +31,12 @@ module Data.Packed.Numeric ( diag, ident, ctrans, -- * Generic operations - Container(..), Numeric, + Container(..), Numeric, Extractor(..), (??), range, idxs, I, -- add, mul, sub, divide, equal, scaleRecip, addConstant, scalar, conj, scale, arctan2, cmap, cmod, atIndex, minIndex, maxIndex, minElement, maxElement, sumElements, prodElements, - step, cond, find, assoc, accum, + step, cond, find, assoc, accum, ccompare, cselect, Transposable(..), Linear(..), -- * Matrix product Product(..), udot, dot, (<·>), (#>), (<#), app, @@ -58,7 +58,7 @@ module Data.Packed.Numeric ( Complexable(), RealElement(), RealOf, ComplexOf, SingleOf, DoubleOf, - roundVector, + roundVector,fromInt,toInt, IndexOf, module Data.Complex, -- * IO @@ -309,4 +309,9 @@ sortVector = sortV sortIndex :: (Ord t, Element t) => Vector t -> Vector I sortIndex = sortI +ccompare :: (Ord t, Container c t) => c t -> c t -> c I +ccompare = ccompare' + +cselect :: (Container c t) => c I -> c t -> c t -> c t -> c t +cselect = cselect' diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 2aac2e4..79dd06b 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs @@ -59,7 +59,9 @@ module Numeric.LinearAlgebra.Data( fromBlocks, (|||), (===), diagBlock, repmat, toBlocks, toBlocksEvery, -- * Mapping functions - conj, cmap, cmod, step, cond, + conj, cmap, cmod, + + step, cond, ccompare, cselect, -- * Find elements find, maxIndex, minIndex, maxElement, minElement, @@ -78,7 +80,7 @@ module Numeric.LinearAlgebra.Data( -- * Element conversion Convert(..), roundVector, - fromInt, + fromInt,toInt, -- * Misc arctan2, separable, @@ -95,6 +97,5 @@ import Data.Packed.Numeric import Numeric.LinearAlgebra.Util hiding ((&),(#)) import Data.Complex import Numeric.Sparse -import Data.Packed.Internal.Numeric(I,Extractor(..),(??),fromInt,range,idxs) diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs index eadd2a2..779630f 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs @@ -66,7 +66,6 @@ import Control.Monad(when) import Text.Printf import Data.List.Split(splitOn) import Data.List(intercalate) -import Data.Packed.Internal.Numeric(I) type ℝ = Double type ℕ = Int -- cgit v1.2.3