From db50bc11dafa6834a4367427156306674063ed6b Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 19 Jun 2015 13:55:39 +0200 Subject: removed the annoying appN adapter for the foreign functions. replaced by several overloaded app variants in the style of the module Internal.Foreign contributed by Mike Ledger. --- packages/base/src/Internal/Matrix.hs | 72 +++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 18 deletions(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8f8c219..db0a609 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -3,6 +3,8 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} + -- | -- Module : Internal.Matrix @@ -18,7 +20,7 @@ module Internal.Matrix where import Internal.Vector import Internal.Devel -import Internal.Vectorized +import Internal.Vectorized hiding ((#)) import Foreign.Marshal.Alloc ( free ) import Foreign.Marshal.Array(newArray) import Foreign.Ptr ( Ptr ) @@ -79,8 +81,6 @@ data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int -- RowMajor: preferred by C, fdat may require a transposition -- ColumnMajor: preferred by LAPACK, cdat may require a transposition ---cdat = xdat ---fdat = xdat rows :: Matrix t -> Int rows = irows @@ -129,6 +129,48 @@ omat a f = g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p f m +-------------------------------------------------------------------------------- + +{-# INLINE amatr #-} +amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b +amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) + where + r = fromIntegral (rows x) + c = fromIntegral (cols x) + +{-# INLINE amat #-} +amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b +amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) + where + r = fromIntegral (rows x) + c = fromIntegral (cols x) + sr = stepRow x + sc = stepCol x + +{-# INLINE arrmat #-} +arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b +arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p))) + where + s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x] + + +instance Storable t => TransArray (Matrix t) + where + type Elem (Matrix t) = t + type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b + type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b + apply = amat + {-# INLINE apply #-} + applyRaw = amatr + {-# INLINE applyRaw #-} + applyArray = arrmat + {-# INLINE applyArray #-} + +infixl 1 # +a # b = apply a b +{-# INLINE (#) #-} + +-------------------------------------------------------------------------------- {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. @@ -139,12 +181,6 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] flatten :: Element t => Matrix t -> Vector t flatten = xdat . cmat -{- -type Mt t s = Int -> Int -> Ptr t -> s - -infixr 6 ::> -type t ::> s = Mt t s --} -- | the inverse of 'Data.Packed.Matrix.fromLists' toLists :: (Element t) => Matrix t -> [[t]] @@ -445,7 +481,7 @@ extractAux f m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix RowMajor nr nc - app4 (f moder modec) vec vr vec vc omat m omat r "extractAux" + f moder modec # vr # vc # m # r #|"extract" return r type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) @@ -459,7 +495,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z --------------------------------------------------------------- -setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" +setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" type SetRect x = I -> I -> x ::> x::> Ok @@ -474,7 +510,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z sortG f v = unsafePerformIO $ do r <- createVector (dim v) - app2 f vec v vec r "sortG" + f # v # r #|"sortG" return r sortIdxD = sortG c_sort_indexD @@ -501,7 +537,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok compareG f u v = unsafePerformIO $ do r <- createVector (dim v) - app3 f vec u vec v vec r "compareG" + f # u # v # r #|"compareG" return r compareD = compareG c_compareD @@ -518,7 +554,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) - app5 f vec c vec u vec v vec w vec r "selectG" + f # c # u # v # w # r #|"selectG" return r selectD = selectG c_selectD @@ -541,7 +577,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) - app4 f omat i omat j omat m omat r "remapG" + f # i # j # m # r #|"remapG" return r remapD = remapG c_remapD @@ -564,7 +600,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] - app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" + f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" free px type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok @@ -580,7 +616,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- -gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" +gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" type Tgemm x = x :> I :> x ::> x ::> x ::> Ok @@ -608,7 +644,7 @@ saveMatrix saveMatrix name format m = do cname <- newCString name cformat <- newCString format - app1 (c_saveMatrix cname cformat) mat m "saveMatrix" + c_saveMatrix cname cformat `applyRaw` m #|"saveMatrix" free cname free cformat return () -- cgit v1.2.3