From 0431e82183a925e63472bbc9a17db4eb84f904a6 Mon Sep 17 00:00:00 2001 From: Kevin Slagle Date: Mon, 19 Dec 2016 09:40:37 -0500 Subject: add reorderVector function for tensor libraries (e.g. hTensor) to implement tensor transpose --- packages/base/src/Internal/Matrix.hs | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 0135288..fbf8295 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -285,6 +285,7 @@ class (Storable a) => Element a where remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () + reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation instance Element Float where @@ -298,6 +299,7 @@ instance Element Float where remapM = remapF rowOp = rowOpAux c_rowOpF gemm = gemmg c_gemmF + reorderV = reorderAux c_reorderF instance Element Double where constantD = constantAux cconstantR @@ -310,6 +312,7 @@ instance Element Double where remapM = remapD rowOp = rowOpAux c_rowOpD gemm = gemmg c_gemmD + reorderV = reorderAux c_reorderD instance Element (Complex Float) where constantD = constantAux cconstantQ @@ -322,6 +325,7 @@ instance Element (Complex Float) where remapM = remapQ rowOp = rowOpAux c_rowOpQ gemm = gemmg c_gemmQ + reorderV = reorderAux c_reorderQ instance Element (Complex Double) where constantD = constantAux cconstantC @@ -334,6 +338,7 @@ instance Element (Complex Double) where remapM = remapC rowOp = rowOpAux c_rowOpC gemm = gemmg c_gemmC + reorderV = reorderAux c_reorderC instance Element (CInt) where constantD = constantAux cconstantI @@ -346,6 +351,7 @@ instance Element (CInt) where remapM = remapI rowOp = rowOpAux c_rowOpI gemm = gemmg c_gemmI + reorderV = reorderAux c_reorderI instance Element Z where constantD = constantAux cconstantL @@ -358,6 +364,7 @@ instance Element Z where remapM = remapL rowOp = rowOpAux c_rowOpL gemm = gemmg c_gemmL + reorderV = reorderAux c_reorderL ------------------------------------------------------------------- @@ -580,6 +587,33 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z -------------------------------------------------------------------------------- +reorderAux f s d v = unsafePerformIO $ do + k <- createVector (dim s) + r <- createVector (dim v) + (k # s # d # v #! r) f #| "reorderV" + return r + +type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) + +foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double +foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float +foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt +foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) +foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) +foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z + +-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, +-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ +-- This function is intended to be used internally by tensor libraries.. +reorderVector :: Element a + => Vector CInt -- ^ @strides@: array strides + -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ + -> Vector a -- ^ @v@: flattened input array + -> Vector a -- ^ @v'@: flattened output array +reorderVector = reorderV + +-------------------------------------------------------------------------------- + foreign import ccall unsafe "saveMatrix" c_saveMatrix :: CString -> CString -> Double ::> Ok -- cgit v1.2.3