summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Slagle <kjslag@gmail.com>2016-12-19 09:40:37 -0500
committerKevin Slagle <kjslag@gmail.com>2016-12-19 09:40:37 -0500
commit0431e82183a925e63472bbc9a17db4eb84f904a6 (patch)
treefaeb333ead212fbe704640b32b9638f1328e5946
parent66f0174cec6b9b3a329321a435d0c7841f396077 (diff)
add reorderVector function for tensor libraries (e.g. hTensor) to implement tensor transpose
-rw-r--r--packages/base/src/Internal/C/vector-aux.c51
-rw-r--r--packages/base/src/Internal/Matrix.hs34
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs2
3 files changed, 86 insertions, 1 deletions
diff --git a/packages/base/src/Internal/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c
index 1cef27d..dcd6c0b 100644
--- a/packages/base/src/Internal/C/vector-aux.c
+++ b/packages/base/src/Internal/C/vector-aux.c
@@ -1533,3 +1533,54 @@ int chooseQ(KIVEC(cond),KQVEC(lt),KQVEC(eq),KQVEC(gt),QVEC(r)) {
1533 CHOOSE_IMP 1533 CHOOSE_IMP
1534} 1534}
1535 1535
1536//////////////////// reorder /////////////////////////
1537
1538#define REORDER_IMP \
1539 REQUIRES(kn == stridesn && stridesn == dimsn ,BAD_SIZE); \
1540 int i,j,l; \
1541 for (i=1,j=0,l=0;l<kn;++l) { \
1542 kp[l] = 0; \
1543 i *= dimsp[l]; \
1544 j += (dimsp[l]-1) * stridesp[l]; \
1545 } \
1546 REQUIRES(i <= vn && j < rn ,BAD_SIZE); \
1547 for (i=0,j=0;;i++) { \
1548 rp[i] = vp[j]; \
1549 for(l=kn-1;;l--) { \
1550 ++kp[l]; \
1551 if (kp[l] < dimsp[l]) { \
1552 j += stridesp[l]; \
1553 break; \
1554 } else { \
1555 if (l == 0) { \
1556 return 0; \
1557 } \
1558 kp[l] = 0; \
1559 j -= (dimsp[l]-1) * stridesp[l]; \
1560 } \
1561 } \
1562 }
1563
1564int reorderF(IVEC(k), KIVEC(strides),KIVEC(dims),KFVEC(v),FVEC(r)) {
1565 REORDER_IMP
1566}
1567
1568int reorderD(IVEC(k), KIVEC(strides),KIVEC(dims),KDVEC(v),DVEC(r)) {
1569 REORDER_IMP
1570}
1571
1572int reorderI(IVEC(k), KIVEC(strides),KIVEC(dims),KIVEC(v),IVEC(r)) {
1573 REORDER_IMP
1574}
1575
1576int reorderL(IVEC(k), KIVEC(strides),KIVEC(dims),KLVEC(v),LVEC(r)) {
1577 REORDER_IMP
1578}
1579
1580int reorderC(IVEC(k), KIVEC(strides),KIVEC(dims),KCVEC(v),CVEC(r)) {
1581 REORDER_IMP
1582}
1583
1584int reorderQ(IVEC(k), KIVEC(strides),KIVEC(dims),KQVEC(v),QVEC(r)) {
1585 REORDER_IMP
1586}
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 0135288..fbf8295 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -285,6 +285,7 @@ class (Storable a) => Element a where
285 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 285 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
286 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 286 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
287 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () 287 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
288 reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
288 289
289 290
290instance Element Float where 291instance Element Float where
@@ -298,6 +299,7 @@ instance Element Float where
298 remapM = remapF 299 remapM = remapF
299 rowOp = rowOpAux c_rowOpF 300 rowOp = rowOpAux c_rowOpF
300 gemm = gemmg c_gemmF 301 gemm = gemmg c_gemmF
302 reorderV = reorderAux c_reorderF
301 303
302instance Element Double where 304instance Element Double where
303 constantD = constantAux cconstantR 305 constantD = constantAux cconstantR
@@ -310,6 +312,7 @@ instance Element Double where
310 remapM = remapD 312 remapM = remapD
311 rowOp = rowOpAux c_rowOpD 313 rowOp = rowOpAux c_rowOpD
312 gemm = gemmg c_gemmD 314 gemm = gemmg c_gemmD
315 reorderV = reorderAux c_reorderD
313 316
314instance Element (Complex Float) where 317instance Element (Complex Float) where
315 constantD = constantAux cconstantQ 318 constantD = constantAux cconstantQ
@@ -322,6 +325,7 @@ instance Element (Complex Float) where
322 remapM = remapQ 325 remapM = remapQ
323 rowOp = rowOpAux c_rowOpQ 326 rowOp = rowOpAux c_rowOpQ
324 gemm = gemmg c_gemmQ 327 gemm = gemmg c_gemmQ
328 reorderV = reorderAux c_reorderQ
325 329
326instance Element (Complex Double) where 330instance Element (Complex Double) where
327 constantD = constantAux cconstantC 331 constantD = constantAux cconstantC
@@ -334,6 +338,7 @@ instance Element (Complex Double) where
334 remapM = remapC 338 remapM = remapC
335 rowOp = rowOpAux c_rowOpC 339 rowOp = rowOpAux c_rowOpC
336 gemm = gemmg c_gemmC 340 gemm = gemmg c_gemmC
341 reorderV = reorderAux c_reorderC
337 342
338instance Element (CInt) where 343instance Element (CInt) where
339 constantD = constantAux cconstantI 344 constantD = constantAux cconstantI
@@ -346,6 +351,7 @@ instance Element (CInt) where
346 remapM = remapI 351 remapM = remapI
347 rowOp = rowOpAux c_rowOpI 352 rowOp = rowOpAux c_rowOpI
348 gemm = gemmg c_gemmI 353 gemm = gemmg c_gemmI
354 reorderV = reorderAux c_reorderI
349 355
350instance Element Z where 356instance Element Z where
351 constantD = constantAux cconstantL 357 constantD = constantAux cconstantL
@@ -358,6 +364,7 @@ instance Element Z where
358 remapM = remapL 364 remapM = remapL
359 rowOp = rowOpAux c_rowOpL 365 rowOp = rowOpAux c_rowOpL
360 gemm = gemmg c_gemmL 366 gemm = gemmg c_gemmL
367 reorderV = reorderAux c_reorderL
361 368
362------------------------------------------------------------------- 369-------------------------------------------------------------------
363 370
@@ -580,6 +587,33 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
580 587
581-------------------------------------------------------------------------------- 588--------------------------------------------------------------------------------
582 589
590reorderAux f s d v = unsafePerformIO $ do
591 k <- createVector (dim s)
592 r <- createVector (dim v)
593 (k # s # d # v #! r) f #| "reorderV"
594 return r
595
596type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt)))))
597
598foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
599foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
600foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt
601foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
602foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
603foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
604
605-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices,
606-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@
607-- This function is intended to be used internally by tensor libraries..
608reorderVector :: Element a
609 => Vector CInt -- ^ @strides@: array strides
610 -> Vector CInt -- ^ @dims@: array dimensions of new array @v@
611 -> Vector a -- ^ @v@: flattened input array
612 -> Vector a -- ^ @v'@: flattened output array
613reorderVector = reorderV
614
615--------------------------------------------------------------------------------
616
583foreign import ccall unsafe "saveMatrix" c_saveMatrix 617foreign import ccall unsafe "saveMatrix" c_saveMatrix
584 :: CString -> CString -> Double ::> Ok 618 :: CString -> CString -> Double ::> Ok
585 619
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index 941b597..e974fc4 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -55,7 +55,7 @@ module Numeric.LinearAlgebra.Devel(
55 GMatrix(..), 55 GMatrix(..),
56 56
57 -- * Misc 57 -- * Misc
58 toByteString, fromByteString, showInternal 58 toByteString, fromByteString, showInternal, reorderVector
59 59
60) where 60) where
61 61