summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Crayne <joe@jerkface.net>2019-08-10 01:39:35 -0400
committerJoe Crayne <joe@jerkface.net>2019-08-10 03:27:06 -0400
commit145a61cc82ab66853daed8b352cb283fdcc790c5 (patch)
tree945689a6c3373001ff3d74eedaa9f190261bfbdc
parentd304980b586fb7c7ee369b7d83620c9d992dea5a (diff)
More specialization.
-rw-r--r--packages/base/src/Internal/Matrix.hs382
-rw-r--r--packages/base/src/Internal/Modular.hs10
-rw-r--r--packages/base/src/Internal/Specialized.hs561
3 files changed, 476 insertions, 477 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 7c774ef..225b039 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -37,31 +37,6 @@ import Text.Printf
37 37
38----------------------------------------------------------------- 38-----------------------------------------------------------------
39 39
40data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
41
42-- | Matrix representation suitable for BLAS\/LAPACK computations.
43
44data Matrix t = Matrix
45 { irows :: {-# UNPACK #-} !Int
46 , icols :: {-# UNPACK #-} !Int
47 , xRow :: {-# UNPACK #-} !Int
48 , xCol :: {-# UNPACK #-} !Int
49 , xdat :: {-# UNPACK #-} !(Vector t)
50 }
51
52
53rows :: Matrix t -> Int
54rows = irows
55{-# INLINE rows #-}
56
57cols :: Matrix t -> Int
58cols = icols
59{-# INLINE cols #-}
60
61size :: Matrix t -> (Int, Int)
62size m = (irows m, icols m)
63{-# INLINE size #-}
64
65rowOrder :: Matrix t -> Bool 40rowOrder :: Matrix t -> Bool
66rowOrder m = xCol m == 1 || cols m == 1 41rowOrder m = xCol m == 1 || cols m == 1
67{-# INLINE rowOrder #-} 42{-# INLINE rowOrder #-}
@@ -114,33 +89,6 @@ fmat m
114 | otherwise = extractAll ColumnMajor m 89 | otherwise = extractAll ColumnMajor m
115 90
116 91
117-- C-Haskell matrix adapters
118{-# INLINE amatr #-}
119amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r
120amatr x f g = unsafeWith (xdat x) (f . g r c)
121 where
122 r = fi (rows x)
123 c = fi (cols x)
124
125{-# INLINE amat #-}
126amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r
127amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
128 where
129 r = fi (rows x)
130 c = fi (cols x)
131 sr = fi (xRow x)
132 sc = fi (xCol x)
133
134
135instance Storable t => TransArray (Matrix t)
136 where
137 type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b
138 type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b
139 apply = amat
140 {-# INLINE apply #-}
141 applyRaw = amatr
142 {-# INLINE applyRaw #-}
143
144infixr 1 # 92infixr 1 #
145(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r 93(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
146a # b = apply a b 94a # b = apply a b
@@ -240,22 +188,6 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
240 188
241------------------------------------------------------------------ 189------------------------------------------------------------------
242 190
243matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t
244matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
245matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
246matrixFromVector o r c v
247 | r * c == dim v = m
248 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
249 where
250 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
251 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
252
253-- allocates memory for a new matrix
254createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
255createMatrix ord r c = do
256 p <- createVector (r*c)
257 return (matrixFromVector ord r c p)
258
259{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ 191{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@
260where r is the desired number of rows.) 192where r is the desired number of rows.)
261 193
@@ -286,101 +218,6 @@ liftMatrix2 f m1@(size->(r,c)) m2
286 218
287------------------------------------------------------------------ 219------------------------------------------------------------------
288 220
289{-
290-- | Supported matrix elements.
291class (Storable a) => Element a where
292 constantD :: a -> Int -> Vector a
293 extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
294 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
295 sortI :: Ord a => Vector a -> Vector CInt
296 sortV :: Ord a => Vector a -> Vector a
297 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
298 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
299 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
300 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
301 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
302 reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
303
304
305instance Element Float where
306 constantD = constantAux cconstantF
307 extractR = extractAux c_extractF
308 setRect = setRectAux c_setRectF
309 sortI = sortIdxF
310 sortV = sortValF
311 compareV = compareF
312 selectV = selectF
313 remapM = remapF
314 rowOp = rowOpAux c_rowOpF
315 gemm = gemmg c_gemmF
316 reorderV = reorderAux c_reorderF
317
318instance Element Double where
319 constantD = constantAux cconstantR
320 extractR = extractAux c_extractD
321 setRect = setRectAux c_setRectD
322 sortI = sortIdxD
323 sortV = sortValD
324 compareV = compareD
325 selectV = selectD
326 remapM = remapD
327 rowOp = rowOpAux c_rowOpD
328 gemm = gemmg c_gemmD
329 reorderV = reorderAux c_reorderD
330
331instance Element (Complex Float) where
332 constantD = constantAux cconstantQ
333 extractR = extractAux c_extractQ
334 setRect = setRectAux c_setRectQ
335 sortI = undefined
336 sortV = undefined
337 compareV = undefined
338 selectV = selectQ
339 remapM = remapQ
340 rowOp = rowOpAux c_rowOpQ
341 gemm = gemmg c_gemmQ
342 reorderV = reorderAux c_reorderQ
343
344instance Element (Complex Double) where
345 constantD = constantAux cconstantC
346 extractR = extractAux c_extractC
347 setRect = setRectAux c_setRectC
348 sortI = undefined
349 sortV = undefined
350 compareV = undefined
351 selectV = selectC
352 remapM = remapC
353 rowOp = rowOpAux c_rowOpC
354 gemm = gemmg c_gemmC
355 reorderV = reorderAux c_reorderC
356
357instance Element (CInt) where
358 constantD = constantAux cconstantI
359 extractR = extractAux c_extractI
360 setRect = setRectAux c_setRectI
361 sortI = sortIdxI
362 sortV = sortValI
363 compareV = compareI
364 selectV = selectI
365 remapM = remapI
366 rowOp = rowOpAux c_rowOpI
367 gemm = gemmg c_gemmI
368 reorderV = reorderAux c_reorderI
369
370instance Element Z where
371 constantD = constantAux cconstantL
372 extractR = extractAux c_extractL
373 setRect = setRectAux c_setRectL
374 sortI = sortIdxL
375 sortV = sortValL
376 compareV = compareL
377 selectV = selectL
378 remapM = remapL
379 rowOp = rowOpAux c_rowOpL
380 gemm = gemmg c_gemmL
381 reorderV = reorderAux c_reorderL
382-}
383
384------------------------------------------------------------------- 221-------------------------------------------------------------------
385 222
386-- | reference to a rectangular slice of a matrix (no data copy) 223-- | reference to a rectangular slice of a matrix (no data copy)
@@ -435,12 +272,6 @@ repRows n x = fromRows (replicate n (flatten x))
435repCols :: Element t => Int -> Matrix t -> Matrix t 272repCols :: Element t => Int -> Matrix t -> Matrix t
436repCols n x = fromColumns (replicate n (flatten x)) 273repCols n x = fromColumns (replicate n (flatten x))
437 274
438shSize :: Matrix t -> [Char]
439shSize = shDim . size
440
441shDim :: (Show a, Show a1) => (a1, a) -> [Char]
442shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
443
444emptyM :: Storable t => Int -> Int -> Matrix t 275emptyM :: Storable t => Int -> Int -> Matrix t
445emptyM r c = matrixFromVector RowMajor r c (fromList[]) 276emptyM r c = matrixFromVector RowMajor r c (fromList[])
446 277
@@ -456,19 +287,6 @@ instance (Storable t, NFData t) => NFData (Matrix t)
456 287
457--------------------------------------------------------------- 288---------------------------------------------------------------
458 289
459extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
460 Storable t, Num t3, Num t2, Integral t1, Integral t)
461 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
462 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
463 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
464extractAux f ord m moder vr modec vc = do
465 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
466 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
467 r <- createMatrix ord nr nc
468 (vr # vc # m #! r) (f moder modec) #|"extract"
469
470 return r
471
472type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 290type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
473 291
474foreign import ccall unsafe "extractD" c_extractD :: Extr Double 292foreign import ccall unsafe "extractD" c_extractD :: Extr Double
@@ -480,217 +298,17 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
480 298
481--------------------------------------------------------------- 299---------------------------------------------------------------
482 300
483setRectAux :: (TransArray c1, TransArray c)
484 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt)))
485 -> Int -> Int -> c1 -> c -> IO ()
486setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
487
488type SetRect x = I -> I -> x ::> x::> Ok
489
490foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double
491foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float
492foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double)
493foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float)
494foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I
495foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
496
497-------------------------------------------------------------------------------- 301--------------------------------------------------------------------------------
498 302
499sortG :: (Storable t, Storable a)
500 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
501sortG f v = unsafePerformIO $ do
502 r <- createVector (dim v)
503 (v #! r) f #|"sortG"
504 return r
505
506sortIdxD :: Vector Double -> Vector CInt
507sortIdxD = sortG c_sort_indexD
508sortIdxF :: Vector Float -> Vector CInt
509sortIdxF = sortG c_sort_indexF
510sortIdxI :: Vector CInt -> Vector CInt
511sortIdxI = sortG c_sort_indexI
512sortIdxL :: Vector Z -> Vector I
513sortIdxL = sortG c_sort_indexL
514
515sortValD :: Vector Double -> Vector Double
516sortValD = sortG c_sort_valD
517sortValF :: Vector Float -> Vector Float
518sortValF = sortG c_sort_valF
519sortValI :: Vector CInt -> Vector CInt
520sortValI = sortG c_sort_valI
521sortValL :: Vector Z -> Vector Z
522sortValL = sortG c_sort_valL
523
524foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
525foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt))
526foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt))
527foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok
528
529foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt))
530foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt))
531foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt))
532foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
533
534-------------------------------------------------------------------------------- 303--------------------------------------------------------------------------------
535 304
536compareG :: (TransArray c, Storable t, Storable a)
537 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt)
538 -> c -> Vector t -> Vector a
539compareG f u v = unsafePerformIO $ do
540 r <- createVector (dim v)
541 (u # v #! r) f #|"compareG"
542 return r
543
544compareD :: Vector Double -> Vector Double -> Vector CInt
545compareD = compareG c_compareD
546compareF :: Vector Float -> Vector Float -> Vector CInt
547compareF = compareG c_compareF
548compareI :: Vector CInt -> Vector CInt -> Vector CInt
549compareI = compareG c_compareI
550compareL :: Vector Z -> Vector Z -> Vector CInt
551compareL = compareG c_compareL
552
553foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
554foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt)))
555foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt)))
556foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
557
558-------------------------------------------------------------------------------- 305--------------------------------------------------------------------------------
559 306
560selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
561 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt)))
562 -> c2 -> c1 -> Vector t -> c -> Vector a
563selectG f c u v w = unsafePerformIO $ do
564 r <- createVector (dim v)
565 (c # u # v # w #! r) f #|"selectG"
566 return r
567
568selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double
569selectD = selectG c_selectD
570selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float
571selectF = selectG c_selectF
572selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt
573selectI = selectG c_selectI
574selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z
575selectL = selectG c_selectL
576selectC :: Vector CInt
577 -> Vector (Complex Double)
578 -> Vector (Complex Double)
579 -> Vector (Complex Double)
580 -> Vector (Complex Double)
581selectC = selectG c_selectC
582selectQ :: Vector CInt
583 -> Vector (Complex Float)
584 -> Vector (Complex Float)
585 -> Vector (Complex Float)
586 -> Vector (Complex Float)
587selectQ = selectG c_selectQ
588
589type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
590
591foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
592foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
593foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
594foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
595foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
596foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
597
598--------------------------------------------------------------------------- 307---------------------------------------------------------------------------
599
600remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
601 => (CInt -> CInt -> CInt -> CInt -> Ptr t
602 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)))
603 -> Matrix t -> c1 -> c -> Matrix a
604remapG f i j m = unsafePerformIO $ do
605 r <- createMatrix RowMajor (rows i) (cols i)
606 (i # j # m #! r) f #|"remapG"
607 return r
608
609remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double
610remapD = remapG c_remapD
611remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float
612remapF = remapG c_remapF
613remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt
614remapI = remapG c_remapI
615remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z
616remapL = remapG c_remapL
617remapC :: Matrix CInt
618 -> Matrix CInt
619 -> Matrix (Complex Double)
620 -> Matrix (Complex Double)
621remapC = remapG c_remapC
622remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float)
623remapQ = remapG c_remapQ
624
625type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
626
627foreign import ccall unsafe "remapD" c_remapD :: Rem Double
628foreign import ccall unsafe "remapF" c_remapF :: Rem Float
629foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
630foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
631foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
632foreign import ccall unsafe "remapL" c_remapL :: Rem Z
633
634-------------------------------------------------------------------------------- 308--------------------------------------------------------------------------------
635
636rowOpAux :: (TransArray c, Storable a) =>
637 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt))
638 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
639rowOpAux f c x i1 i2 j1 j2 m = do
640 px <- newArray [x]
641 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
642 free px
643
644type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
645
646foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
647foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
648foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C
649foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float)
650foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I
651foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z
652foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I
653foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
654
655-------------------------------------------------------------------------------- 309--------------------------------------------------------------------------------
656 310
657gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
658 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt))))
659 -> c3 -> c2 -> c1 -> c -> IO ()
660gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
661
662type Tgemm x = x :> x ::> x ::> x ::> Ok
663
664foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
665foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
666foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C
667foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float)
668foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
669foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
670foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
671foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
672
673-------------------------------------------------------------------------------- 311--------------------------------------------------------------------------------
674
675reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
676 (CInt -> Ptr a -> CInt -> Ptr t1
677 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt))
678 -> Vector t1 -> c -> Vector t -> Vector a1
679reorderAux f s d v = unsafePerformIO $ do
680 k <- createVector (dim s)
681 r <- createVector (dim v)
682 (k # s # d # v #! r) f #| "reorderV"
683 return r
684
685type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt)))))
686
687foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
688foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
689foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt
690foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
691foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
692foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
693
694-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, 312-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices,
695-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ 313-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@
696-- This function is intended to be used internally by tensor libraries. 314-- This function is intended to be used internally by tensor libraries.
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index a211dd3..10ff8a3 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -257,16 +257,6 @@ instance KnownNat m => Normed (Vector (Mod m Z))
257instance KnownNat m => Numeric (Mod m I) 257instance KnownNat m => Numeric (Mod m I)
258instance KnownNat m => Numeric (Mod m Z) 258instance KnownNat m => Numeric (Mod m Z)
259 259
260f2i :: Storable t => Vector (Mod n t) -> Vector t
261f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
262 where (fp,i,n) = unsafeToForeignPtr v
263
264f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t
265f2iM m = m { xdat = f2i (xdat m) }
266
267i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t)
268i2fM m = m { xdat = i2f (xdat m) }
269
270vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) 260vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t)
271vmod = i2f . cmod' m' 261vmod = i2f . cmod' m'
272 where 262 where
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs
index c79194f..c063369 100644
--- a/packages/base/src/Internal/Specialized.hs
+++ b/packages/base/src/Internal/Specialized.hs
@@ -8,6 +8,8 @@
8{-# LANGUAGE RankNTypes #-} 8{-# LANGUAGE RankNTypes #-}
9{-# LANGUAGE ScopedTypeVariables #-} 9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE KindSignatures #-} 10{-# LANGUAGE KindSignatures #-}
11{-# LANGUAGE ViewPatterns #-}
12{-# LANGUAGE LambdaCase #-}
11module Internal.Specialized where 13module Internal.Specialized where
12 14
13import Control.Monad 15import Control.Monad
@@ -16,6 +18,7 @@ import Data.Coerce
16import Data.Complex 18import Data.Complex
17import Data.Functor 19import Data.Functor
18import Data.Int 20import Data.Int
21import Data.Maybe
19import Data.Typeable (eqT,Proxy) 22import Data.Typeable (eqT,Proxy)
20import Type.Reflection 23import Type.Reflection
21import Foreign.Marshal.Alloc(free,malloc) 24import Foreign.Marshal.Alloc(free,malloc)
@@ -31,127 +34,281 @@ import GHC.TypeLits hiding (Mod)
31import GHC.TypeLits 34import GHC.TypeLits
32#endif 35#endif
33 36
34import Internal.Vector (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr) 37import Internal.Vector -- (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr,(@>))
35import Internal.Devel 38import Internal.Devel
36 39
37eqt :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) 40eqp :: (Typeable a, Typeable b) => proxy a -> Maybe (a :~: b)
38eqt _ = eqT 41eqp _ = eqT
39eq32 :: (Typeable a) => a -> Maybe (a :~: Int32) 42ep32 :: (Typeable a) => proxy a -> Maybe (a :~: Int32)
40eq32 _ = eqT 43ep32 _ = eqT
41eq64 :: (Typeable a) => a -> Maybe (a :~: Int64) 44ep64 :: (Typeable a) => proxy a -> Maybe (a :~: Int64)
42eq64 _ = eqT 45ep64 _ = eqT
43eqint :: (Typeable a) => a -> Maybe (a :~: CInt) 46epint :: (Typeable a) => proxy a -> Maybe (a :~: CInt)
44eqint _ = eqT 47epint _ = eqT
45 48
46type Element t = (Storable t, Typeable t) 49type Element t = (Storable t, Typeable t)
47 50
51-- | Wrapper with a phantom integer for statically checked modular arithmetic.
52newtype Mod (n :: Nat) t = Mod {unMod:: t}
53 deriving (Storable)
54
55instance (NFData t) => NFData (Mod n t)
56 where
57 rnf (Mod x) = rnf x
58
59i2fM :: Storable t => Matrix t -> Matrix (Mod n t)
60i2fM m = m { xdat = i2f (xdat m) }
61
62i2f :: Storable t => Vector t -> Vector (Mod n t)
63i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
64 where (fp,i,n) = unsafeToForeignPtr v
65
66f2i :: Storable t => Vector (Mod n t) -> Vector t
67f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
68 where (fp,i,n) = unsafeToForeignPtr v
69
70f2iM :: Storable t => Matrix (Mod n t) -> Matrix t
71f2iM m = m { xdat = f2i (xdat m) }
72
73data IntegralRep t a = IntegralRep
74 { i2rep :: Vector t -> Vector a
75 , i2repM :: Matrix t -> Matrix a
76 , rep2i :: Vector a -> Vector t
77 , rep2iM :: Matrix a -> Matrix t
78 , rep2one :: a -> t
79 , modulo :: Maybe t
80 }
81
82idint :: Storable t => IntegralRep t t
83idint = IntegralRep id id id id id Nothing
84
85coerceint :: Coercible t a => IntegralRep t a
86coerceint = IntegralRep coerce coerce coerce coerce coerce Nothing
87
88modint :: forall t n. (Read t, Storable t) => TypeRep n -> IntegralRep t (Mod n t)
89modint r = IntegralRep i2f i2fM f2i f2iM unMod (Just n)
90 where
91 n = read . show $ r -- XXX: Hack to get nat value from Type.Reflection
92 -- n = fromIntegral . natVal $ (undefined :: Proxy n)
93
94
95typeRepOf :: Typeable a => proxy a -> TypeRep a
96typeRepOf proxy = typeRep
97
48data Specialized a 98data Specialized a
49 = SpFloat !(a :~: Float) 99 = SpFloat !(a :~: Float)
50 | SpDouble !(a :~: Double) 100 | SpDouble !(a :~: Double)
51 | SpCFloat !(a :~: Complex Float) 101 | SpCFloat !(a :~: Complex Float)
52 | SpCDouble !(a :~: Complex Double) 102 | SpCDouble !(a :~: Complex Double)
53 | SpInt32 !(Vector Int32 -> Vector a) !Int32 103 | SpInt32 !(IntegralRep Int32 a)
54 | SpInt64 !(Vector Int64 -> Vector a) !Int64 104 | SpInt64 !(IntegralRep Int64 a)
55 -- | SpModInt32 !Int32 Int32 !(forall f. f Int32 -> f a)
56 -- | SpModInt64 !Int32 Int64 !(forall f. f Int64 -> f a)
57 105
58specialize :: forall a. Typeable a => a -> Maybe (Specialized a) 106specialize :: forall m a. Typeable a => m a -> Maybe (Specialized a)
59specialize x = foldr1 mplus 107specialize x = foldr1 mplus
60 [ SpDouble <$> eqt x 108 [ SpDouble <$> eqp x
61 , eq64 x <&> \Refl -> SpInt64 id x 109 , ep64 x <&> \Refl -> SpInt64 idint
62 , SpFloat <$> eqt x 110 , SpFloat <$> eqp x
63 , eq32 x <&> \Refl -> SpInt32 id x 111 , ep32 x <&> \Refl -> SpInt32 idint
64 , SpCDouble <$> eqt x 112 , SpCDouble <$> eqp x
65 , SpCFloat <$> eqt x 113 , SpCFloat <$> eqp x
66 , eqint x <&> \Refl -> case x of CInt y -> SpInt32 coerce y 114 , epint x <&> \Refl -> SpInt32 coerceint
67 -- , em32 x <&> \(nat,Refl) -> case x of Mod y -> SpInt32 (i2f' nat) y 115 , case typeRepOf x of
68 , case typeOf x of 116 App (App modtyp n) inttyp
69 App (App modtyp ntyp) inttyp -> case eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp of 117 -> do HRefl <- eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp
70 Just HRefl -> let i = unMod x 118 mplus (eqTypeRep (typeRep :: TypeRep Int32) inttyp <&> \HRefl -> SpInt32 $ modint n)
71 in case eqTypeRep (typeRep :: TypeRep Int32) inttyp of 119 (eqTypeRep (typeRep :: TypeRep Int64) inttyp <&> \HRefl -> SpInt64 $ modint n)
72 Just HRefl -> Just $ SpInt32 i2f i
73 _ -> case eqTypeRep (typeRep :: TypeRep Int64) inttyp of
74 Just HRefl -> Just $ SpInt64 i2f i
75 _ -> Nothing
76 Nothing -> Nothing
77 _ -> Nothing 120 _ -> Nothing
78 ] 121 ]
79 122
80-- | Supported matrix elements. 123-- | Supported matrix elements.
81constantD :: Typeable a => a -> Int -> Vector a 124constantD :: Typeable a => a -> Int -> Vector a
82constantD x = case specialize x of 125constantD x = fromMaybe (error "constantD") $ specialize (const x) <&> \case
83 Nothing -> error "constantD" 126 SpDouble Refl -> constantAux cconstantR x
84 Just (SpDouble Refl) -> constantAux cconstantR x 127 SpInt64 r -> i2rep r . constantAux cconstantL (rep2one r x)
85 Just (SpInt64 out y) -> out . constantAux cconstantL y 128 SpFloat Refl -> constantAux cconstantF x
86 Just (SpFloat Refl) -> constantAux cconstantF x 129 SpInt32 r -> i2rep r . constantAux cconstantI (rep2one r x)
87 Just (SpInt32 out y) -> out . constantAux cconstantI y 130 SpCDouble Refl -> constantAux cconstantC x
88 Just (SpCDouble Refl) -> constantAux cconstantC x 131 SpCFloat Refl -> constantAux cconstantQ x
89 Just (SpCFloat Refl) -> constantAux cconstantQ x
90 -- Just (SpModInt32 _ y ret) -> \n -> ret (constantAux cconstantI y n)
91 132
92-- | Wrapper with a phantom integer for statically checked modular arithmetic. 133data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
93newtype Mod (n :: Nat) t = Mod {unMod:: t}
94 deriving (Storable)
95 134
96instance (NFData t) => NFData (Mod n t) 135-- | Matrix representation suitable for BLAS\/LAPACK computations.
136data Matrix t = Matrix
137 { irows :: {-# UNPACK #-} !Int
138 , icols :: {-# UNPACK #-} !Int
139 , xRow :: {-# UNPACK #-} !Int
140 , xCol :: {-# UNPACK #-} !Int
141 , xdat :: {-# UNPACK #-} !(Vector t)
142 }
143
144-- allocates memory for a new matrix
145createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
146createMatrix ord r c = do
147 p <- createVector (r*c)
148 return (matrixFromVector ord r c p)
149
150matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t
151matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
152matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
153matrixFromVector o r c v
154 | r * c == dim v = m
155 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
97 where 156 where
98 rnf (Mod x) = rnf x 157 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
158 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
99 159
100i2f :: Storable t => Vector t -> Vector (Mod n t) 160shSize :: Matrix t -> [Char]
101i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) 161shSize = shDim . size
102 where (fp,i,n) = unsafeToForeignPtr v
103 162
163shDim :: (Show a, Show a1) => (a1, a) -> [Char]
164shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
165
166size :: Matrix t -> (Int, Int)
167size m = (irows m, icols m)
168{-# INLINE size #-}
104 169
105{-
106extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) 170extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
171extractR ord m = fromMaybe (\mi is mj js -> error "extractR") $ specialize m <&> \case
172 SpDouble Refl -> extractAux c_extractD ord m
173 SpInt64 r -> \mi is mj js -> i2repM r <$> extractAux c_extractL ord (rep2iM r m) mi is mj js
174 SpFloat Refl -> extractAux c_extractF ord m
175 SpInt32 r -> \mi is mj js -> i2repM r <$> extractAux (coerce c_extractI) ord (rep2iM r m) mi is mj js
176 SpCDouble Refl -> extractAux c_extractC ord m
177 SpCFloat Refl -> extractAux c_extractQ ord m
178
107setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () 179setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO ()
108sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt 180setRect i j m x = fromMaybe (error "setRect") $ specialize m <&> \case
109sortV :: (Typeable a , Ord a ) => Vector a -> Vector a 181 SpDouble Refl -> setRectAux c_setRectD i j m x
182 SpInt64 r -> setRectAux c_setRectL i j (rep2iM r m) (rep2iM r x)
183 SpFloat Refl -> setRectAux c_setRectF i j m x
184 SpInt32 r -> setRectAux (coerce c_setRectI) i j (rep2iM r m) (rep2iM r x)
185 SpCDouble Refl -> setRectAux c_setRectC i j m x
186 SpCFloat Refl -> setRectAux c_setRectQ i j m x
187
188sortI :: (Typeable a , Ord a) => Vector a -> Vector CInt
189sortI v = maybe (error "sortI") ($ v) $ specialize v <&> \case
190 SpDouble Refl -> sortIdxD
191 SpInt64 r -> sortIdxL . rep2i r
192 SpFloat Refl -> sortIdxF
193 SpInt32 r -> coerce sortIdxI . rep2i r
194 SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex
195 SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex
196
197sortV :: (Typeable a , Ord a ) => Vector a -> Vector a
198sortV v = maybe (error "sortV") ($ v) $ specialize v <&> \case
199 SpDouble Refl -> sortValD
200 SpInt64 r -> i2rep r . sortValL . rep2i r
201 SpFloat Refl -> sortValF
202 SpInt32 r -> i2rep r . coerce sortValI . rep2i r
203 SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex
204 SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex
205
110compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt 206compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt
207compareV u v = fromMaybe (error "compareV" u v) $ specialize u <&> \case
208 SpDouble Refl -> compareD u v
209 SpInt64 r -> compareL (rep2i r u) (rep2i r v)
210 SpFloat Refl -> compareF u v
211 SpInt32 r -> coerce compareI (rep2i r u) (rep2i r v)
212 SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex
213 SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex
214
111selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 215selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
216selectV c l e g = fromMaybe (error "selectV" c l e g) $ specialize l <&> \case
217 SpDouble Refl -> selectD c l e g
218 SpInt64 r -> i2rep r (selectL c (rep2i r l) (rep2i r e) (rep2i r g))
219 SpFloat Refl -> selectF c l e g
220 SpInt32 r -> i2rep r (coerce selectI c (rep2i r l) (rep2i r e) (rep2i r g))
221 SpCDouble Refl -> selectC c l e g
222 SpCFloat Refl -> selectQ c l e g
223
112remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 224remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
225remapM i j m = fromMaybe (error "remapM" i j m) $ specialize m <&> \case
226 SpDouble Refl -> remapD i j m
227 SpInt64 r -> i2repM r (remapL i j (rep2iM r m))
228 SpFloat Refl -> remapF i j m
229 SpInt32 r -> i2repM r (coerce remapI i j (rep2iM r m))
230 SpCDouble Refl -> remapC i j m
231 SpCFloat Refl -> remapQ i j m
232
113rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 233rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
234rowOp c a i1 i2 j1 j2 x = fromMaybe (error "rowOp") $ specialize x <&> \case
235 SpDouble Refl -> rowOpAux c_rowOpD c a i1 i2 j1 j2 x
236 SpInt64 r -> case modulo r of
237 Just m' -> rowOpAux (c_rowOpML m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x)
238 Nothing -> rowOpAux c_rowOpL c (rep2one r a) i1 i2 j1 j2 (rep2iM r x)
239 SpFloat Refl -> rowOpAux c_rowOpF c a i1 i2 j1 j2 x
240 SpInt32 r -> case modulo r of
241 Just m' -> rowOpAux (coerce c_rowOpMI m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x)
242 Nothing -> rowOpAux (coerce c_rowOpI) c (rep2one r a) i1 i2 j1 j2 (rep2iM r x)
243 SpCDouble Refl -> rowOpAux c_rowOpC c a i1 i2 j1 j2 x
244 SpCFloat Refl -> rowOpAux c_rowOpQ c a i1 i2 j1 j2 x
245
114gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () 246gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
247gemm u a b c = fromMaybe (error "gemm") $ specialize u <&> \case
248 SpDouble Refl -> gemmg c_gemmD u a b c
249 SpInt64 r -> case modulo r of
250 Just m' -> gemmg (c_gemmML m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c)
251 Nothing -> gemmg c_gemmL (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c)
252 SpFloat Refl -> gemmg c_gemmF u a b c
253 SpInt32 r -> case modulo r of
254 Just m' -> gemmg (coerce c_gemmMI m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c)
255 Nothing -> gemmg (coerce c_gemmI) (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c)
256 SpCDouble Refl -> gemmg c_gemmC u a b c
257 SpCFloat Refl -> gemmg c_gemmQ u a b c
258
115reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation 259reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
260reorderV strides dims v = fromMaybe (error "reorderV") $ specialize v <&> \case
261 SpDouble Refl -> reorderAux c_reorderD strides dims v
262 SpInt64 r -> i2rep r $ reorderAux c_reorderL strides dims (rep2i r v)
263 SpFloat Refl -> reorderAux c_reorderF strides dims v
264 SpInt32 r -> i2rep r $ reorderAux (coerce c_reorderI) strides dims (rep2i r v)
265 SpCDouble Refl -> reorderAux c_reorderC strides dims v
266 SpCFloat Refl -> reorderAux c_reorderQ strides dims v
116 267
117instance KnownNat m => Element (Mod m I) 268
269instance Storable t => TransArray (Matrix t)
118 where 270 where
119 constantD x n = i2f (constantD (unMod x) n) 271 type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b
120 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js 272 type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b
121 setRect i j m x = setRect i j (f2iM m) (f2iM x) 273 apply = amat
122 sortI = sortI . f2i 274 {-# INLINE apply #-}
123 sortV = i2f . sortV . f2i 275 applyRaw = amatr
124 compareV u v = compareV (f2i u) (f2i v) 276 {-# INLINE applyRaw #-}
125 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) 277
126 remapM i j m = i2fM (remap i j (f2iM m)) 278-- C-Haskell matrix adapters
127 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) 279{-# INLINE amatr #-}
128 where 280amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r
129 m' = fromIntegral . natVal $ (undefined :: Proxy m) 281amatr x f g = unsafeWith (xdat x) (f . g r c)
130 gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
131 where
132 m' = fromIntegral . natVal $ (undefined :: Proxy m)
133
134instance KnownNat m => Element (Mod m Z)
135 where 282 where
136 constantD x n = i2f (constantD (unMod x) n) 283 r = fi (rows x)
137 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js 284 c = fi (cols x)
138 setRect i j m x = setRect i j (f2iM m) (f2iM x) 285
139 sortI = sortI . f2i 286{-# INLINE amat #-}
140 sortV = i2f . sortV . f2i 287amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r
141 compareV u v = compareV (f2i u) (f2i v) 288amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
142 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) 289 where
143 remapM i j m = i2fM (remap i j (f2iM m)) 290 r = fi (rows x)
144 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) 291 c = fi (cols x)
145 where 292 sr = fi (xRow x)
146 m' = fromIntegral . natVal $ (undefined :: Proxy m) 293 sc = fi (xCol x)
147 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) 294
148 where 295rows :: Matrix t -> Int
149 m' = fromIntegral . natVal $ (undefined :: Proxy m) 296rows = irows
150-} 297{-# INLINE rows #-}
298
299cols :: Matrix t -> Int
300cols = icols
301{-# INLINE cols #-}
302
151 303
304infixr 1 #
305(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
306a # b = apply a b
307{-# INLINE (#) #-}
152 308
153( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV ) 309(#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
154 = error "todo Element" 310a #! b = a # b # id
311{-# INLINE (#!) #-}
155 312
156constantAux :: (Storable a1, Storable a) 313constantAux :: (Storable a1, Storable a)
157 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a 314 => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a
@@ -169,3 +326,237 @@ foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float)
169foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) 326foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double)
170foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 327foreign import ccall unsafe "constantI" cconstantI :: TConst Int32
171foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 328foreign import ccall unsafe "constantL" cconstantL :: TConst Int64
329
330{-
331extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
332 Storable t, Num t3, Num t2, Integral t1, Integral t)
333 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
334 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
335 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
336extractAux f ord m moder vr modec vc = do
337 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
338 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
339 r <- createMatrix ord nr nc
340 (vr # vc # m #!r) (f moder modec) #|"extract"
341 return r
342-}
343
344extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
345 Storable t, Num t3, Num t2, Integral t1, Integral t)
346 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
347 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
348 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
349extractAux f ord m moder vr modec vc = do
350 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
351 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
352 r <- createMatrix ord nr nc
353 (vr # vc # m #! r) (f moder modec) #|"extract"
354 return r
355
356type Extr x = CInt -> CInt ->
357 CInt -> Ptr CInt -> -- CIdxs
358 CInt -> Ptr CInt -> -- CIdxs
359 CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x
360 CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x
361 IO CInt
362foreign import ccall unsafe "extractD" c_extractD :: Extr Double
363foreign import ccall unsafe "extractF" c_extractF :: Extr Float
364foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
365foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
366foreign import ccall unsafe "extractI" c_extractI :: Extr CInt
367foreign import ccall unsafe "extractL" c_extractL :: Extr Int64
368
369setRectAux :: (TransArray c1, TransArray c)
370 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt)))
371 -> Int -> Int -> c1 -> c -> IO ()
372setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
373
374type SetRect x = I -> I -> x ::> x::> Ok
375
376foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double
377foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float
378foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double)
379foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float)
380foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I
381foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
382
383sortG :: (Storable t, Storable a)
384 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
385sortG f v = unsafePerformIO $ do
386 r <- createVector (dim v)
387 (v #! r) f #|"sortG"
388 return r
389
390sortIdxD :: Vector Double -> Vector CInt
391sortIdxD = sortG c_sort_indexD
392sortIdxF :: Vector Float -> Vector CInt
393sortIdxF = sortG c_sort_indexF
394sortIdxI :: Vector CInt -> Vector CInt
395sortIdxI = sortG c_sort_indexI
396sortIdxL :: Vector Z -> Vector I
397sortIdxL = sortG c_sort_indexL
398
399foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
400foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt))
401foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt))
402foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok
403
404sortValD :: Vector Double -> Vector Double
405sortValD = sortG c_sort_valD
406sortValF :: Vector Float -> Vector Float
407sortValF = sortG c_sort_valF
408sortValI :: Vector CInt -> Vector CInt
409sortValI = sortG c_sort_valI
410sortValL :: Vector Z -> Vector Z
411sortValL = sortG c_sort_valL
412
413foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt))
414foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt))
415foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt))
416foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
417
418compareG :: (TransArray c, Storable t, Storable a)
419 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt)
420 -> c -> Vector t -> Vector a
421compareG f u v = unsafePerformIO $ do
422 r <- createVector (dim v)
423 (u # v #! r) f #|"compareG"
424 return r
425
426compareD :: Vector Double -> Vector Double -> Vector CInt
427compareD = compareG c_compareD
428compareF :: Vector Float -> Vector Float -> Vector CInt
429compareF = compareG c_compareF
430compareI :: Vector CInt -> Vector CInt -> Vector CInt
431compareI = compareG c_compareI
432compareL :: Vector Z -> Vector Z -> Vector CInt
433compareL = compareG c_compareL
434
435foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
436foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt)))
437foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt)))
438foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
439
440selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
441 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt)))
442 -> c2 -> c1 -> Vector t -> c -> Vector a
443selectG f c u v w = unsafePerformIO $ do
444 r <- createVector (dim v)
445 (c # u # v # w #! r) f #|"selectG"
446 return r
447
448selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double
449selectD = selectG c_selectD
450selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float
451selectF = selectG c_selectF
452selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt
453selectI = selectG c_selectI
454selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z
455selectL = selectG c_selectL
456selectC :: Vector CInt
457 -> Vector (Complex Double)
458 -> Vector (Complex Double)
459 -> Vector (Complex Double)
460 -> Vector (Complex Double)
461selectC = selectG c_selectC
462selectQ :: Vector CInt
463 -> Vector (Complex Float)
464 -> Vector (Complex Float)
465 -> Vector (Complex Float)
466 -> Vector (Complex Float)
467selectQ = selectG c_selectQ
468
469type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
470
471foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
472foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
473foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
474foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
475foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
476foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
477
478
479remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
480 => (CInt -> CInt -> CInt -> CInt -> Ptr t
481 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)))
482 -> Matrix t -> c1 -> c -> Matrix a
483remapG f i j m = unsafePerformIO $ do
484 r <- createMatrix RowMajor (rows i) (cols i)
485 (i # j # m #! r) f #|"remapG"
486 return r
487
488remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double
489remapD = remapG c_remapD
490remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float
491remapF = remapG c_remapF
492remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt
493remapI = remapG c_remapI
494remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z
495remapL = remapG c_remapL
496remapC :: Matrix CInt
497 -> Matrix CInt
498 -> Matrix (Complex Double)
499 -> Matrix (Complex Double)
500remapC = remapG c_remapC
501remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float)
502remapQ = remapG c_remapQ
503
504type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
505
506foreign import ccall unsafe "remapD" c_remapD :: Rem Double
507foreign import ccall unsafe "remapF" c_remapF :: Rem Float
508foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
509foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
510foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
511foreign import ccall unsafe "remapL" c_remapL :: Rem Z
512
513
514rowOpAux :: (TransArray c, Storable a) =>
515 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt))
516 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
517rowOpAux f c x i1 i2 j1 j2 m = do
518 px <- newArray [x]
519 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
520 free px
521
522type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
523
524foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
525foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
526foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C
527foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float)
528foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I
529foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z
530foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I
531foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
532
533
534gemmg :: Storable x => Tgemm x -> Vector x -> Matrix x -> Matrix x -> Matrix x -> IO ()
535gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
536
537type Tgemm x = x :> x ::> x ::> x ::> Ok
538
539foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
540foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
541foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C
542foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float)
543foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
544foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
545foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
546foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
547
548reorderAux :: Storable x => Reorder x -> Vector CInt -> Vector CInt -> Vector x -> Vector x
549reorderAux f s d v = unsafePerformIO $ do
550 k <- createVector (dim s)
551 r <- createVector (dim v)
552 (k # s # d # v #! r) f #| "reorderV"
553 return r
554
555type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt)))))
556
557foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
558foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
559foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt
560foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
561foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
562foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z