From 4b3e29097aa272d429f8005fe17b459cf0c049c8 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 12 Jun 2015 20:58:13 +0200 Subject: row ops in ST --- packages/base/src/Internal/Matrix.hs | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8de06ce..fa1aad6 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -20,6 +20,7 @@ import Internal.Vector import Internal.Devel import Internal.Vectorized import Foreign.Marshal.Alloc ( free ) +import Foreign.Marshal.Array(newArray) import Foreign.Ptr ( Ptr ) import Foreign.Storable ( Storable ) import Data.Complex ( Complex ) @@ -273,12 +274,13 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 class (Storable a) => Element a where transdata :: Int -> Vector a -> Int -> Vector a constantD :: a -> Int -> Vector a - extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a + extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) sortI :: Ord a => Vector a -> Vector CInt sortV :: Ord a => Vector a -> Vector a compareV :: Ord a => Vector a -> Vector a -> Vector CInt selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a + rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () instance Element Float where @@ -290,6 +292,7 @@ instance Element Float where compareV = compareF selectV = selectF remapM = remapF + rowOp = rowOpAux c_rowOpF instance Element Double where transdata = transdataAux ctransR @@ -300,6 +303,7 @@ instance Element Double where compareV = compareD selectV = selectD remapM = remapD + rowOp = rowOpAux c_rowOpD instance Element (Complex Float) where @@ -311,6 +315,7 @@ instance Element (Complex Float) where compareV = undefined selectV = selectQ remapM = remapQ + rowOp = rowOpAux c_rowOpQ instance Element (Complex Double) where @@ -322,6 +327,7 @@ instance Element (Complex Double) where compareV = undefined selectV = selectC remapM = remapC + rowOp = rowOpAux c_rowOpC instance Element (CInt) where transdata = transdataAux ctransI @@ -332,6 +338,7 @@ instance Element (CInt) where compareV = compareI selectV = selectI remapM = remapI + rowOp = rowOpAux c_rowOpI instance Element Z where transdata = transdataAux ctransL @@ -342,6 +349,7 @@ instance Element Z where compareV = compareL selectV = selectL remapM = remapL + rowOp = rowOpAux c_rowOpL ------------------------------------------------------------------- @@ -379,7 +387,7 @@ subMatrix :: Element a -> Matrix a -- ^ result subMatrix (r0,c0) (rt,ct) m | 0 <= r0 && 0 <= rt && r0+rt <= rows m && - 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) + 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) | otherwise = error $ "wrong subMatrix "++ show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) @@ -430,7 +438,7 @@ instance (Storable t, NFData t) => NFData (Matrix t) --------------------------------------------------------------- -extractAux f m moder vr modec vc = unsafePerformIO $ do +extractAux f m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix RowMajor nr nc @@ -538,6 +546,24 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z -------------------------------------------------------------------------------- +rowOpAux f c x i1 i2 j1 j2 m = do + px <- newArray [x] + app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" + free px + +type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok + +foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R +foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float +foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C +foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) +foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I +foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z +foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I +foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z + +-------------------------------------------------------------------------------- + foreign import ccall unsafe "saveMatrix" c_saveMatrix :: CString -> CString -> Double ..> Ok -- cgit v1.2.3