From 46178222d272a85220bc86b221aa3166edd5bd4a Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Thu, 21 May 2015 20:47:59 +0200 Subject: CInt elements, wip --- packages/base/src/Data/Packed/Internal/Matrix.hs | 45 +++++++++++++++++++++- .../base/src/Data/Packed/Internal/Signatures.hs | 7 +++- packages/base/src/Data/Packed/Internal/Vector.hs | 6 ++- 3 files changed, 55 insertions(+), 3 deletions(-) (limited to 'packages/base/src/Data') diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 150b978..be5fb03 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs @@ -265,23 +265,33 @@ class (Storable a) => Element a where transdata = transdataP -- transdata' constantD :: a -> Int -> Vector a constantD = constantP -- constant' - + extractR :: Matrix a -> Idxs -> Matrix a instance Element Float where transdata = transdataAux ctransF constantD = constantAux cconstantF + extractR = extractAux c_extractRF instance Element Double where transdata = transdataAux ctransR constantD = constantAux cconstantR + extractR = extractAux c_extractRD instance Element (Complex Float) where transdata = transdataAux ctransQ constantD = constantAux cconstantQ + extractR = extractAux c_extractRQ instance Element (Complex Double) where transdata = transdataAux ctransC constantD = constantAux cconstantC + extractR = extractAux c_extractRC + +instance Element (CInt) where + transdata = transdataAux ctransI + constantD = constantAux cconstantI + extractR = extractAux c_extractRI + ------------------------------------------------------------------- @@ -289,6 +299,7 @@ transdataAux fun c1 d c2 = if noneed then d else unsafePerformIO $ do + -- putStrLn "T" v <- createVector (dim d) unsafeWith d $ \pd -> unsafeWith v $ \pv -> @@ -317,6 +328,7 @@ foreign import ccall unsafe "transF" ctransF :: TFMFM foreign import ccall unsafe "transR" ctransR :: TMM foreign import ccall unsafe "transQ" ctransQ :: TQMQM foreign import ccall unsafe "transC" ctransC :: TCMCM +foreign import ccall unsafe "transI" ctransI :: CM CInt (CM CInt (IO CInt)) foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt ---------------------------------------------------------------------- @@ -336,6 +348,8 @@ foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV +foreign import ccall unsafe "constantI" cconstantI :: Ptr CInt -> CV CInt (IO CInt) + constantP :: Storable a => a -> Int -> Vector a constantP a n = unsafePerformIO $ do let sz = sizeOf a @@ -421,3 +435,32 @@ instance (Storable t, NFData t) => NFData (Matrix t) d = dim v v = xdat m +--------------------------------------------------------------- + +isT Matrix{order = ColumnMajor} = 1 +isT Matrix{order = RowMajor} = 0 + +tt x@Matrix{order = ColumnMajor} = trans x +tt x@Matrix{order = RowMajor} = x + +--extractAux :: Matrix Double -> Idxs -> Matrix Double +extractAux f m v = unsafePerformIO $ do + r <- createMatrix RowMajor (dim v) (cols m) + app3 (f (isT m)) vec v mat (tt m) mat r "extractAux" + return r + +foreign import ccall unsafe "extractRD" c_extractRD + :: CInt -> CIdxs (CM Double (CM Double (IO CInt))) + +foreign import ccall unsafe "extractRF" c_extractRF + :: CInt -> CIdxs (CM Float (CM Float (IO CInt))) + +foreign import ccall unsafe "extractRC" c_extractRC + :: CInt -> CIdxs (CM (Complex Double) (CM (Complex Double) (IO CInt))) + +foreign import ccall unsafe "extractRQ" c_extractRQ + :: CInt -> CIdxs (CM (Complex Float) (CM (Complex Float) (IO CInt))) + +foreign import ccall unsafe "extractRI" c_extractRI + :: CInt -> CIdxs (CM CInt (CM CInt (IO CInt))) + diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs index acc3070..37dac16 100644 --- a/packages/base/src/Data/Packed/Internal/Signatures.hs +++ b/packages/base/src/Data/Packed/Internal/Signatures.hs @@ -1,6 +1,6 @@ -- | -- Module : Data.Packed.Internal.Signatures --- Copyright : (c) Alberto Ruiz 2009 +-- Copyright : (c) Alberto Ruiz 2009-15 -- License : BSD3 -- Maintainer : Alberto Ruiz -- Stability : provisional @@ -68,3 +68,8 @@ type TCVM = CInt -> PC -> TM -- type TMCVM = CInt -> CInt -> PD -> TCVM -- type TMMCVM = CInt -> CInt -> PD -> TMCVM -- +type CM b r = CInt -> CInt -> Ptr b -> r +type CV b r = CInt -> Ptr b -> r + +type CIdxs r = CV CInt r + diff --git a/packages/base/src/Data/Packed/Internal/Vector.hs b/packages/base/src/Data/Packed/Internal/Vector.hs index b49f379..2a6ed2c 100644 --- a/packages/base/src/Data/Packed/Internal/Vector.hs +++ b/packages/base/src/Data/Packed/Internal/Vector.hs @@ -24,7 +24,8 @@ module Data.Packed.Internal.Vector ( cloneVector, unsafeToForeignPtr, unsafeFromForeignPtr, - unsafeWith + unsafeWith, + Idxs ) where import Data.Packed.Internal.Common @@ -56,6 +57,8 @@ import Data.Vector.Storable(Vector, unsafeWith) +type Idxs = Vector CInt + -- | Number of elements dim :: (Storable t) => Vector t -> Int dim = Vector.length @@ -243,6 +246,7 @@ double2FloatV v = unsafePerformIO $ do foreign import ccall unsafe "float2double" c_float2double:: TFV foreign import ccall unsafe "double2float" c_double2float:: TVF + --------------------------------------------------------------- stepF :: Vector Float -> Vector Float -- cgit v1.2.3