diff options
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 382 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 10 | ||||
-rw-r--r-- | packages/base/src/Internal/Specialized.hs | 561 |
3 files changed, 476 insertions, 477 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 7c774ef..225b039 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -37,31 +37,6 @@ import Text.Printf | |||
37 | 37 | ||
38 | ----------------------------------------------------------------- | 38 | ----------------------------------------------------------------- |
39 | 39 | ||
40 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
41 | |||
42 | -- | Matrix representation suitable for BLAS\/LAPACK computations. | ||
43 | |||
44 | data Matrix t = Matrix | ||
45 | { irows :: {-# UNPACK #-} !Int | ||
46 | , icols :: {-# UNPACK #-} !Int | ||
47 | , xRow :: {-# UNPACK #-} !Int | ||
48 | , xCol :: {-# UNPACK #-} !Int | ||
49 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
50 | } | ||
51 | |||
52 | |||
53 | rows :: Matrix t -> Int | ||
54 | rows = irows | ||
55 | {-# INLINE rows #-} | ||
56 | |||
57 | cols :: Matrix t -> Int | ||
58 | cols = icols | ||
59 | {-# INLINE cols #-} | ||
60 | |||
61 | size :: Matrix t -> (Int, Int) | ||
62 | size m = (irows m, icols m) | ||
63 | {-# INLINE size #-} | ||
64 | |||
65 | rowOrder :: Matrix t -> Bool | 40 | rowOrder :: Matrix t -> Bool |
66 | rowOrder m = xCol m == 1 || cols m == 1 | 41 | rowOrder m = xCol m == 1 || cols m == 1 |
67 | {-# INLINE rowOrder #-} | 42 | {-# INLINE rowOrder #-} |
@@ -114,33 +89,6 @@ fmat m | |||
114 | | otherwise = extractAll ColumnMajor m | 89 | | otherwise = extractAll ColumnMajor m |
115 | 90 | ||
116 | 91 | ||
117 | -- C-Haskell matrix adapters | ||
118 | {-# INLINE amatr #-} | ||
119 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r | ||
120 | amatr x f g = unsafeWith (xdat x) (f . g r c) | ||
121 | where | ||
122 | r = fi (rows x) | ||
123 | c = fi (cols x) | ||
124 | |||
125 | {-# INLINE amat #-} | ||
126 | amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r | ||
127 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) | ||
128 | where | ||
129 | r = fi (rows x) | ||
130 | c = fi (cols x) | ||
131 | sr = fi (xRow x) | ||
132 | sc = fi (xCol x) | ||
133 | |||
134 | |||
135 | instance Storable t => TransArray (Matrix t) | ||
136 | where | ||
137 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | ||
138 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | ||
139 | apply = amat | ||
140 | {-# INLINE apply #-} | ||
141 | applyRaw = amatr | ||
142 | {-# INLINE applyRaw #-} | ||
143 | |||
144 | infixr 1 # | 92 | infixr 1 # |
145 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r | 93 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r |
146 | a # b = apply a b | 94 | a # b = apply a b |
@@ -240,22 +188,6 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | |||
240 | 188 | ||
241 | ------------------------------------------------------------------ | 189 | ------------------------------------------------------------------ |
242 | 190 | ||
243 | matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t | ||
244 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | ||
245 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | ||
246 | matrixFromVector o r c v | ||
247 | | r * c == dim v = m | ||
248 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | ||
249 | where | ||
250 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } | ||
251 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } | ||
252 | |||
253 | -- allocates memory for a new matrix | ||
254 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | ||
255 | createMatrix ord r c = do | ||
256 | p <- createVector (r*c) | ||
257 | return (matrixFromVector ord r c p) | ||
258 | |||
259 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ | 191 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ |
260 | where r is the desired number of rows.) | 192 | where r is the desired number of rows.) |
261 | 193 | ||
@@ -286,101 +218,6 @@ liftMatrix2 f m1@(size->(r,c)) m2 | |||
286 | 218 | ||
287 | ------------------------------------------------------------------ | 219 | ------------------------------------------------------------------ |
288 | 220 | ||
289 | {- | ||
290 | -- | Supported matrix elements. | ||
291 | class (Storable a) => Element a where | ||
292 | constantD :: a -> Int -> Vector a | ||
293 | extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | ||
294 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () | ||
295 | sortI :: Ord a => Vector a -> Vector CInt | ||
296 | sortV :: Ord a => Vector a -> Vector a | ||
297 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | ||
298 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | ||
299 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | ||
300 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
301 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
302 | reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | ||
303 | |||
304 | |||
305 | instance Element Float where | ||
306 | constantD = constantAux cconstantF | ||
307 | extractR = extractAux c_extractF | ||
308 | setRect = setRectAux c_setRectF | ||
309 | sortI = sortIdxF | ||
310 | sortV = sortValF | ||
311 | compareV = compareF | ||
312 | selectV = selectF | ||
313 | remapM = remapF | ||
314 | rowOp = rowOpAux c_rowOpF | ||
315 | gemm = gemmg c_gemmF | ||
316 | reorderV = reorderAux c_reorderF | ||
317 | |||
318 | instance Element Double where | ||
319 | constantD = constantAux cconstantR | ||
320 | extractR = extractAux c_extractD | ||
321 | setRect = setRectAux c_setRectD | ||
322 | sortI = sortIdxD | ||
323 | sortV = sortValD | ||
324 | compareV = compareD | ||
325 | selectV = selectD | ||
326 | remapM = remapD | ||
327 | rowOp = rowOpAux c_rowOpD | ||
328 | gemm = gemmg c_gemmD | ||
329 | reorderV = reorderAux c_reorderD | ||
330 | |||
331 | instance Element (Complex Float) where | ||
332 | constantD = constantAux cconstantQ | ||
333 | extractR = extractAux c_extractQ | ||
334 | setRect = setRectAux c_setRectQ | ||
335 | sortI = undefined | ||
336 | sortV = undefined | ||
337 | compareV = undefined | ||
338 | selectV = selectQ | ||
339 | remapM = remapQ | ||
340 | rowOp = rowOpAux c_rowOpQ | ||
341 | gemm = gemmg c_gemmQ | ||
342 | reorderV = reorderAux c_reorderQ | ||
343 | |||
344 | instance Element (Complex Double) where | ||
345 | constantD = constantAux cconstantC | ||
346 | extractR = extractAux c_extractC | ||
347 | setRect = setRectAux c_setRectC | ||
348 | sortI = undefined | ||
349 | sortV = undefined | ||
350 | compareV = undefined | ||
351 | selectV = selectC | ||
352 | remapM = remapC | ||
353 | rowOp = rowOpAux c_rowOpC | ||
354 | gemm = gemmg c_gemmC | ||
355 | reorderV = reorderAux c_reorderC | ||
356 | |||
357 | instance Element (CInt) where | ||
358 | constantD = constantAux cconstantI | ||
359 | extractR = extractAux c_extractI | ||
360 | setRect = setRectAux c_setRectI | ||
361 | sortI = sortIdxI | ||
362 | sortV = sortValI | ||
363 | compareV = compareI | ||
364 | selectV = selectI | ||
365 | remapM = remapI | ||
366 | rowOp = rowOpAux c_rowOpI | ||
367 | gemm = gemmg c_gemmI | ||
368 | reorderV = reorderAux c_reorderI | ||
369 | |||
370 | instance Element Z where | ||
371 | constantD = constantAux cconstantL | ||
372 | extractR = extractAux c_extractL | ||
373 | setRect = setRectAux c_setRectL | ||
374 | sortI = sortIdxL | ||
375 | sortV = sortValL | ||
376 | compareV = compareL | ||
377 | selectV = selectL | ||
378 | remapM = remapL | ||
379 | rowOp = rowOpAux c_rowOpL | ||
380 | gemm = gemmg c_gemmL | ||
381 | reorderV = reorderAux c_reorderL | ||
382 | -} | ||
383 | |||
384 | ------------------------------------------------------------------- | 221 | ------------------------------------------------------------------- |
385 | 222 | ||
386 | -- | reference to a rectangular slice of a matrix (no data copy) | 223 | -- | reference to a rectangular slice of a matrix (no data copy) |
@@ -435,12 +272,6 @@ repRows n x = fromRows (replicate n (flatten x)) | |||
435 | repCols :: Element t => Int -> Matrix t -> Matrix t | 272 | repCols :: Element t => Int -> Matrix t -> Matrix t |
436 | repCols n x = fromColumns (replicate n (flatten x)) | 273 | repCols n x = fromColumns (replicate n (flatten x)) |
437 | 274 | ||
438 | shSize :: Matrix t -> [Char] | ||
439 | shSize = shDim . size | ||
440 | |||
441 | shDim :: (Show a, Show a1) => (a1, a) -> [Char] | ||
442 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | ||
443 | |||
444 | emptyM :: Storable t => Int -> Int -> Matrix t | 275 | emptyM :: Storable t => Int -> Int -> Matrix t |
445 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | 276 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) |
446 | 277 | ||
@@ -456,19 +287,6 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
456 | 287 | ||
457 | --------------------------------------------------------------- | 288 | --------------------------------------------------------------- |
458 | 289 | ||
459 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
460 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
461 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
462 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
463 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
464 | extractAux f ord m moder vr modec vc = do | ||
465 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | ||
466 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | ||
467 | r <- createMatrix ord nr nc | ||
468 | (vr # vc # m #! r) (f moder modec) #|"extract" | ||
469 | |||
470 | return r | ||
471 | |||
472 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | 290 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) |
473 | 291 | ||
474 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double | 292 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double |
@@ -480,217 +298,17 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
480 | 298 | ||
481 | --------------------------------------------------------------- | 299 | --------------------------------------------------------------- |
482 | 300 | ||
483 | setRectAux :: (TransArray c1, TransArray c) | ||
484 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | ||
485 | -> Int -> Int -> c1 -> c -> IO () | ||
486 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | ||
487 | |||
488 | type SetRect x = I -> I -> x ::> x::> Ok | ||
489 | |||
490 | foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double | ||
491 | foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float | ||
492 | foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) | ||
493 | foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) | ||
494 | foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I | ||
495 | foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | ||
496 | |||
497 | -------------------------------------------------------------------------------- | 301 | -------------------------------------------------------------------------------- |
498 | 302 | ||
499 | sortG :: (Storable t, Storable a) | ||
500 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
501 | sortG f v = unsafePerformIO $ do | ||
502 | r <- createVector (dim v) | ||
503 | (v #! r) f #|"sortG" | ||
504 | return r | ||
505 | |||
506 | sortIdxD :: Vector Double -> Vector CInt | ||
507 | sortIdxD = sortG c_sort_indexD | ||
508 | sortIdxF :: Vector Float -> Vector CInt | ||
509 | sortIdxF = sortG c_sort_indexF | ||
510 | sortIdxI :: Vector CInt -> Vector CInt | ||
511 | sortIdxI = sortG c_sort_indexI | ||
512 | sortIdxL :: Vector Z -> Vector I | ||
513 | sortIdxL = sortG c_sort_indexL | ||
514 | |||
515 | sortValD :: Vector Double -> Vector Double | ||
516 | sortValD = sortG c_sort_valD | ||
517 | sortValF :: Vector Float -> Vector Float | ||
518 | sortValF = sortG c_sort_valF | ||
519 | sortValI :: Vector CInt -> Vector CInt | ||
520 | sortValI = sortG c_sort_valI | ||
521 | sortValL :: Vector Z -> Vector Z | ||
522 | sortValL = sortG c_sort_valL | ||
523 | |||
524 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | ||
525 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) | ||
526 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) | ||
527 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok | ||
528 | |||
529 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) | ||
530 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) | ||
531 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) | ||
532 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | ||
533 | |||
534 | -------------------------------------------------------------------------------- | 303 | -------------------------------------------------------------------------------- |
535 | 304 | ||
536 | compareG :: (TransArray c, Storable t, Storable a) | ||
537 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | ||
538 | -> c -> Vector t -> Vector a | ||
539 | compareG f u v = unsafePerformIO $ do | ||
540 | r <- createVector (dim v) | ||
541 | (u # v #! r) f #|"compareG" | ||
542 | return r | ||
543 | |||
544 | compareD :: Vector Double -> Vector Double -> Vector CInt | ||
545 | compareD = compareG c_compareD | ||
546 | compareF :: Vector Float -> Vector Float -> Vector CInt | ||
547 | compareF = compareG c_compareF | ||
548 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | ||
549 | compareI = compareG c_compareI | ||
550 | compareL :: Vector Z -> Vector Z -> Vector CInt | ||
551 | compareL = compareG c_compareL | ||
552 | |||
553 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | ||
554 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) | ||
555 | foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) | ||
556 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | ||
557 | |||
558 | -------------------------------------------------------------------------------- | 305 | -------------------------------------------------------------------------------- |
559 | 306 | ||
560 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | ||
561 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | ||
562 | -> c2 -> c1 -> Vector t -> c -> Vector a | ||
563 | selectG f c u v w = unsafePerformIO $ do | ||
564 | r <- createVector (dim v) | ||
565 | (c # u # v # w #! r) f #|"selectG" | ||
566 | return r | ||
567 | |||
568 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
569 | selectD = selectG c_selectD | ||
570 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
571 | selectF = selectG c_selectF | ||
572 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | ||
573 | selectI = selectG c_selectI | ||
574 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | ||
575 | selectL = selectG c_selectL | ||
576 | selectC :: Vector CInt | ||
577 | -> Vector (Complex Double) | ||
578 | -> Vector (Complex Double) | ||
579 | -> Vector (Complex Double) | ||
580 | -> Vector (Complex Double) | ||
581 | selectC = selectG c_selectC | ||
582 | selectQ :: Vector CInt | ||
583 | -> Vector (Complex Float) | ||
584 | -> Vector (Complex Float) | ||
585 | -> Vector (Complex Float) | ||
586 | -> Vector (Complex Float) | ||
587 | selectQ = selectG c_selectQ | ||
588 | |||
589 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | ||
590 | |||
591 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double | ||
592 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float | ||
593 | foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt | ||
594 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) | ||
595 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) | ||
596 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | ||
597 | |||
598 | --------------------------------------------------------------------------- | 307 | --------------------------------------------------------------------------- |
599 | |||
600 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | ||
601 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | ||
602 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | ||
603 | -> Matrix t -> c1 -> c -> Matrix a | ||
604 | remapG f i j m = unsafePerformIO $ do | ||
605 | r <- createMatrix RowMajor (rows i) (cols i) | ||
606 | (i # j # m #! r) f #|"remapG" | ||
607 | return r | ||
608 | |||
609 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | ||
610 | remapD = remapG c_remapD | ||
611 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | ||
612 | remapF = remapG c_remapF | ||
613 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
614 | remapI = remapG c_remapI | ||
615 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | ||
616 | remapL = remapG c_remapL | ||
617 | remapC :: Matrix CInt | ||
618 | -> Matrix CInt | ||
619 | -> Matrix (Complex Double) | ||
620 | -> Matrix (Complex Double) | ||
621 | remapC = remapG c_remapC | ||
622 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | ||
623 | remapQ = remapG c_remapQ | ||
624 | |||
625 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | ||
626 | |||
627 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double | ||
628 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float | ||
629 | foreign import ccall unsafe "remapI" c_remapI :: Rem CInt | ||
630 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) | ||
631 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) | ||
632 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z | ||
633 | |||
634 | -------------------------------------------------------------------------------- | 308 | -------------------------------------------------------------------------------- |
635 | |||
636 | rowOpAux :: (TransArray c, Storable a) => | ||
637 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | ||
638 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | ||
639 | rowOpAux f c x i1 i2 j1 j2 m = do | ||
640 | px <- newArray [x] | ||
641 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | ||
642 | free px | ||
643 | |||
644 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | ||
645 | |||
646 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | ||
647 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | ||
648 | foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C | ||
649 | foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) | ||
650 | foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I | ||
651 | foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z | ||
652 | foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I | ||
653 | foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | ||
654 | |||
655 | -------------------------------------------------------------------------------- | 309 | -------------------------------------------------------------------------------- |
656 | 310 | ||
657 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | ||
658 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | ||
659 | -> c3 -> c2 -> c1 -> c -> IO () | ||
660 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | ||
661 | |||
662 | type Tgemm x = x :> x ::> x ::> x ::> Ok | ||
663 | |||
664 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | ||
665 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | ||
666 | foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C | ||
667 | foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) | ||
668 | foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I | ||
669 | foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z | ||
670 | foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I | ||
671 | foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | ||
672 | |||
673 | -------------------------------------------------------------------------------- | 311 | -------------------------------------------------------------------------------- |
674 | |||
675 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | ||
676 | (CInt -> Ptr a -> CInt -> Ptr t1 | ||
677 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | ||
678 | -> Vector t1 -> c -> Vector t -> Vector a1 | ||
679 | reorderAux f s d v = unsafePerformIO $ do | ||
680 | k <- createVector (dim s) | ||
681 | r <- createVector (dim v) | ||
682 | (k # s # d # v #! r) f #| "reorderV" | ||
683 | return r | ||
684 | |||
685 | type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) | ||
686 | |||
687 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | ||
688 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | ||
689 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt | ||
690 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | ||
691 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | ||
692 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | ||
693 | |||
694 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, | 312 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, |
695 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ | 313 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ |
696 | -- This function is intended to be used internally by tensor libraries. | 314 | -- This function is intended to be used internally by tensor libraries. |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index a211dd3..10ff8a3 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -257,16 +257,6 @@ instance KnownNat m => Normed (Vector (Mod m Z)) | |||
257 | instance KnownNat m => Numeric (Mod m I) | 257 | instance KnownNat m => Numeric (Mod m I) |
258 | instance KnownNat m => Numeric (Mod m Z) | 258 | instance KnownNat m => Numeric (Mod m Z) |
259 | 259 | ||
260 | f2i :: Storable t => Vector (Mod n t) -> Vector t | ||
261 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
262 | where (fp,i,n) = unsafeToForeignPtr v | ||
263 | |||
264 | f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t | ||
265 | f2iM m = m { xdat = f2i (xdat m) } | ||
266 | |||
267 | i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t) | ||
268 | i2fM m = m { xdat = i2f (xdat m) } | ||
269 | |||
270 | vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) | 260 | vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) |
271 | vmod = i2f . cmod' m' | 261 | vmod = i2f . cmod' m' |
272 | where | 262 | where |
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs index c79194f..c063369 100644 --- a/packages/base/src/Internal/Specialized.hs +++ b/packages/base/src/Internal/Specialized.hs | |||
@@ -8,6 +8,8 @@ | |||
8 | {-# LANGUAGE RankNTypes #-} | 8 | {-# LANGUAGE RankNTypes #-} |
9 | {-# LANGUAGE ScopedTypeVariables #-} | 9 | {-# LANGUAGE ScopedTypeVariables #-} |
10 | {-# LANGUAGE KindSignatures #-} | 10 | {-# LANGUAGE KindSignatures #-} |
11 | {-# LANGUAGE ViewPatterns #-} | ||
12 | {-# LANGUAGE LambdaCase #-} | ||
11 | module Internal.Specialized where | 13 | module Internal.Specialized where |
12 | 14 | ||
13 | import Control.Monad | 15 | import Control.Monad |
@@ -16,6 +18,7 @@ import Data.Coerce | |||
16 | import Data.Complex | 18 | import Data.Complex |
17 | import Data.Functor | 19 | import Data.Functor |
18 | import Data.Int | 20 | import Data.Int |
21 | import Data.Maybe | ||
19 | import Data.Typeable (eqT,Proxy) | 22 | import Data.Typeable (eqT,Proxy) |
20 | import Type.Reflection | 23 | import Type.Reflection |
21 | import Foreign.Marshal.Alloc(free,malloc) | 24 | import Foreign.Marshal.Alloc(free,malloc) |
@@ -31,127 +34,281 @@ import GHC.TypeLits hiding (Mod) | |||
31 | import GHC.TypeLits | 34 | import GHC.TypeLits |
32 | #endif | 35 | #endif |
33 | 36 | ||
34 | import Internal.Vector (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr) | 37 | import Internal.Vector -- (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr,(@>)) |
35 | import Internal.Devel | 38 | import Internal.Devel |
36 | 39 | ||
37 | eqt :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) | 40 | eqp :: (Typeable a, Typeable b) => proxy a -> Maybe (a :~: b) |
38 | eqt _ = eqT | 41 | eqp _ = eqT |
39 | eq32 :: (Typeable a) => a -> Maybe (a :~: Int32) | 42 | ep32 :: (Typeable a) => proxy a -> Maybe (a :~: Int32) |
40 | eq32 _ = eqT | 43 | ep32 _ = eqT |
41 | eq64 :: (Typeable a) => a -> Maybe (a :~: Int64) | 44 | ep64 :: (Typeable a) => proxy a -> Maybe (a :~: Int64) |
42 | eq64 _ = eqT | 45 | ep64 _ = eqT |
43 | eqint :: (Typeable a) => a -> Maybe (a :~: CInt) | 46 | epint :: (Typeable a) => proxy a -> Maybe (a :~: CInt) |
44 | eqint _ = eqT | 47 | epint _ = eqT |
45 | 48 | ||
46 | type Element t = (Storable t, Typeable t) | 49 | type Element t = (Storable t, Typeable t) |
47 | 50 | ||
51 | -- | Wrapper with a phantom integer for statically checked modular arithmetic. | ||
52 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | ||
53 | deriving (Storable) | ||
54 | |||
55 | instance (NFData t) => NFData (Mod n t) | ||
56 | where | ||
57 | rnf (Mod x) = rnf x | ||
58 | |||
59 | i2fM :: Storable t => Matrix t -> Matrix (Mod n t) | ||
60 | i2fM m = m { xdat = i2f (xdat m) } | ||
61 | |||
62 | i2f :: Storable t => Vector t -> Vector (Mod n t) | ||
63 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
64 | where (fp,i,n) = unsafeToForeignPtr v | ||
65 | |||
66 | f2i :: Storable t => Vector (Mod n t) -> Vector t | ||
67 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
68 | where (fp,i,n) = unsafeToForeignPtr v | ||
69 | |||
70 | f2iM :: Storable t => Matrix (Mod n t) -> Matrix t | ||
71 | f2iM m = m { xdat = f2i (xdat m) } | ||
72 | |||
73 | data IntegralRep t a = IntegralRep | ||
74 | { i2rep :: Vector t -> Vector a | ||
75 | , i2repM :: Matrix t -> Matrix a | ||
76 | , rep2i :: Vector a -> Vector t | ||
77 | , rep2iM :: Matrix a -> Matrix t | ||
78 | , rep2one :: a -> t | ||
79 | , modulo :: Maybe t | ||
80 | } | ||
81 | |||
82 | idint :: Storable t => IntegralRep t t | ||
83 | idint = IntegralRep id id id id id Nothing | ||
84 | |||
85 | coerceint :: Coercible t a => IntegralRep t a | ||
86 | coerceint = IntegralRep coerce coerce coerce coerce coerce Nothing | ||
87 | |||
88 | modint :: forall t n. (Read t, Storable t) => TypeRep n -> IntegralRep t (Mod n t) | ||
89 | modint r = IntegralRep i2f i2fM f2i f2iM unMod (Just n) | ||
90 | where | ||
91 | n = read . show $ r -- XXX: Hack to get nat value from Type.Reflection | ||
92 | -- n = fromIntegral . natVal $ (undefined :: Proxy n) | ||
93 | |||
94 | |||
95 | typeRepOf :: Typeable a => proxy a -> TypeRep a | ||
96 | typeRepOf proxy = typeRep | ||
97 | |||
48 | data Specialized a | 98 | data Specialized a |
49 | = SpFloat !(a :~: Float) | 99 | = SpFloat !(a :~: Float) |
50 | | SpDouble !(a :~: Double) | 100 | | SpDouble !(a :~: Double) |
51 | | SpCFloat !(a :~: Complex Float) | 101 | | SpCFloat !(a :~: Complex Float) |
52 | | SpCDouble !(a :~: Complex Double) | 102 | | SpCDouble !(a :~: Complex Double) |
53 | | SpInt32 !(Vector Int32 -> Vector a) !Int32 | 103 | | SpInt32 !(IntegralRep Int32 a) |
54 | | SpInt64 !(Vector Int64 -> Vector a) !Int64 | 104 | | SpInt64 !(IntegralRep Int64 a) |
55 | -- | SpModInt32 !Int32 Int32 !(forall f. f Int32 -> f a) | ||
56 | -- | SpModInt64 !Int32 Int64 !(forall f. f Int64 -> f a) | ||
57 | 105 | ||
58 | specialize :: forall a. Typeable a => a -> Maybe (Specialized a) | 106 | specialize :: forall m a. Typeable a => m a -> Maybe (Specialized a) |
59 | specialize x = foldr1 mplus | 107 | specialize x = foldr1 mplus |
60 | [ SpDouble <$> eqt x | 108 | [ SpDouble <$> eqp x |
61 | , eq64 x <&> \Refl -> SpInt64 id x | 109 | , ep64 x <&> \Refl -> SpInt64 idint |
62 | , SpFloat <$> eqt x | 110 | , SpFloat <$> eqp x |
63 | , eq32 x <&> \Refl -> SpInt32 id x | 111 | , ep32 x <&> \Refl -> SpInt32 idint |
64 | , SpCDouble <$> eqt x | 112 | , SpCDouble <$> eqp x |
65 | , SpCFloat <$> eqt x | 113 | , SpCFloat <$> eqp x |
66 | , eqint x <&> \Refl -> case x of CInt y -> SpInt32 coerce y | 114 | , epint x <&> \Refl -> SpInt32 coerceint |
67 | -- , em32 x <&> \(nat,Refl) -> case x of Mod y -> SpInt32 (i2f' nat) y | 115 | , case typeRepOf x of |
68 | , case typeOf x of | 116 | App (App modtyp n) inttyp |
69 | App (App modtyp ntyp) inttyp -> case eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp of | 117 | -> do HRefl <- eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp |
70 | Just HRefl -> let i = unMod x | 118 | mplus (eqTypeRep (typeRep :: TypeRep Int32) inttyp <&> \HRefl -> SpInt32 $ modint n) |
71 | in case eqTypeRep (typeRep :: TypeRep Int32) inttyp of | 119 | (eqTypeRep (typeRep :: TypeRep Int64) inttyp <&> \HRefl -> SpInt64 $ modint n) |
72 | Just HRefl -> Just $ SpInt32 i2f i | ||
73 | _ -> case eqTypeRep (typeRep :: TypeRep Int64) inttyp of | ||
74 | Just HRefl -> Just $ SpInt64 i2f i | ||
75 | _ -> Nothing | ||
76 | Nothing -> Nothing | ||
77 | _ -> Nothing | 120 | _ -> Nothing |
78 | ] | 121 | ] |
79 | 122 | ||
80 | -- | Supported matrix elements. | 123 | -- | Supported matrix elements. |
81 | constantD :: Typeable a => a -> Int -> Vector a | 124 | constantD :: Typeable a => a -> Int -> Vector a |
82 | constantD x = case specialize x of | 125 | constantD x = fromMaybe (error "constantD") $ specialize (const x) <&> \case |
83 | Nothing -> error "constantD" | 126 | SpDouble Refl -> constantAux cconstantR x |
84 | Just (SpDouble Refl) -> constantAux cconstantR x | 127 | SpInt64 r -> i2rep r . constantAux cconstantL (rep2one r x) |
85 | Just (SpInt64 out y) -> out . constantAux cconstantL y | 128 | SpFloat Refl -> constantAux cconstantF x |
86 | Just (SpFloat Refl) -> constantAux cconstantF x | 129 | SpInt32 r -> i2rep r . constantAux cconstantI (rep2one r x) |
87 | Just (SpInt32 out y) -> out . constantAux cconstantI y | 130 | SpCDouble Refl -> constantAux cconstantC x |
88 | Just (SpCDouble Refl) -> constantAux cconstantC x | 131 | SpCFloat Refl -> constantAux cconstantQ x |
89 | Just (SpCFloat Refl) -> constantAux cconstantQ x | ||
90 | -- Just (SpModInt32 _ y ret) -> \n -> ret (constantAux cconstantI y n) | ||
91 | 132 | ||
92 | -- | Wrapper with a phantom integer for statically checked modular arithmetic. | 133 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
93 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | ||
94 | deriving (Storable) | ||
95 | 134 | ||
96 | instance (NFData t) => NFData (Mod n t) | 135 | -- | Matrix representation suitable for BLAS\/LAPACK computations. |
136 | data Matrix t = Matrix | ||
137 | { irows :: {-# UNPACK #-} !Int | ||
138 | , icols :: {-# UNPACK #-} !Int | ||
139 | , xRow :: {-# UNPACK #-} !Int | ||
140 | , xCol :: {-# UNPACK #-} !Int | ||
141 | , xdat :: {-# UNPACK #-} !(Vector t) | ||
142 | } | ||
143 | |||
144 | -- allocates memory for a new matrix | ||
145 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | ||
146 | createMatrix ord r c = do | ||
147 | p <- createVector (r*c) | ||
148 | return (matrixFromVector ord r c p) | ||
149 | |||
150 | matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t | ||
151 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | ||
152 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | ||
153 | matrixFromVector o r c v | ||
154 | | r * c == dim v = m | ||
155 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | ||
97 | where | 156 | where |
98 | rnf (Mod x) = rnf x | 157 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } |
158 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } | ||
99 | 159 | ||
100 | i2f :: Storable t => Vector t -> Vector (Mod n t) | 160 | shSize :: Matrix t -> [Char] |
101 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | 161 | shSize = shDim . size |
102 | where (fp,i,n) = unsafeToForeignPtr v | ||
103 | 162 | ||
163 | shDim :: (Show a, Show a1) => (a1, a) -> [Char] | ||
164 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | ||
165 | |||
166 | size :: Matrix t -> (Int, Int) | ||
167 | size m = (irows m, icols m) | ||
168 | {-# INLINE size #-} | ||
104 | 169 | ||
105 | {- | ||
106 | extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | 170 | extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) |
171 | extractR ord m = fromMaybe (\mi is mj js -> error "extractR") $ specialize m <&> \case | ||
172 | SpDouble Refl -> extractAux c_extractD ord m | ||
173 | SpInt64 r -> \mi is mj js -> i2repM r <$> extractAux c_extractL ord (rep2iM r m) mi is mj js | ||
174 | SpFloat Refl -> extractAux c_extractF ord m | ||
175 | SpInt32 r -> \mi is mj js -> i2repM r <$> extractAux (coerce c_extractI) ord (rep2iM r m) mi is mj js | ||
176 | SpCDouble Refl -> extractAux c_extractC ord m | ||
177 | SpCFloat Refl -> extractAux c_extractQ ord m | ||
178 | |||
107 | setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () | 179 | setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () |
108 | sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt | 180 | setRect i j m x = fromMaybe (error "setRect") $ specialize m <&> \case |
109 | sortV :: (Typeable a , Ord a ) => Vector a -> Vector a | 181 | SpDouble Refl -> setRectAux c_setRectD i j m x |
182 | SpInt64 r -> setRectAux c_setRectL i j (rep2iM r m) (rep2iM r x) | ||
183 | SpFloat Refl -> setRectAux c_setRectF i j m x | ||
184 | SpInt32 r -> setRectAux (coerce c_setRectI) i j (rep2iM r m) (rep2iM r x) | ||
185 | SpCDouble Refl -> setRectAux c_setRectC i j m x | ||
186 | SpCFloat Refl -> setRectAux c_setRectQ i j m x | ||
187 | |||
188 | sortI :: (Typeable a , Ord a) => Vector a -> Vector CInt | ||
189 | sortI v = maybe (error "sortI") ($ v) $ specialize v <&> \case | ||
190 | SpDouble Refl -> sortIdxD | ||
191 | SpInt64 r -> sortIdxL . rep2i r | ||
192 | SpFloat Refl -> sortIdxF | ||
193 | SpInt32 r -> coerce sortIdxI . rep2i r | ||
194 | SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex | ||
195 | SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex | ||
196 | |||
197 | sortV :: (Typeable a , Ord a ) => Vector a -> Vector a | ||
198 | sortV v = maybe (error "sortV") ($ v) $ specialize v <&> \case | ||
199 | SpDouble Refl -> sortValD | ||
200 | SpInt64 r -> i2rep r . sortValL . rep2i r | ||
201 | SpFloat Refl -> sortValF | ||
202 | SpInt32 r -> i2rep r . coerce sortValI . rep2i r | ||
203 | SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex | ||
204 | SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex | ||
205 | |||
110 | compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt | 206 | compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt |
207 | compareV u v = fromMaybe (error "compareV" u v) $ specialize u <&> \case | ||
208 | SpDouble Refl -> compareD u v | ||
209 | SpInt64 r -> compareL (rep2i r u) (rep2i r v) | ||
210 | SpFloat Refl -> compareF u v | ||
211 | SpInt32 r -> coerce compareI (rep2i r u) (rep2i r v) | ||
212 | SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex | ||
213 | SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex | ||
214 | |||
111 | selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | 215 | selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a |
216 | selectV c l e g = fromMaybe (error "selectV" c l e g) $ specialize l <&> \case | ||
217 | SpDouble Refl -> selectD c l e g | ||
218 | SpInt64 r -> i2rep r (selectL c (rep2i r l) (rep2i r e) (rep2i r g)) | ||
219 | SpFloat Refl -> selectF c l e g | ||
220 | SpInt32 r -> i2rep r (coerce selectI c (rep2i r l) (rep2i r e) (rep2i r g)) | ||
221 | SpCDouble Refl -> selectC c l e g | ||
222 | SpCFloat Refl -> selectQ c l e g | ||
223 | |||
112 | remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 224 | remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
225 | remapM i j m = fromMaybe (error "remapM" i j m) $ specialize m <&> \case | ||
226 | SpDouble Refl -> remapD i j m | ||
227 | SpInt64 r -> i2repM r (remapL i j (rep2iM r m)) | ||
228 | SpFloat Refl -> remapF i j m | ||
229 | SpInt32 r -> i2repM r (coerce remapI i j (rep2iM r m)) | ||
230 | SpCDouble Refl -> remapC i j m | ||
231 | SpCFloat Refl -> remapQ i j m | ||
232 | |||
113 | rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 233 | rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
234 | rowOp c a i1 i2 j1 j2 x = fromMaybe (error "rowOp") $ specialize x <&> \case | ||
235 | SpDouble Refl -> rowOpAux c_rowOpD c a i1 i2 j1 j2 x | ||
236 | SpInt64 r -> case modulo r of | ||
237 | Just m' -> rowOpAux (c_rowOpML m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) | ||
238 | Nothing -> rowOpAux c_rowOpL c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) | ||
239 | SpFloat Refl -> rowOpAux c_rowOpF c a i1 i2 j1 j2 x | ||
240 | SpInt32 r -> case modulo r of | ||
241 | Just m' -> rowOpAux (coerce c_rowOpMI m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) | ||
242 | Nothing -> rowOpAux (coerce c_rowOpI) c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) | ||
243 | SpCDouble Refl -> rowOpAux c_rowOpC c a i1 i2 j1 j2 x | ||
244 | SpCFloat Refl -> rowOpAux c_rowOpQ c a i1 i2 j1 j2 x | ||
245 | |||
114 | gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | 246 | gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () |
247 | gemm u a b c = fromMaybe (error "gemm") $ specialize u <&> \case | ||
248 | SpDouble Refl -> gemmg c_gemmD u a b c | ||
249 | SpInt64 r -> case modulo r of | ||
250 | Just m' -> gemmg (c_gemmML m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) | ||
251 | Nothing -> gemmg c_gemmL (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) | ||
252 | SpFloat Refl -> gemmg c_gemmF u a b c | ||
253 | SpInt32 r -> case modulo r of | ||
254 | Just m' -> gemmg (coerce c_gemmMI m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) | ||
255 | Nothing -> gemmg (coerce c_gemmI) (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) | ||
256 | SpCDouble Refl -> gemmg c_gemmC u a b c | ||
257 | SpCFloat Refl -> gemmg c_gemmQ u a b c | ||
258 | |||
115 | reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | 259 | reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation |
260 | reorderV strides dims v = fromMaybe (error "reorderV") $ specialize v <&> \case | ||
261 | SpDouble Refl -> reorderAux c_reorderD strides dims v | ||
262 | SpInt64 r -> i2rep r $ reorderAux c_reorderL strides dims (rep2i r v) | ||
263 | SpFloat Refl -> reorderAux c_reorderF strides dims v | ||
264 | SpInt32 r -> i2rep r $ reorderAux (coerce c_reorderI) strides dims (rep2i r v) | ||
265 | SpCDouble Refl -> reorderAux c_reorderC strides dims v | ||
266 | SpCFloat Refl -> reorderAux c_reorderQ strides dims v | ||
116 | 267 | ||
117 | instance KnownNat m => Element (Mod m I) | 268 | |
269 | instance Storable t => TransArray (Matrix t) | ||
118 | where | 270 | where |
119 | constantD x n = i2f (constantD (unMod x) n) | 271 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b |
120 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | 272 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b |
121 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | 273 | apply = amat |
122 | sortI = sortI . f2i | 274 | {-# INLINE apply #-} |
123 | sortV = i2f . sortV . f2i | 275 | applyRaw = amatr |
124 | compareV u v = compareV (f2i u) (f2i v) | 276 | {-# INLINE applyRaw #-} |
125 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | 277 | |
126 | remapM i j m = i2fM (remap i j (f2iM m)) | 278 | -- C-Haskell matrix adapters |
127 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) | 279 | {-# INLINE amatr #-} |
128 | where | 280 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r |
129 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 281 | amatr x f g = unsafeWith (xdat x) (f . g r c) |
130 | gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | ||
131 | where | ||
132 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
133 | |||
134 | instance KnownNat m => Element (Mod m Z) | ||
135 | where | 282 | where |
136 | constantD x n = i2f (constantD (unMod x) n) | 283 | r = fi (rows x) |
137 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | 284 | c = fi (cols x) |
138 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | 285 | |
139 | sortI = sortI . f2i | 286 | {-# INLINE amat #-} |
140 | sortV = i2f . sortV . f2i | 287 | amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r |
141 | compareV u v = compareV (f2i u) (f2i v) | 288 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) |
142 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | 289 | where |
143 | remapM i j m = i2fM (remap i j (f2iM m)) | 290 | r = fi (rows x) |
144 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) | 291 | c = fi (cols x) |
145 | where | 292 | sr = fi (xRow x) |
146 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 293 | sc = fi (xCol x) |
147 | gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | 294 | |
148 | where | 295 | rows :: Matrix t -> Int |
149 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 296 | rows = irows |
150 | -} | 297 | {-# INLINE rows #-} |
298 | |||
299 | cols :: Matrix t -> Int | ||
300 | cols = icols | ||
301 | {-# INLINE cols #-} | ||
302 | |||
151 | 303 | ||
304 | infixr 1 # | ||
305 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r | ||
306 | a # b = apply a b | ||
307 | {-# INLINE (#) #-} | ||
152 | 308 | ||
153 | ( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV ) | 309 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r |
154 | = error "todo Element" | 310 | a #! b = a # b # id |
311 | {-# INLINE (#!) #-} | ||
155 | 312 | ||
156 | constantAux :: (Storable a1, Storable a) | 313 | constantAux :: (Storable a1, Storable a) |
157 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a | 314 | => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a |
@@ -169,3 +326,237 @@ foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) | |||
169 | foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) | 326 | foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) |
170 | foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 | 327 | foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 |
171 | foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 | 328 | foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 |
329 | |||
330 | {- | ||
331 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
332 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
333 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
334 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
335 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
336 | extractAux f ord m moder vr modec vc = do | ||
337 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | ||
338 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | ||
339 | r <- createMatrix ord nr nc | ||
340 | (vr # vc # m #!r) (f moder modec) #|"extract" | ||
341 | return r | ||
342 | -} | ||
343 | |||
344 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
345 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
346 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
347 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
348 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
349 | extractAux f ord m moder vr modec vc = do | ||
350 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | ||
351 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | ||
352 | r <- createMatrix ord nr nc | ||
353 | (vr # vc # m #! r) (f moder modec) #|"extract" | ||
354 | return r | ||
355 | |||
356 | type Extr x = CInt -> CInt -> | ||
357 | CInt -> Ptr CInt -> -- CIdxs | ||
358 | CInt -> Ptr CInt -> -- CIdxs | ||
359 | CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x | ||
360 | CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x | ||
361 | IO CInt | ||
362 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double | ||
363 | foreign import ccall unsafe "extractF" c_extractF :: Extr Float | ||
364 | foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) | ||
365 | foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) | ||
366 | foreign import ccall unsafe "extractI" c_extractI :: Extr CInt | ||
367 | foreign import ccall unsafe "extractL" c_extractL :: Extr Int64 | ||
368 | |||
369 | setRectAux :: (TransArray c1, TransArray c) | ||
370 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | ||
371 | -> Int -> Int -> c1 -> c -> IO () | ||
372 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | ||
373 | |||
374 | type SetRect x = I -> I -> x ::> x::> Ok | ||
375 | |||
376 | foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double | ||
377 | foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float | ||
378 | foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) | ||
379 | foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) | ||
380 | foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I | ||
381 | foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | ||
382 | |||
383 | sortG :: (Storable t, Storable a) | ||
384 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
385 | sortG f v = unsafePerformIO $ do | ||
386 | r <- createVector (dim v) | ||
387 | (v #! r) f #|"sortG" | ||
388 | return r | ||
389 | |||
390 | sortIdxD :: Vector Double -> Vector CInt | ||
391 | sortIdxD = sortG c_sort_indexD | ||
392 | sortIdxF :: Vector Float -> Vector CInt | ||
393 | sortIdxF = sortG c_sort_indexF | ||
394 | sortIdxI :: Vector CInt -> Vector CInt | ||
395 | sortIdxI = sortG c_sort_indexI | ||
396 | sortIdxL :: Vector Z -> Vector I | ||
397 | sortIdxL = sortG c_sort_indexL | ||
398 | |||
399 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | ||
400 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) | ||
401 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) | ||
402 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok | ||
403 | |||
404 | sortValD :: Vector Double -> Vector Double | ||
405 | sortValD = sortG c_sort_valD | ||
406 | sortValF :: Vector Float -> Vector Float | ||
407 | sortValF = sortG c_sort_valF | ||
408 | sortValI :: Vector CInt -> Vector CInt | ||
409 | sortValI = sortG c_sort_valI | ||
410 | sortValL :: Vector Z -> Vector Z | ||
411 | sortValL = sortG c_sort_valL | ||
412 | |||
413 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) | ||
414 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) | ||
415 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) | ||
416 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | ||
417 | |||
418 | compareG :: (TransArray c, Storable t, Storable a) | ||
419 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | ||
420 | -> c -> Vector t -> Vector a | ||
421 | compareG f u v = unsafePerformIO $ do | ||
422 | r <- createVector (dim v) | ||
423 | (u # v #! r) f #|"compareG" | ||
424 | return r | ||
425 | |||
426 | compareD :: Vector Double -> Vector Double -> Vector CInt | ||
427 | compareD = compareG c_compareD | ||
428 | compareF :: Vector Float -> Vector Float -> Vector CInt | ||
429 | compareF = compareG c_compareF | ||
430 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | ||
431 | compareI = compareG c_compareI | ||
432 | compareL :: Vector Z -> Vector Z -> Vector CInt | ||
433 | compareL = compareG c_compareL | ||
434 | |||
435 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | ||
436 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) | ||
437 | foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) | ||
438 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | ||
439 | |||
440 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | ||
441 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | ||
442 | -> c2 -> c1 -> Vector t -> c -> Vector a | ||
443 | selectG f c u v w = unsafePerformIO $ do | ||
444 | r <- createVector (dim v) | ||
445 | (c # u # v # w #! r) f #|"selectG" | ||
446 | return r | ||
447 | |||
448 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
449 | selectD = selectG c_selectD | ||
450 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
451 | selectF = selectG c_selectF | ||
452 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | ||
453 | selectI = selectG c_selectI | ||
454 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | ||
455 | selectL = selectG c_selectL | ||
456 | selectC :: Vector CInt | ||
457 | -> Vector (Complex Double) | ||
458 | -> Vector (Complex Double) | ||
459 | -> Vector (Complex Double) | ||
460 | -> Vector (Complex Double) | ||
461 | selectC = selectG c_selectC | ||
462 | selectQ :: Vector CInt | ||
463 | -> Vector (Complex Float) | ||
464 | -> Vector (Complex Float) | ||
465 | -> Vector (Complex Float) | ||
466 | -> Vector (Complex Float) | ||
467 | selectQ = selectG c_selectQ | ||
468 | |||
469 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | ||
470 | |||
471 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double | ||
472 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float | ||
473 | foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt | ||
474 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) | ||
475 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) | ||
476 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | ||
477 | |||
478 | |||
479 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | ||
480 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | ||
481 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | ||
482 | -> Matrix t -> c1 -> c -> Matrix a | ||
483 | remapG f i j m = unsafePerformIO $ do | ||
484 | r <- createMatrix RowMajor (rows i) (cols i) | ||
485 | (i # j # m #! r) f #|"remapG" | ||
486 | return r | ||
487 | |||
488 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | ||
489 | remapD = remapG c_remapD | ||
490 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | ||
491 | remapF = remapG c_remapF | ||
492 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
493 | remapI = remapG c_remapI | ||
494 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | ||
495 | remapL = remapG c_remapL | ||
496 | remapC :: Matrix CInt | ||
497 | -> Matrix CInt | ||
498 | -> Matrix (Complex Double) | ||
499 | -> Matrix (Complex Double) | ||
500 | remapC = remapG c_remapC | ||
501 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | ||
502 | remapQ = remapG c_remapQ | ||
503 | |||
504 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | ||
505 | |||
506 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double | ||
507 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float | ||
508 | foreign import ccall unsafe "remapI" c_remapI :: Rem CInt | ||
509 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) | ||
510 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) | ||
511 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z | ||
512 | |||
513 | |||
514 | rowOpAux :: (TransArray c, Storable a) => | ||
515 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | ||
516 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | ||
517 | rowOpAux f c x i1 i2 j1 j2 m = do | ||
518 | px <- newArray [x] | ||
519 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | ||
520 | free px | ||
521 | |||
522 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | ||
523 | |||
524 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | ||
525 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | ||
526 | foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C | ||
527 | foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) | ||
528 | foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I | ||
529 | foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z | ||
530 | foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I | ||
531 | foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | ||
532 | |||
533 | |||
534 | gemmg :: Storable x => Tgemm x -> Vector x -> Matrix x -> Matrix x -> Matrix x -> IO () | ||
535 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | ||
536 | |||
537 | type Tgemm x = x :> x ::> x ::> x ::> Ok | ||
538 | |||
539 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | ||
540 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | ||
541 | foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C | ||
542 | foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) | ||
543 | foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I | ||
544 | foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z | ||
545 | foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I | ||
546 | foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | ||
547 | |||
548 | reorderAux :: Storable x => Reorder x -> Vector CInt -> Vector CInt -> Vector x -> Vector x | ||
549 | reorderAux f s d v = unsafePerformIO $ do | ||
550 | k <- createVector (dim s) | ||
551 | r <- createVector (dim v) | ||
552 | (k # s # d # v #! r) f #| "reorderV" | ||
553 | return r | ||
554 | |||
555 | type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) | ||
556 | |||
557 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | ||
558 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | ||
559 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt | ||
560 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | ||
561 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | ||
562 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | ||