diff options
author | Kevin Slagle <kjslag@gmail.com> | 2016-12-19 09:40:37 -0500 |
---|---|---|
committer | Kevin Slagle <kjslag@gmail.com> | 2016-12-19 09:40:37 -0500 |
commit | 0431e82183a925e63472bbc9a17db4eb84f904a6 (patch) | |
tree | faeb333ead212fbe704640b32b9638f1328e5946 | |
parent | 66f0174cec6b9b3a329321a435d0c7841f396077 (diff) |
add reorderVector function for tensor libraries (e.g. hTensor) to implement tensor transpose
-rw-r--r-- | packages/base/src/Internal/C/vector-aux.c | 51 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 34 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 2 |
3 files changed, 86 insertions, 1 deletions
diff --git a/packages/base/src/Internal/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c index 1cef27d..dcd6c0b 100644 --- a/packages/base/src/Internal/C/vector-aux.c +++ b/packages/base/src/Internal/C/vector-aux.c | |||
@@ -1533,3 +1533,54 @@ int chooseQ(KIVEC(cond),KQVEC(lt),KQVEC(eq),KQVEC(gt),QVEC(r)) { | |||
1533 | CHOOSE_IMP | 1533 | CHOOSE_IMP |
1534 | } | 1534 | } |
1535 | 1535 | ||
1536 | //////////////////// reorder ///////////////////////// | ||
1537 | |||
1538 | #define REORDER_IMP \ | ||
1539 | REQUIRES(kn == stridesn && stridesn == dimsn ,BAD_SIZE); \ | ||
1540 | int i,j,l; \ | ||
1541 | for (i=1,j=0,l=0;l<kn;++l) { \ | ||
1542 | kp[l] = 0; \ | ||
1543 | i *= dimsp[l]; \ | ||
1544 | j += (dimsp[l]-1) * stridesp[l]; \ | ||
1545 | } \ | ||
1546 | REQUIRES(i <= vn && j < rn ,BAD_SIZE); \ | ||
1547 | for (i=0,j=0;;i++) { \ | ||
1548 | rp[i] = vp[j]; \ | ||
1549 | for(l=kn-1;;l--) { \ | ||
1550 | ++kp[l]; \ | ||
1551 | if (kp[l] < dimsp[l]) { \ | ||
1552 | j += stridesp[l]; \ | ||
1553 | break; \ | ||
1554 | } else { \ | ||
1555 | if (l == 0) { \ | ||
1556 | return 0; \ | ||
1557 | } \ | ||
1558 | kp[l] = 0; \ | ||
1559 | j -= (dimsp[l]-1) * stridesp[l]; \ | ||
1560 | } \ | ||
1561 | } \ | ||
1562 | } | ||
1563 | |||
1564 | int reorderF(IVEC(k), KIVEC(strides),KIVEC(dims),KFVEC(v),FVEC(r)) { | ||
1565 | REORDER_IMP | ||
1566 | } | ||
1567 | |||
1568 | int reorderD(IVEC(k), KIVEC(strides),KIVEC(dims),KDVEC(v),DVEC(r)) { | ||
1569 | REORDER_IMP | ||
1570 | } | ||
1571 | |||
1572 | int reorderI(IVEC(k), KIVEC(strides),KIVEC(dims),KIVEC(v),IVEC(r)) { | ||
1573 | REORDER_IMP | ||
1574 | } | ||
1575 | |||
1576 | int reorderL(IVEC(k), KIVEC(strides),KIVEC(dims),KLVEC(v),LVEC(r)) { | ||
1577 | REORDER_IMP | ||
1578 | } | ||
1579 | |||
1580 | int reorderC(IVEC(k), KIVEC(strides),KIVEC(dims),KCVEC(v),CVEC(r)) { | ||
1581 | REORDER_IMP | ||
1582 | } | ||
1583 | |||
1584 | int reorderQ(IVEC(k), KIVEC(strides),KIVEC(dims),KQVEC(v),QVEC(r)) { | ||
1585 | REORDER_IMP | ||
1586 | } | ||
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 0135288..fbf8295 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -285,6 +285,7 @@ class (Storable a) => Element a where | |||
285 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 285 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
286 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 286 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
287 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | 287 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () |
288 | reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | ||
288 | 289 | ||
289 | 290 | ||
290 | instance Element Float where | 291 | instance Element Float where |
@@ -298,6 +299,7 @@ instance Element Float where | |||
298 | remapM = remapF | 299 | remapM = remapF |
299 | rowOp = rowOpAux c_rowOpF | 300 | rowOp = rowOpAux c_rowOpF |
300 | gemm = gemmg c_gemmF | 301 | gemm = gemmg c_gemmF |
302 | reorderV = reorderAux c_reorderF | ||
301 | 303 | ||
302 | instance Element Double where | 304 | instance Element Double where |
303 | constantD = constantAux cconstantR | 305 | constantD = constantAux cconstantR |
@@ -310,6 +312,7 @@ instance Element Double where | |||
310 | remapM = remapD | 312 | remapM = remapD |
311 | rowOp = rowOpAux c_rowOpD | 313 | rowOp = rowOpAux c_rowOpD |
312 | gemm = gemmg c_gemmD | 314 | gemm = gemmg c_gemmD |
315 | reorderV = reorderAux c_reorderD | ||
313 | 316 | ||
314 | instance Element (Complex Float) where | 317 | instance Element (Complex Float) where |
315 | constantD = constantAux cconstantQ | 318 | constantD = constantAux cconstantQ |
@@ -322,6 +325,7 @@ instance Element (Complex Float) where | |||
322 | remapM = remapQ | 325 | remapM = remapQ |
323 | rowOp = rowOpAux c_rowOpQ | 326 | rowOp = rowOpAux c_rowOpQ |
324 | gemm = gemmg c_gemmQ | 327 | gemm = gemmg c_gemmQ |
328 | reorderV = reorderAux c_reorderQ | ||
325 | 329 | ||
326 | instance Element (Complex Double) where | 330 | instance Element (Complex Double) where |
327 | constantD = constantAux cconstantC | 331 | constantD = constantAux cconstantC |
@@ -334,6 +338,7 @@ instance Element (Complex Double) where | |||
334 | remapM = remapC | 338 | remapM = remapC |
335 | rowOp = rowOpAux c_rowOpC | 339 | rowOp = rowOpAux c_rowOpC |
336 | gemm = gemmg c_gemmC | 340 | gemm = gemmg c_gemmC |
341 | reorderV = reorderAux c_reorderC | ||
337 | 342 | ||
338 | instance Element (CInt) where | 343 | instance Element (CInt) where |
339 | constantD = constantAux cconstantI | 344 | constantD = constantAux cconstantI |
@@ -346,6 +351,7 @@ instance Element (CInt) where | |||
346 | remapM = remapI | 351 | remapM = remapI |
347 | rowOp = rowOpAux c_rowOpI | 352 | rowOp = rowOpAux c_rowOpI |
348 | gemm = gemmg c_gemmI | 353 | gemm = gemmg c_gemmI |
354 | reorderV = reorderAux c_reorderI | ||
349 | 355 | ||
350 | instance Element Z where | 356 | instance Element Z where |
351 | constantD = constantAux cconstantL | 357 | constantD = constantAux cconstantL |
@@ -358,6 +364,7 @@ instance Element Z where | |||
358 | remapM = remapL | 364 | remapM = remapL |
359 | rowOp = rowOpAux c_rowOpL | 365 | rowOp = rowOpAux c_rowOpL |
360 | gemm = gemmg c_gemmL | 366 | gemm = gemmg c_gemmL |
367 | reorderV = reorderAux c_reorderL | ||
361 | 368 | ||
362 | ------------------------------------------------------------------- | 369 | ------------------------------------------------------------------- |
363 | 370 | ||
@@ -580,6 +587,33 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
580 | 587 | ||
581 | -------------------------------------------------------------------------------- | 588 | -------------------------------------------------------------------------------- |
582 | 589 | ||
590 | reorderAux f s d v = unsafePerformIO $ do | ||
591 | k <- createVector (dim s) | ||
592 | r <- createVector (dim v) | ||
593 | (k # s # d # v #! r) f #| "reorderV" | ||
594 | return r | ||
595 | |||
596 | type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) | ||
597 | |||
598 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | ||
599 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | ||
600 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt | ||
601 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | ||
602 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | ||
603 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | ||
604 | |||
605 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, | ||
606 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ | ||
607 | -- This function is intended to be used internally by tensor libraries.. | ||
608 | reorderVector :: Element a | ||
609 | => Vector CInt -- ^ @strides@: array strides | ||
610 | -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ | ||
611 | -> Vector a -- ^ @v@: flattened input array | ||
612 | -> Vector a -- ^ @v'@: flattened output array | ||
613 | reorderVector = reorderV | ||
614 | |||
615 | -------------------------------------------------------------------------------- | ||
616 | |||
583 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | 617 | foreign import ccall unsafe "saveMatrix" c_saveMatrix |
584 | :: CString -> CString -> Double ::> Ok | 618 | :: CString -> CString -> Double ::> Ok |
585 | 619 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 941b597..e974fc4 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -55,7 +55,7 @@ module Numeric.LinearAlgebra.Devel( | |||
55 | GMatrix(..), | 55 | GMatrix(..), |
56 | 56 | ||
57 | -- * Misc | 57 | -- * Misc |
58 | toByteString, fromByteString, showInternal | 58 | toByteString, fromByteString, showInternal, reorderVector |
59 | 59 | ||
60 | ) where | 60 | ) where |
61 | 61 | ||