diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 0135288..fbf8295 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -285,6 +285,7 @@ class (Storable a) => Element a where | |||
285 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 285 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
286 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 286 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
287 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | 287 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () |
288 | reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | ||
288 | 289 | ||
289 | 290 | ||
290 | instance Element Float where | 291 | instance Element Float where |
@@ -298,6 +299,7 @@ instance Element Float where | |||
298 | remapM = remapF | 299 | remapM = remapF |
299 | rowOp = rowOpAux c_rowOpF | 300 | rowOp = rowOpAux c_rowOpF |
300 | gemm = gemmg c_gemmF | 301 | gemm = gemmg c_gemmF |
302 | reorderV = reorderAux c_reorderF | ||
301 | 303 | ||
302 | instance Element Double where | 304 | instance Element Double where |
303 | constantD = constantAux cconstantR | 305 | constantD = constantAux cconstantR |
@@ -310,6 +312,7 @@ instance Element Double where | |||
310 | remapM = remapD | 312 | remapM = remapD |
311 | rowOp = rowOpAux c_rowOpD | 313 | rowOp = rowOpAux c_rowOpD |
312 | gemm = gemmg c_gemmD | 314 | gemm = gemmg c_gemmD |
315 | reorderV = reorderAux c_reorderD | ||
313 | 316 | ||
314 | instance Element (Complex Float) where | 317 | instance Element (Complex Float) where |
315 | constantD = constantAux cconstantQ | 318 | constantD = constantAux cconstantQ |
@@ -322,6 +325,7 @@ instance Element (Complex Float) where | |||
322 | remapM = remapQ | 325 | remapM = remapQ |
323 | rowOp = rowOpAux c_rowOpQ | 326 | rowOp = rowOpAux c_rowOpQ |
324 | gemm = gemmg c_gemmQ | 327 | gemm = gemmg c_gemmQ |
328 | reorderV = reorderAux c_reorderQ | ||
325 | 329 | ||
326 | instance Element (Complex Double) where | 330 | instance Element (Complex Double) where |
327 | constantD = constantAux cconstantC | 331 | constantD = constantAux cconstantC |
@@ -334,6 +338,7 @@ instance Element (Complex Double) where | |||
334 | remapM = remapC | 338 | remapM = remapC |
335 | rowOp = rowOpAux c_rowOpC | 339 | rowOp = rowOpAux c_rowOpC |
336 | gemm = gemmg c_gemmC | 340 | gemm = gemmg c_gemmC |
341 | reorderV = reorderAux c_reorderC | ||
337 | 342 | ||
338 | instance Element (CInt) where | 343 | instance Element (CInt) where |
339 | constantD = constantAux cconstantI | 344 | constantD = constantAux cconstantI |
@@ -346,6 +351,7 @@ instance Element (CInt) where | |||
346 | remapM = remapI | 351 | remapM = remapI |
347 | rowOp = rowOpAux c_rowOpI | 352 | rowOp = rowOpAux c_rowOpI |
348 | gemm = gemmg c_gemmI | 353 | gemm = gemmg c_gemmI |
354 | reorderV = reorderAux c_reorderI | ||
349 | 355 | ||
350 | instance Element Z where | 356 | instance Element Z where |
351 | constantD = constantAux cconstantL | 357 | constantD = constantAux cconstantL |
@@ -358,6 +364,7 @@ instance Element Z where | |||
358 | remapM = remapL | 364 | remapM = remapL |
359 | rowOp = rowOpAux c_rowOpL | 365 | rowOp = rowOpAux c_rowOpL |
360 | gemm = gemmg c_gemmL | 366 | gemm = gemmg c_gemmL |
367 | reorderV = reorderAux c_reorderL | ||
361 | 368 | ||
362 | ------------------------------------------------------------------- | 369 | ------------------------------------------------------------------- |
363 | 370 | ||
@@ -580,6 +587,33 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
580 | 587 | ||
581 | -------------------------------------------------------------------------------- | 588 | -------------------------------------------------------------------------------- |
582 | 589 | ||
590 | reorderAux f s d v = unsafePerformIO $ do | ||
591 | k <- createVector (dim s) | ||
592 | r <- createVector (dim v) | ||
593 | (k # s # d # v #! r) f #| "reorderV" | ||
594 | return r | ||
595 | |||
596 | type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) | ||
597 | |||
598 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | ||
599 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | ||
600 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt | ||
601 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | ||
602 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | ||
603 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | ||
604 | |||
605 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, | ||
606 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ | ||
607 | -- This function is intended to be used internally by tensor libraries.. | ||
608 | reorderVector :: Element a | ||
609 | => Vector CInt -- ^ @strides@: array strides | ||
610 | -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ | ||
611 | -> Vector a -- ^ @v@: flattened input array | ||
612 | -> Vector a -- ^ @v'@: flattened output array | ||
613 | reorderVector = reorderV | ||
614 | |||
615 | -------------------------------------------------------------------------------- | ||
616 | |||
583 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | 617 | foreign import ccall unsafe "saveMatrix" c_saveMatrix |
584 | :: CString -> CString -> Double ::> Ok | 618 | :: CString -> CString -> Double ::> Ok |
585 | 619 | ||