From c77c83f1e442e5fff408d883b7aac5043ba512a9 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 27 May 2015 10:47:02 +0200 Subject: omat, AT, remap --- packages/base/src/Data/Packed/Internal/Matrix.hs | 52 +++++++++++++++++++++--- 1 file changed, 46 insertions(+), 6 deletions(-) (limited to 'packages/base/src/Data/Packed/Internal/Matrix.hs') diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs index 82a9d8f..ddeddae 100644 --- a/packages/base/src/Data/Packed/Internal/Matrix.hs +++ b/packages/base/src/Data/Packed/Internal/Matrix.hs @@ -105,6 +105,14 @@ cols = icols orderOf :: Matrix t -> MatrixOrder orderOf = order +stepRow :: Matrix t -> CInt +stepRow Matrix {icols = c, order = RowMajor } = fromIntegral c +stepRow _ = 1 + +stepCol :: Matrix t -> CInt +stepCol Matrix {irows = r, order = ColumnMajor } = fromIntegral r +stepCol _ = 1 + -- | Matrix transpose. trans :: Matrix t -> Matrix t @@ -128,6 +136,14 @@ mat a f = g (fi (rows a)) (fi (cols a)) p f m +omat :: (Storable t) => Matrix t -> (((CInt -> CInt -> CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b +omat a f = + unsafeWith (xdat a) $ \p -> do + let m g = do + g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p + f m + + {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. >>> flatten (ident 3) @@ -257,7 +273,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 >instance Element Foo -} class (Storable a) => Element a where - subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position + subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -> Matrix a subMatrixD = subMatrix' @@ -270,6 +286,7 @@ class (Storable a) => Element a where sortV :: Ord a => Vector a -> Vector a compareV :: Ord a => Vector a -> Vector a -> Vector CInt selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a + remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a instance Element Float where @@ -280,7 +297,7 @@ instance Element Float where sortV = sortValF compareV = compareF selectV = selectF - + remapM = remapF instance Element Double where transdata = transdataAux ctransR @@ -290,6 +307,7 @@ instance Element Double where sortV = sortValD compareV = compareD selectV = selectD + remapM = remapD instance Element (Complex Float) where @@ -300,6 +318,7 @@ instance Element (Complex Float) where sortV = undefined compareV = undefined selectV = selectQ + remapM = remapQ instance Element (Complex Double) where @@ -310,8 +329,8 @@ instance Element (Complex Double) where sortV = undefined compareV = undefined selectV = selectC + remapM = remapC - instance Element (CInt) where transdata = transdataAux ctransI constantD = constantAux cconstantI @@ -320,7 +339,7 @@ instance Element (CInt) where sortV = sortValI compareV = compareI selectV = selectI - + remapM = remapI ------------------------------------------------------------------- @@ -394,7 +413,7 @@ foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () - -- | Extracts a submatrix from a matrix. subMatrix :: Element a - => (Int,Int) -- ^ (r0,c0) starting position + => (Int,Int) -- ^ (r0,c0) starting position -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix -> Matrix a -- ^ input matrix -> Matrix a -- ^ result @@ -427,7 +446,7 @@ conformMs ms = map (conformMTo (r,c)) ms where r = maxZ (map rows ms) c = maxZ (map cols ms) - + conformVs vs = map (conformVTo n) vs where @@ -554,4 +573,25 @@ foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) +--------------------------------------------------------------------------- + +remapG f i j m = unsafePerformIO $ do + r <- createMatrix RowMajor (rows i) (cols i) + app4 f omat i omat j omat m omat r "remapG" + return r + +remapD = remapG c_remapD +remapF = remapG c_remapF +remapI = remapG c_remapI +remapC = remapG c_remapC +remapQ = remapG c_remapQ + +type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) + +foreign import ccall unsafe "remapD" c_remapD :: Rem Double +foreign import ccall unsafe "remapF" c_remapF :: Rem Float +foreign import ccall unsafe "remapI" c_remapI :: Rem CInt +foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) +foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) + -- cgit v1.2.3