summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
authorKevin Slagle <kjslag@gmail.com>2016-12-19 09:40:37 -0500
committerKevin Slagle <kjslag@gmail.com>2016-12-19 09:40:37 -0500
commit0431e82183a925e63472bbc9a17db4eb84f904a6 (patch)
treefaeb333ead212fbe704640b32b9638f1328e5946 /packages/base/src/Internal/Matrix.hs
parent66f0174cec6b9b3a329321a435d0c7841f396077 (diff)
add reorderVector function for tensor libraries (e.g. hTensor) to implement tensor transpose
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs34
1 files changed, 34 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 0135288..fbf8295 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -285,6 +285,7 @@ class (Storable a) => Element a where
285 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 285 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
286 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 286 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
287 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () 287 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
288 reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
288 289
289 290
290instance Element Float where 291instance Element Float where
@@ -298,6 +299,7 @@ instance Element Float where
298 remapM = remapF 299 remapM = remapF
299 rowOp = rowOpAux c_rowOpF 300 rowOp = rowOpAux c_rowOpF
300 gemm = gemmg c_gemmF 301 gemm = gemmg c_gemmF
302 reorderV = reorderAux c_reorderF
301 303
302instance Element Double where 304instance Element Double where
303 constantD = constantAux cconstantR 305 constantD = constantAux cconstantR
@@ -310,6 +312,7 @@ instance Element Double where
310 remapM = remapD 312 remapM = remapD
311 rowOp = rowOpAux c_rowOpD 313 rowOp = rowOpAux c_rowOpD
312 gemm = gemmg c_gemmD 314 gemm = gemmg c_gemmD
315 reorderV = reorderAux c_reorderD
313 316
314instance Element (Complex Float) where 317instance Element (Complex Float) where
315 constantD = constantAux cconstantQ 318 constantD = constantAux cconstantQ
@@ -322,6 +325,7 @@ instance Element (Complex Float) where
322 remapM = remapQ 325 remapM = remapQ
323 rowOp = rowOpAux c_rowOpQ 326 rowOp = rowOpAux c_rowOpQ
324 gemm = gemmg c_gemmQ 327 gemm = gemmg c_gemmQ
328 reorderV = reorderAux c_reorderQ
325 329
326instance Element (Complex Double) where 330instance Element (Complex Double) where
327 constantD = constantAux cconstantC 331 constantD = constantAux cconstantC
@@ -334,6 +338,7 @@ instance Element (Complex Double) where
334 remapM = remapC 338 remapM = remapC
335 rowOp = rowOpAux c_rowOpC 339 rowOp = rowOpAux c_rowOpC
336 gemm = gemmg c_gemmC 340 gemm = gemmg c_gemmC
341 reorderV = reorderAux c_reorderC
337 342
338instance Element (CInt) where 343instance Element (CInt) where
339 constantD = constantAux cconstantI 344 constantD = constantAux cconstantI
@@ -346,6 +351,7 @@ instance Element (CInt) where
346 remapM = remapI 351 remapM = remapI
347 rowOp = rowOpAux c_rowOpI 352 rowOp = rowOpAux c_rowOpI
348 gemm = gemmg c_gemmI 353 gemm = gemmg c_gemmI
354 reorderV = reorderAux c_reorderI
349 355
350instance Element Z where 356instance Element Z where
351 constantD = constantAux cconstantL 357 constantD = constantAux cconstantL
@@ -358,6 +364,7 @@ instance Element Z where
358 remapM = remapL 364 remapM = remapL
359 rowOp = rowOpAux c_rowOpL 365 rowOp = rowOpAux c_rowOpL
360 gemm = gemmg c_gemmL 366 gemm = gemmg c_gemmL
367 reorderV = reorderAux c_reorderL
361 368
362------------------------------------------------------------------- 369-------------------------------------------------------------------
363 370
@@ -580,6 +587,33 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
580 587
581-------------------------------------------------------------------------------- 588--------------------------------------------------------------------------------
582 589
590reorderAux f s d v = unsafePerformIO $ do
591 k <- createVector (dim s)
592 r <- createVector (dim v)
593 (k # s # d # v #! r) f #| "reorderV"
594 return r
595
596type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt)))))
597
598foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
599foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
600foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt
601foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
602foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
603foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
604
605-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices,
606-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@
607-- This function is intended to be used internally by tensor libraries..
608reorderVector :: Element a
609 => Vector CInt -- ^ @strides@: array strides
610 -> Vector CInt -- ^ @dims@: array dimensions of new array @v@
611 -> Vector a -- ^ @v@: flattened input array
612 -> Vector a -- ^ @v'@: flattened output array
613reorderVector = reorderV
614
615--------------------------------------------------------------------------------
616
583foreign import ccall unsafe "saveMatrix" c_saveMatrix 617foreign import ccall unsafe "saveMatrix" c_saveMatrix
584 :: CString -> CString -> Double ::> Ok 618 :: CString -> CString -> Double ::> Ok
585 619