From 59cb364ebd7bff09a19f5f83104752a14f6a5177 Mon Sep 17 00:00:00 2001 From: exfalso <0slemi0@gmail.com> Date: Fri, 7 Oct 2016 16:49:57 +0100 Subject: Redefine (#) --- packages/base/src/Internal/Matrix.hs | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index c47c625..0135288 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -22,7 +22,7 @@ module Internal.Matrix where import Internal.Vector import Internal.Devel -import Internal.Vectorized hiding ((#)) +import Internal.Vectorized hiding ((#), (#!)) import Foreign.Marshal.Alloc ( free ) import Foreign.Marshal.Array(newArray) import Foreign.Ptr ( Ptr ) @@ -110,15 +110,15 @@ fmat m -- C-Haskell matrix adapters {-# INLINE amatr #-} -amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b -amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) +amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r +amatr x f g = unsafeWith (xdat x) (f . g r c) where r = fi (rows x) c = fi (cols x) {-# INLINE amat #-} -amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b -amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) +amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r +amat x f g = unsafeWith (xdat x) (f . g r c sr sc) where r = fi (rows x) c = fi (cols x) @@ -135,10 +135,13 @@ instance Storable t => TransArray (Matrix t) applyRaw = amatr {-# INLINE applyRaw #-} -infixl 1 # +infixr 1 # a # b = apply a b {-# INLINE (#) #-} +a #! b = a # b # id +{-# INLINE (#!) #-} + -------------------------------------------------------------------------------- copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) @@ -426,7 +429,8 @@ extractAux f ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix ord nr nc - f moder modec # vr # vc # m # r #|"extract" + (vr # vc # m #! r) (f moder modec) #|"extract" + return r type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) @@ -440,7 +444,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z --------------------------------------------------------------- -setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" +setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" type SetRect x = I -> I -> x ::> x::> Ok @@ -455,7 +459,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z sortG f v = unsafePerformIO $ do r <- createVector (dim v) - f # v # r #|"sortG" + (v #! r) f #|"sortG" return r sortIdxD = sortG c_sort_indexD @@ -482,7 +486,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok compareG f u v = unsafePerformIO $ do r <- createVector (dim v) - f # u # v # r #|"compareG" + (u # v #! r) f #|"compareG" return r compareD = compareG c_compareD @@ -499,7 +503,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) - f # c # u # v # w # r #|"selectG" + (c # u # v # w #! r) f #|"selectG" return r selectD = selectG c_selectD @@ -522,7 +526,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) - f # i # j # m # r #|"remapG" + (i # j # m #! r) f #|"remapG" return r remapD = remapG c_remapD @@ -545,7 +549,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] - f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" + (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" free px type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok @@ -561,7 +565,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- -gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" +gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" type Tgemm x = x :> x ::> x ::> x ::> Ok @@ -589,10 +593,9 @@ saveMatrix saveMatrix name format m = do cname <- newCString name cformat <- newCString format - c_saveMatrix cname cformat # m #|"saveMatrix" + (m # id) (c_saveMatrix cname cformat) #|"saveMatrix" free cname free cformat return () -------------------------------------------------------------------------------- - -- cgit v1.2.3