diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 382 |
1 files changed, 0 insertions, 382 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 7c774ef..225b039 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -37,31 +37,6 @@ import Text.Printf | |||
37 | 37 | ||
38 | ----------------------------------------------------------------- | 38 | ----------------------------------------------------------------- |
39 | 39 | ||
40 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
41 | |||
42 | -- | Matrix representation suitable for BLAS\/LAPACK computations. | ||
43 | |||
44 | data Matrix t = Matrix | ||
45 | { irows :: {-# UNPACK #-} !Int | ||
46 | , icols :: {-# UNPACK #-} !Int | ||
47 | , xRow :: {-# UNPACK #-} !Int | ||
48 | , xCol :: {-# UNPACK #-} !Int | ||
49 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
50 | } | ||
51 | |||
52 | |||
53 | rows :: Matrix t -> Int | ||
54 | rows = irows | ||
55 | {-# INLINE rows #-} | ||
56 | |||
57 | cols :: Matrix t -> Int | ||
58 | cols = icols | ||
59 | {-# INLINE cols #-} | ||
60 | |||
61 | size :: Matrix t -> (Int, Int) | ||
62 | size m = (irows m, icols m) | ||
63 | {-# INLINE size #-} | ||
64 | |||
65 | rowOrder :: Matrix t -> Bool | 40 | rowOrder :: Matrix t -> Bool |
66 | rowOrder m = xCol m == 1 || cols m == 1 | 41 | rowOrder m = xCol m == 1 || cols m == 1 |
67 | {-# INLINE rowOrder #-} | 42 | {-# INLINE rowOrder #-} |
@@ -114,33 +89,6 @@ fmat m | |||
114 | | otherwise = extractAll ColumnMajor m | 89 | | otherwise = extractAll ColumnMajor m |
115 | 90 | ||
116 | 91 | ||
117 | -- C-Haskell matrix adapters | ||
118 | {-# INLINE amatr #-} | ||
119 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r | ||
120 | amatr x f g = unsafeWith (xdat x) (f . g r c) | ||
121 | where | ||
122 | r = fi (rows x) | ||
123 | c = fi (cols x) | ||
124 | |||
125 | {-# INLINE amat #-} | ||
126 | amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r | ||
127 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) | ||
128 | where | ||
129 | r = fi (rows x) | ||
130 | c = fi (cols x) | ||
131 | sr = fi (xRow x) | ||
132 | sc = fi (xCol x) | ||
133 | |||
134 | |||
135 | instance Storable t => TransArray (Matrix t) | ||
136 | where | ||
137 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | ||
138 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | ||
139 | apply = amat | ||
140 | {-# INLINE apply #-} | ||
141 | applyRaw = amatr | ||
142 | {-# INLINE applyRaw #-} | ||
143 | |||
144 | infixr 1 # | 92 | infixr 1 # |
145 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r | 93 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r |
146 | a # b = apply a b | 94 | a # b = apply a b |
@@ -240,22 +188,6 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | |||
240 | 188 | ||
241 | ------------------------------------------------------------------ | 189 | ------------------------------------------------------------------ |
242 | 190 | ||
243 | matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t | ||
244 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | ||
245 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | ||
246 | matrixFromVector o r c v | ||
247 | | r * c == dim v = m | ||
248 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | ||
249 | where | ||
250 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } | ||
251 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } | ||
252 | |||
253 | -- allocates memory for a new matrix | ||
254 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | ||
255 | createMatrix ord r c = do | ||
256 | p <- createVector (r*c) | ||
257 | return (matrixFromVector ord r c p) | ||
258 | |||
259 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ | 191 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ |
260 | where r is the desired number of rows.) | 192 | where r is the desired number of rows.) |
261 | 193 | ||
@@ -286,101 +218,6 @@ liftMatrix2 f m1@(size->(r,c)) m2 | |||
286 | 218 | ||
287 | ------------------------------------------------------------------ | 219 | ------------------------------------------------------------------ |
288 | 220 | ||
289 | {- | ||
290 | -- | Supported matrix elements. | ||
291 | class (Storable a) => Element a where | ||
292 | constantD :: a -> Int -> Vector a | ||
293 | extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | ||
294 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () | ||
295 | sortI :: Ord a => Vector a -> Vector CInt | ||
296 | sortV :: Ord a => Vector a -> Vector a | ||
297 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | ||
298 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | ||
299 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | ||
300 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
301 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
302 | reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | ||
303 | |||
304 | |||
305 | instance Element Float where | ||
306 | constantD = constantAux cconstantF | ||
307 | extractR = extractAux c_extractF | ||
308 | setRect = setRectAux c_setRectF | ||
309 | sortI = sortIdxF | ||
310 | sortV = sortValF | ||
311 | compareV = compareF | ||
312 | selectV = selectF | ||
313 | remapM = remapF | ||
314 | rowOp = rowOpAux c_rowOpF | ||
315 | gemm = gemmg c_gemmF | ||
316 | reorderV = reorderAux c_reorderF | ||
317 | |||
318 | instance Element Double where | ||
319 | constantD = constantAux cconstantR | ||
320 | extractR = extractAux c_extractD | ||
321 | setRect = setRectAux c_setRectD | ||
322 | sortI = sortIdxD | ||
323 | sortV = sortValD | ||
324 | compareV = compareD | ||
325 | selectV = selectD | ||
326 | remapM = remapD | ||
327 | rowOp = rowOpAux c_rowOpD | ||
328 | gemm = gemmg c_gemmD | ||
329 | reorderV = reorderAux c_reorderD | ||
330 | |||
331 | instance Element (Complex Float) where | ||
332 | constantD = constantAux cconstantQ | ||
333 | extractR = extractAux c_extractQ | ||
334 | setRect = setRectAux c_setRectQ | ||
335 | sortI = undefined | ||
336 | sortV = undefined | ||
337 | compareV = undefined | ||
338 | selectV = selectQ | ||
339 | remapM = remapQ | ||
340 | rowOp = rowOpAux c_rowOpQ | ||
341 | gemm = gemmg c_gemmQ | ||
342 | reorderV = reorderAux c_reorderQ | ||
343 | |||
344 | instance Element (Complex Double) where | ||
345 | constantD = constantAux cconstantC | ||
346 | extractR = extractAux c_extractC | ||
347 | setRect = setRectAux c_setRectC | ||
348 | sortI = undefined | ||
349 | sortV = undefined | ||
350 | compareV = undefined | ||
351 | selectV = selectC | ||
352 | remapM = remapC | ||
353 | rowOp = rowOpAux c_rowOpC | ||
354 | gemm = gemmg c_gemmC | ||
355 | reorderV = reorderAux c_reorderC | ||
356 | |||
357 | instance Element (CInt) where | ||
358 | constantD = constantAux cconstantI | ||
359 | extractR = extractAux c_extractI | ||
360 | setRect = setRectAux c_setRectI | ||
361 | sortI = sortIdxI | ||
362 | sortV = sortValI | ||
363 | compareV = compareI | ||
364 | selectV = selectI | ||
365 | remapM = remapI | ||
366 | rowOp = rowOpAux c_rowOpI | ||
367 | gemm = gemmg c_gemmI | ||
368 | reorderV = reorderAux c_reorderI | ||
369 | |||
370 | instance Element Z where | ||
371 | constantD = constantAux cconstantL | ||
372 | extractR = extractAux c_extractL | ||
373 | setRect = setRectAux c_setRectL | ||
374 | sortI = sortIdxL | ||
375 | sortV = sortValL | ||
376 | compareV = compareL | ||
377 | selectV = selectL | ||
378 | remapM = remapL | ||
379 | rowOp = rowOpAux c_rowOpL | ||
380 | gemm = gemmg c_gemmL | ||
381 | reorderV = reorderAux c_reorderL | ||
382 | -} | ||
383 | |||
384 | ------------------------------------------------------------------- | 221 | ------------------------------------------------------------------- |
385 | 222 | ||
386 | -- | reference to a rectangular slice of a matrix (no data copy) | 223 | -- | reference to a rectangular slice of a matrix (no data copy) |
@@ -435,12 +272,6 @@ repRows n x = fromRows (replicate n (flatten x)) | |||
435 | repCols :: Element t => Int -> Matrix t -> Matrix t | 272 | repCols :: Element t => Int -> Matrix t -> Matrix t |
436 | repCols n x = fromColumns (replicate n (flatten x)) | 273 | repCols n x = fromColumns (replicate n (flatten x)) |
437 | 274 | ||
438 | shSize :: Matrix t -> [Char] | ||
439 | shSize = shDim . size | ||
440 | |||
441 | shDim :: (Show a, Show a1) => (a1, a) -> [Char] | ||
442 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | ||
443 | |||
444 | emptyM :: Storable t => Int -> Int -> Matrix t | 275 | emptyM :: Storable t => Int -> Int -> Matrix t |
445 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | 276 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) |
446 | 277 | ||
@@ -456,19 +287,6 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
456 | 287 | ||
457 | --------------------------------------------------------------- | 288 | --------------------------------------------------------------- |
458 | 289 | ||
459 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
460 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
461 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
462 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
463 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
464 | extractAux f ord m moder vr modec vc = do | ||
465 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | ||
466 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | ||
467 | r <- createMatrix ord nr nc | ||
468 | (vr # vc # m #! r) (f moder modec) #|"extract" | ||
469 | |||
470 | return r | ||
471 | |||
472 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | 290 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) |
473 | 291 | ||
474 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double | 292 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double |
@@ -480,217 +298,17 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
480 | 298 | ||
481 | --------------------------------------------------------------- | 299 | --------------------------------------------------------------- |
482 | 300 | ||
483 | setRectAux :: (TransArray c1, TransArray c) | ||
484 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | ||
485 | -> Int -> Int -> c1 -> c -> IO () | ||
486 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | ||
487 | |||
488 | type SetRect x = I -> I -> x ::> x::> Ok | ||
489 | |||
490 | foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double | ||
491 | foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float | ||
492 | foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) | ||
493 | foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) | ||
494 | foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I | ||
495 | foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | ||
496 | |||
497 | -------------------------------------------------------------------------------- | 301 | -------------------------------------------------------------------------------- |
498 | 302 | ||
499 | sortG :: (Storable t, Storable a) | ||
500 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
501 | sortG f v = unsafePerformIO $ do | ||
502 | r <- createVector (dim v) | ||
503 | (v #! r) f #|"sortG" | ||
504 | return r | ||
505 | |||
506 | sortIdxD :: Vector Double -> Vector CInt | ||
507 | sortIdxD = sortG c_sort_indexD | ||
508 | sortIdxF :: Vector Float -> Vector CInt | ||
509 | sortIdxF = sortG c_sort_indexF | ||
510 | sortIdxI :: Vector CInt -> Vector CInt | ||
511 | sortIdxI = sortG c_sort_indexI | ||
512 | sortIdxL :: Vector Z -> Vector I | ||
513 | sortIdxL = sortG c_sort_indexL | ||
514 | |||
515 | sortValD :: Vector Double -> Vector Double | ||
516 | sortValD = sortG c_sort_valD | ||
517 | sortValF :: Vector Float -> Vector Float | ||
518 | sortValF = sortG c_sort_valF | ||
519 | sortValI :: Vector CInt -> Vector CInt | ||
520 | sortValI = sortG c_sort_valI | ||
521 | sortValL :: Vector Z -> Vector Z | ||
522 | sortValL = sortG c_sort_valL | ||
523 | |||
524 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | ||
525 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) | ||
526 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) | ||
527 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok | ||
528 | |||
529 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) | ||
530 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) | ||
531 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) | ||
532 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | ||
533 | |||
534 | -------------------------------------------------------------------------------- | 303 | -------------------------------------------------------------------------------- |
535 | 304 | ||
536 | compareG :: (TransArray c, Storable t, Storable a) | ||
537 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | ||
538 | -> c -> Vector t -> Vector a | ||
539 | compareG f u v = unsafePerformIO $ do | ||
540 | r <- createVector (dim v) | ||
541 | (u # v #! r) f #|"compareG" | ||
542 | return r | ||
543 | |||
544 | compareD :: Vector Double -> Vector Double -> Vector CInt | ||
545 | compareD = compareG c_compareD | ||
546 | compareF :: Vector Float -> Vector Float -> Vector CInt | ||
547 | compareF = compareG c_compareF | ||
548 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | ||
549 | compareI = compareG c_compareI | ||
550 | compareL :: Vector Z -> Vector Z -> Vector CInt | ||
551 | compareL = compareG c_compareL | ||
552 | |||
553 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | ||
554 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) | ||
555 | foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) | ||
556 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | ||
557 | |||
558 | -------------------------------------------------------------------------------- | 305 | -------------------------------------------------------------------------------- |
559 | 306 | ||
560 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | ||
561 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | ||
562 | -> c2 -> c1 -> Vector t -> c -> Vector a | ||
563 | selectG f c u v w = unsafePerformIO $ do | ||
564 | r <- createVector (dim v) | ||
565 | (c # u # v # w #! r) f #|"selectG" | ||
566 | return r | ||
567 | |||
568 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
569 | selectD = selectG c_selectD | ||
570 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
571 | selectF = selectG c_selectF | ||
572 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | ||
573 | selectI = selectG c_selectI | ||
574 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | ||
575 | selectL = selectG c_selectL | ||
576 | selectC :: Vector CInt | ||
577 | -> Vector (Complex Double) | ||
578 | -> Vector (Complex Double) | ||
579 | -> Vector (Complex Double) | ||
580 | -> Vector (Complex Double) | ||
581 | selectC = selectG c_selectC | ||
582 | selectQ :: Vector CInt | ||
583 | -> Vector (Complex Float) | ||
584 | -> Vector (Complex Float) | ||
585 | -> Vector (Complex Float) | ||
586 | -> Vector (Complex Float) | ||
587 | selectQ = selectG c_selectQ | ||
588 | |||
589 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | ||
590 | |||
591 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double | ||
592 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float | ||
593 | foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt | ||
594 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) | ||
595 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) | ||
596 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | ||
597 | |||
598 | --------------------------------------------------------------------------- | 307 | --------------------------------------------------------------------------- |
599 | |||
600 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | ||
601 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | ||
602 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | ||
603 | -> Matrix t -> c1 -> c -> Matrix a | ||
604 | remapG f i j m = unsafePerformIO $ do | ||
605 | r <- createMatrix RowMajor (rows i) (cols i) | ||
606 | (i # j # m #! r) f #|"remapG" | ||
607 | return r | ||
608 | |||
609 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | ||
610 | remapD = remapG c_remapD | ||
611 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | ||
612 | remapF = remapG c_remapF | ||
613 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
614 | remapI = remapG c_remapI | ||
615 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | ||
616 | remapL = remapG c_remapL | ||
617 | remapC :: Matrix CInt | ||
618 | -> Matrix CInt | ||
619 | -> Matrix (Complex Double) | ||
620 | -> Matrix (Complex Double) | ||
621 | remapC = remapG c_remapC | ||
622 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | ||
623 | remapQ = remapG c_remapQ | ||
624 | |||
625 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | ||
626 | |||
627 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double | ||
628 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float | ||
629 | foreign import ccall unsafe "remapI" c_remapI :: Rem CInt | ||
630 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) | ||
631 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) | ||
632 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z | ||
633 | |||
634 | -------------------------------------------------------------------------------- | 308 | -------------------------------------------------------------------------------- |
635 | |||
636 | rowOpAux :: (TransArray c, Storable a) => | ||
637 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | ||
638 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | ||
639 | rowOpAux f c x i1 i2 j1 j2 m = do | ||
640 | px <- newArray [x] | ||
641 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | ||
642 | free px | ||
643 | |||
644 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | ||
645 | |||
646 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | ||
647 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | ||
648 | foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C | ||
649 | foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) | ||
650 | foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I | ||
651 | foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z | ||
652 | foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I | ||
653 | foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | ||
654 | |||
655 | -------------------------------------------------------------------------------- | 309 | -------------------------------------------------------------------------------- |
656 | 310 | ||
657 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | ||
658 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | ||
659 | -> c3 -> c2 -> c1 -> c -> IO () | ||
660 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | ||
661 | |||
662 | type Tgemm x = x :> x ::> x ::> x ::> Ok | ||
663 | |||
664 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | ||
665 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | ||
666 | foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C | ||
667 | foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) | ||
668 | foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I | ||
669 | foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z | ||
670 | foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I | ||
671 | foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | ||
672 | |||
673 | -------------------------------------------------------------------------------- | 311 | -------------------------------------------------------------------------------- |
674 | |||
675 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | ||
676 | (CInt -> Ptr a -> CInt -> Ptr t1 | ||
677 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | ||
678 | -> Vector t1 -> c -> Vector t -> Vector a1 | ||
679 | reorderAux f s d v = unsafePerformIO $ do | ||
680 | k <- createVector (dim s) | ||
681 | r <- createVector (dim v) | ||
682 | (k # s # d # v #! r) f #| "reorderV" | ||
683 | return r | ||
684 | |||
685 | type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) | ||
686 | |||
687 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | ||
688 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | ||
689 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt | ||
690 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | ||
691 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | ||
692 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | ||
693 | |||
694 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, | 312 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, |
695 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ | 313 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ |
696 | -- This function is intended to be used internally by tensor libraries. | 314 | -- This function is intended to be used internally by tensor libraries. |