From db50bc11dafa6834a4367427156306674063ed6b Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 19 Jun 2015 13:55:39 +0200 Subject: removed the annoying appN adapter for the foreign functions. replaced by several overloaded app variants in the style of the module Internal.Foreign contributed by Mike Ledger. --- packages/base/src/Internal/Devel.hs | 89 ++++++++++++-------------------- packages/base/src/Internal/LAPACK.hs | 54 +++++++++++-------- packages/base/src/Internal/Matrix.hs | 72 +++++++++++++++++++------- packages/base/src/Internal/Sparse.hs | 4 +- packages/base/src/Internal/Util.hs | 6 +-- packages/base/src/Internal/Vector.hs | 12 ++++- packages/base/src/Internal/Vectorized.hs | 38 ++++++++------ 7 files changed, 153 insertions(+), 122 deletions(-) (limited to 'packages/base/src/Internal') diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs index b8e04ef..4be0afd 100644 --- a/packages/base/src/Internal/Devel.hs +++ b/packages/base/src/Internal/Devel.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} -- | -- Module : Internal.Devel @@ -16,68 +17,14 @@ import Foreign.C.Types ( CInt ) --import Foreign.Storable.Complex () import Foreign.Ptr(Ptr) import Control.Exception as E ( SomeException, catch ) - +import Internal.Vector(Vector,avec,arrvec) +import Foreign.Storable(Storable) -- | postfix function application (@flip ($)@) (//) :: x -> (x -> y) -> y infixl 0 // (//) = flip ($) --- hmm.. -ww2 w1 o1 w2 o2 f = w1 o1 $ w2 o2 . f -ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ ww2 w2 o2 w3 o3 . f -ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ ww3 w2 o2 w3 o3 w4 o4 . f -ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 f = w1 o1 $ ww4 w2 o2 w3 o3 w4 o4 w5 o5 . f -ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 f = w1 o1 $ ww5 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 . f -ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 f = w1 o1 $ ww6 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 . f -ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 f = w1 o1 $ ww7 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 . f -ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 f = w1 o1 $ ww8 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 . f -ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 f = w1 o1 $ ww9 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 . f - -type Adapt f t r = t -> ((f -> r) -> IO()) -> IO() - -type Adapt1 f t1 = Adapt f t1 (IO CInt) -> t1 -> String -> IO() -type Adapt2 f t1 r1 t2 = Adapt f t1 r1 -> t1 -> Adapt1 r1 t2 -type Adapt3 f t1 r1 t2 r2 t3 = Adapt f t1 r1 -> t1 -> Adapt2 r1 t2 r2 t3 -type Adapt4 f t1 r1 t2 r2 t3 r3 t4 = Adapt f t1 r1 -> t1 -> Adapt3 r1 t2 r2 t3 r3 t4 -type Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 = Adapt f t1 r1 -> t1 -> Adapt4 r1 t2 r2 t3 r3 t4 r4 t5 -type Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 = Adapt f t1 r1 -> t1 -> Adapt5 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 -type Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 = Adapt f t1 r1 -> t1 -> Adapt6 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 -type Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 = Adapt f t1 r1 -> t1 -> Adapt7 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 -type Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 = Adapt f t1 r1 -> t1 -> Adapt8 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 -type Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 = Adapt f t1 r1 -> t1 -> Adapt9 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 - -app1 :: f -> Adapt1 f t1 -app2 :: f -> Adapt2 f t1 r1 t2 -app3 :: f -> Adapt3 f t1 r1 t2 r2 t3 -app4 :: f -> Adapt4 f t1 r1 t2 r2 t3 r3 t4 -app5 :: f -> Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 -app6 :: f -> Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 -app7 :: f -> Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 -app8 :: f -> Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 -app9 :: f -> Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 -app10 :: f -> Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 - -app1 f w1 o1 s = w1 o1 $ \a1 -> f // a1 // check s -app2 f w1 o1 w2 o2 s = ww2 w1 o1 w2 o2 $ \a1 a2 -> f // a1 // a2 // check s -app3 f w1 o1 w2 o2 w3 o3 s = ww3 w1 o1 w2 o2 w3 o3 $ - \a1 a2 a3 -> f // a1 // a2 // a3 // check s -app4 f w1 o1 w2 o2 w3 o3 w4 o4 s = ww4 w1 o1 w2 o2 w3 o3 w4 o4 $ - \a1 a2 a3 a4 -> f // a1 // a2 // a3 // a4 // check s -app5 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 s = ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 $ - \a1 a2 a3 a4 a5 -> f // a1 // a2 // a3 // a4 // a5 // check s -app6 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 s = ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 $ - \a1 a2 a3 a4 a5 a6 -> f // a1 // a2 // a3 // a4 // a5 // a6 // check s -app7 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 s = ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 $ - \a1 a2 a3 a4 a5 a6 a7 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // check s -app8 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 s = ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 $ - \a1 a2 a3 a4 a5 a6 a7 a8 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // check s -app9 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 s = ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 $ - \a1 a2 a3 a4 a5 a6 a7 a8 a9 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // check s -app10 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 s = ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 $ - \a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // a10 // check s - - -- GSL error codes are <= 1024 -- | error codes for the auxiliary functions required by the wrappers @@ -104,6 +51,11 @@ check msg f = do when (err/=0) $ error (msg++": "++errorCode err) return () + +-- | postfix error code check +infixl 0 #| +(#|) = flip check + -- | Error capture and conversion to Maybe mbCatch :: IO x -> IO (Maybe x) mbCatch act = E.catch (Just `fmap` act) f @@ -124,4 +76,27 @@ type (:>) t r = CV t r type (::>) t r = OM t r type (..>) t r = CM t r +class TransArray c + where + type Trans c b + type TransRaw c b + type Elem c + apply :: (Trans c b) -> c -> b + applyRaw :: (TransRaw c b) -> c -> b + applyArray :: (Ptr CInt -> Ptr (Elem c) -> b) -> c -> b + infixl 1 `apply`, `applyRaw`, `applyArray` + +instance Storable t => TransArray (Vector t) + where + type Trans (Vector t) b = CInt -> Ptr t -> b + type TransRaw (Vector t) b = CInt -> Ptr t -> b + type Elem (Vector t) = t + apply = avec + {-# INLINE apply #-} + applyRaw = avec + {-# INLINE applyRaw #-} + applyArray = arrvec + {-# INLINE applyArray #-} + + diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 8df568d..3a9abbb 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs @@ -17,7 +17,7 @@ module Internal.LAPACK where import Internal.Devel import Internal.Vector -import Internal.Matrix +import Internal.Matrix hiding ((#)) import Internal.Conversion import Internal.Element import Foreign.Ptr(nullPtr) @@ -27,6 +27,16 @@ import System.IO.Unsafe(unsafePerformIO) ----------------------------------------------------------------------------------- +infixl 1 # +a # b = applyRaw a b +{-# INLINE (#) #-} + +infixl 1 #! +a #! b = apply a b +{-# INLINE (#!) #-} + +----------------------------------------------------------------------------------- + type TMMM t = t ..> t ..> t ..> Ok type F = Float @@ -49,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ show (rows a,cols a) ++ " x " ++ show (rows b, cols b) s <- createMatrix ColumnMajor (rows a) (cols b) - app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st + f (isT a) (isT b) # (tt a) # (tt b) # s #| st return s -- | Matrix product based on BLAS's /dgemm/. @@ -73,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b s <- createMatrix ColumnMajor (rows a) (cols b) - app3 (c_multiplyI m) omat a omat b omat s "c_multiplyI" + c_multiplyI m #! a #! b #! s #|"c_multiplyI" return s multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z @@ -81,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b s <- createMatrix ColumnMajor (rows a) (cols b) - app3 (c_multiplyL m) omat a omat b omat s "c_multiplyL" + c_multiplyL m #! a #! b #! s #|"c_multiplyL" return s ----------------------------------------------------------------------------- @@ -113,7 +123,7 @@ svdAux f st x = unsafePerformIO $ do u <- createMatrix ColumnMajor r r s <- createVector (min r c) v <- createMatrix ColumnMajor c c - app4 f mat x mat u vec s mat v st + f # x # u # s # v #| st return (u,s,v) where r = rows x c = cols x @@ -139,7 +149,7 @@ thinSVDAux f st x = unsafePerformIO $ do u <- createMatrix ColumnMajor r q s <- createVector q v <- createMatrix ColumnMajor q c - app4 f mat x mat u vec s mat v st + f # x # u # s # v #| st return (u,s,v) where r = rows x c = cols x @@ -164,7 +174,7 @@ svCd = svAux zgesdd "svCd" . fmat svAux f st x = unsafePerformIO $ do s <- createVector q - app2 g mat x vec s st + g # x # s #| st return s where r = rows x c = cols x @@ -183,7 +193,7 @@ rightSVC = rightSVAux zgesvd "rightSVC" . fmat rightSVAux f st x = unsafePerformIO $ do s <- createVector q v <- createMatrix ColumnMajor c c - app3 g mat x vec s mat v st + g # x # s # v #| st return (s,v) where r = rows x c = cols x @@ -202,7 +212,7 @@ leftSVC = leftSVAux zgesvd "leftSVC" . fmat leftSVAux f st x = unsafePerformIO $ do u <- createMatrix ColumnMajor r r s <- createVector q - app3 g mat x mat u vec s st + g # x # u # s #| st return (u,s) where r = rows x c = cols x @@ -219,7 +229,7 @@ foreign import ccall unsafe "eig_l_H" zheev :: CInt -> C ..> R :> C ..> Ok eigAux f st m = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - app3 g mat m vec l mat v st + g # m # l # v #| st return (l,v) where r = rows m g ra ca pa = f ra ca pa 0 0 nullPtr @@ -232,7 +242,7 @@ eigC = eigAux zgeev "eigC" . fmat eigOnlyAux f st m = unsafePerformIO $ do l <- createVector r - app2 g mat m vec l st + g # m # l #| st return l where r = rows m g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr @@ -255,7 +265,7 @@ eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) eigRaux m = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - app3 g mat m vec l mat v "eigR" + g # m # l # v #| "eigR" return (l,v) where r = rows m g ra ca pa = dgeev ra ca pa 0 0 nullPtr @@ -282,7 +292,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat eigSHAux f st m = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - app3 f mat m vec l mat v st + f # m # l # v #| st return (l,v) where r = rows m @@ -332,7 +342,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: TMMM C linearSolveSQAux g f st a b | n1==n2 && n1==r = unsafePerformIO . g $ do s <- createMatrix ColumnMajor r c - app3 f mat a mat b mat s st + f # a # b # s #| st return s | otherwise = error $ st ++ " of nonsquare matrix" where n1 = rows a @@ -371,7 +381,7 @@ foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TMMM C linearSolveAux f st a b = unsafePerformIO $ do r <- createMatrix ColumnMajor (max m n) nrhs - app3 f mat a mat b mat r st + f # a # b # r #| st return r where m = rows a n = cols a @@ -412,7 +422,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R cholAux f st a = do r <- createMatrix ColumnMajor n n - app2 f mat a mat r st + f # a # r #| st return r where n = rows a @@ -450,7 +460,7 @@ qrC = qrAux zgeqr2 "qrC" . fmat qrAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector mn - app3 f mat a vec tau mat r st + f # a # tau # r #| st return (r,tau) where m = rows a @@ -469,7 +479,7 @@ qrgrC = qrgrAux zungqr "qrgrC" qrgrAux f st n (a, tau) = unsafePerformIO $ do res <- createMatrix ColumnMajor (rows a) n - app3 f mat (fmat a) vec (subVector 0 n tau') mat res st + f # (fmat a) # (subVector 0 n tau') # res #| st return res where tau' = vjoin [tau, constantD 0 n] @@ -489,7 +499,7 @@ hessC = hessAux zgehrd "hessC" . fmat hessAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector (mn-1) - app3 f mat a vec tau mat r st + f # a # tau # r #| st return (r,tau) where m = rows a n = cols a @@ -510,7 +520,7 @@ schurC = schurAux zgees "schurC" . fmat schurAux f st a = unsafePerformIO $ do u <- createMatrix ColumnMajor n n s <- createMatrix ColumnMajor n n - app3 f mat a mat u mat s st + f # a # u # s #| st return (u,s) where n = rows a @@ -529,7 +539,7 @@ luC = luAux zgetrf "luC" . fmat luAux f st a = unsafePerformIO $ do lu <- createMatrix ColumnMajor n m piv <- createVector (min n m) - app3 f mat a vec piv mat lu st + f # a # piv # lu #| st return (lu, map (pred.round) (toList piv)) where n = rows a m = cols a @@ -552,7 +562,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b) lusAux f st a piv b | n1==n2 && n2==n =unsafePerformIO $ do x <- createMatrix ColumnMajor n m - app4 f mat a vec piv' mat b mat x st + f # a # piv' # b # x #| st return x | otherwise = error $ st ++ " on LU factorization of nonsquare matrix" where n1 = rows a diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8f8c219..db0a609 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -3,6 +3,8 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} + -- | -- Module : Internal.Matrix @@ -18,7 +20,7 @@ module Internal.Matrix where import Internal.Vector import Internal.Devel -import Internal.Vectorized +import Internal.Vectorized hiding ((#)) import Foreign.Marshal.Alloc ( free ) import Foreign.Marshal.Array(newArray) import Foreign.Ptr ( Ptr ) @@ -79,8 +81,6 @@ data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int -- RowMajor: preferred by C, fdat may require a transposition -- ColumnMajor: preferred by LAPACK, cdat may require a transposition ---cdat = xdat ---fdat = xdat rows :: Matrix t -> Int rows = irows @@ -129,6 +129,48 @@ omat a f = g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p f m +-------------------------------------------------------------------------------- + +{-# INLINE amatr #-} +amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b +amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) + where + r = fromIntegral (rows x) + c = fromIntegral (cols x) + +{-# INLINE amat #-} +amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b +amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) + where + r = fromIntegral (rows x) + c = fromIntegral (cols x) + sr = stepRow x + sc = stepCol x + +{-# INLINE arrmat #-} +arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b +arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p))) + where + s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x] + + +instance Storable t => TransArray (Matrix t) + where + type Elem (Matrix t) = t + type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b + type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b + apply = amat + {-# INLINE apply #-} + applyRaw = amatr + {-# INLINE applyRaw #-} + applyArray = arrmat + {-# INLINE applyArray #-} + +infixl 1 # +a # b = apply a b +{-# INLINE (#) #-} + +-------------------------------------------------------------------------------- {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. @@ -139,12 +181,6 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] flatten :: Element t => Matrix t -> Vector t flatten = xdat . cmat -{- -type Mt t s = Int -> Int -> Ptr t -> s - -infixr 6 ::> -type t ::> s = Mt t s --} -- | the inverse of 'Data.Packed.Matrix.fromLists' toLists :: (Element t) => Matrix t -> [[t]] @@ -445,7 +481,7 @@ extractAux f m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc r <- createMatrix RowMajor nr nc - app4 (f moder modec) vec vr vec vc omat m omat r "extractAux" + f moder modec # vr # vc # m # r #|"extract" return r type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) @@ -459,7 +495,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z --------------------------------------------------------------- -setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" +setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" type SetRect x = I -> I -> x ::> x::> Ok @@ -474,7 +510,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z sortG f v = unsafePerformIO $ do r <- createVector (dim v) - app2 f vec v vec r "sortG" + f # v # r #|"sortG" return r sortIdxD = sortG c_sort_indexD @@ -501,7 +537,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok compareG f u v = unsafePerformIO $ do r <- createVector (dim v) - app3 f vec u vec v vec r "compareG" + f # u # v # r #|"compareG" return r compareD = compareG c_compareD @@ -518,7 +554,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) - app5 f vec c vec u vec v vec w vec r "selectG" + f # c # u # v # w # r #|"selectG" return r selectD = selectG c_selectD @@ -541,7 +577,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) - app4 f omat i omat j omat m omat r "remapG" + f # i # j # m # r #|"remapG" return r remapD = remapG c_remapD @@ -564,7 +600,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] - app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" + f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" free px type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok @@ -580,7 +616,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- -gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" +gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" type Tgemm x = x :> I :> x ::> x ::> x ::> Ok @@ -608,7 +644,7 @@ saveMatrix saveMatrix name format m = do cname <- newCString name cformat <- newCString format - app1 (c_saveMatrix cname cformat) mat m "saveMatrix" + c_saveMatrix cname cformat `applyRaw` m #|"saveMatrix" free cname free cformat return () diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs index b365c15..eb4ee1b 100644 --- a/packages/base/src/Internal/Sparse.hs +++ b/packages/base/src/Internal/Sparse.hs @@ -145,13 +145,13 @@ gmXv :: GMatrix -> Vector Double -> Vector Double gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows - app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" + c_smXv # csrVals # csrCols # csrRows # v # r #|"CSRXv" return r gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) r <- createVector nRows - app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" + c_smTXv # cscVals # cscRows # cscCols # v # r #|"CSCXv" return r gmXv Diag{..} v diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 079663d..924ca4c 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -31,7 +31,7 @@ module Internal.Util( diagl, row, col, - (&), (¦), (|||), (——), (===), (#), + (&), (¦), (|||), (——), (===), (?), (¿), Indexable(..), size, Numeric, @@ -185,10 +185,6 @@ infixl 2 —— (——) = (===) -(#) :: Matrix Double -> Matrix Double -> Matrix Double -infixl 2 # -a # b = fromBlocks [[a],[b]] - -- | create a single row real matrix from a list -- -- >>> row [2,3,1,8] diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs index 0e9161d..e5ac440 100644 --- a/packages/base/src/Internal/Vector.hs +++ b/packages/base/src/Internal/Vector.hs @@ -14,7 +14,7 @@ module Internal.Vector( I,Z,R,C, fi,ti, Vector, fromList, unsafeToForeignPtr, unsafeFromForeignPtr, unsafeWith, - createVector, vec, + createVector, vec, avec, arrvec, inlinePerformIO, toList, dim, (@>), at', (|>), vjoin, subVector, takesV, idxs, buildVector, @@ -75,6 +75,16 @@ vec x f = unsafeWith x $ \p -> do f v {-# INLINE vec #-} +{-# INLINE avec #-} +avec :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b +avec f v = inlinePerformIO (unsafeWith v (return . f (fromIntegral (Vector.length v)))) +infixl 1 `avec` + +{-# INLINE arrvec #-} +arrvec :: Storable a => (Ptr CInt -> Ptr a -> b) -> Vector a -> b +arrvec f v = inlinePerformIO (unsafeWith (idxs [1,dim v]) (\p -> unsafeWith v (return . f p))) + + -- allocates memory for a new vector createVector :: Storable a => Int -> IO (Vector a) diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs index 5c89ac9..03bcf90 100644 --- a/packages/base/src/Internal/Vectorized.hs +++ b/packages/base/src/Internal/Vectorized.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} ----------------------------------------------------------------------------- -- | @@ -26,7 +27,9 @@ import Foreign.C.String import System.IO.Unsafe(unsafePerformIO) import Control.Monad(when) - +infixl 1 # +a # b = applyRaw a b +{-# INLINE (#) #-} fromei x = fromIntegral (fromEnum x) :: CInt @@ -100,7 +103,7 @@ sumL m = sumg (c_sumL m) sumg f x = unsafePerformIO $ do r <- createVector 1 - app2 f vec x vec r "sum" + f # x # r #| "sum" return $ r @> 0 type TVV t = t :> t :> Ok @@ -128,14 +131,15 @@ prodQ = prodg c_prodQ prodC :: Vector (Complex Double) -> Complex Double prodC = prodg c_prodC - +prodI :: I-> Vector I -> I prodI = prodg . c_prodI +prodL :: Z-> Vector Z -> Z prodL = prodg . c_prodL prodg f x = unsafePerformIO $ do r <- createVector 1 - app2 f vec x vec r "prod" + f # x # r #| "prod" return $ r @> 0 @@ -150,24 +154,24 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z toScalarAux fun code v = unsafePerformIO $ do r <- createVector 1 - app2 (fun (fromei code)) vec v vec r "toScalarAux" + fun (fromei code) # v # r #|"toScalarAux" return (r @> 0) vectorMapAux fun code v = unsafePerformIO $ do r <- createVector (dim v) - app2 (fun (fromei code)) vec v vec r "vectorMapAux" + fun (fromei code) # v # r #|"vectorMapAux" return r vectorMapValAux fun code val v = unsafePerformIO $ do r <- createVector (dim v) pval <- newArray [val] - app2 (fun (fromei code) pval) vec v vec r "vectorMapValAux" + fun (fromei code) pval # v # r #|"vectorMapValAux" free pval return r vectorZipAux fun code u v = unsafePerformIO $ do r <- createVector (dim u) - app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" + fun (fromei code) # u # v # r #|"vectorZipAux" return r --------------------------------------------------------------------- @@ -364,7 +368,7 @@ randomVector :: Seed -> Vector Double randomVector seed dist n = unsafePerformIO $ do r <- createVector n - app1 (c_random_vector (fi seed) ((fi.fromEnum) dist)) vec r "randomVector" + c_random_vector (fi seed) ((fi.fromEnum) dist) # r #|"randomVector" return r foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok @@ -373,7 +377,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D roundVector v = unsafePerformIO $ do r <- createVector (dim v) - app2 c_round_vector vec v vec r "roundVector" + c_round_vector # v # r #|"roundVector" return r foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double @@ -387,7 +391,7 @@ foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double range :: Int -> Vector I range n = unsafePerformIO $ do r <- createVector n - app1 c_range_vector vec r "range" + c_range_vector # r #|"range" return r foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok @@ -427,7 +431,7 @@ long2intV = tog c_long2int tog f v = unsafePerformIO $ do r <- createVector (dim v) - app2 f vec v vec r "tog" + f # v # r #|"tog" return r foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok @@ -446,7 +450,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok stepg f v = unsafePerformIO $ do r <- createVector (dim v) - app2 f vec v vec r "step" + f # v # r #|"step" return r stepD :: Vector Double -> Vector Double @@ -471,7 +475,7 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z conjugateAux fun x = unsafePerformIO $ do v <- createVector (dim x) - app2 fun vec x vec v "conjugateAux" + fun # x # v #|"conjugateAux" return v conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) @@ -489,7 +493,7 @@ cloneVector v = do let n = dim v r <- createVector n let f _ s _ d = copyArray d s n >> return 0 - app2 f vec v vec r "cloneVector" + f # v # r #|"cloneVector" return r -------------------------------------------------------------------------------- @@ -497,7 +501,7 @@ cloneVector v = do constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] - app1 (fun px) vec v "constantAux" + fun px # v #|"constantAux" free px return v -- cgit v1.2.3