From 145a61cc82ab66853daed8b352cb283fdcc790c5 Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Sat, 10 Aug 2019 01:39:35 -0400 Subject: More specialization. --- packages/base/src/Internal/Matrix.hs | 382 -------------------- packages/base/src/Internal/Modular.hs | 10 - packages/base/src/Internal/Specialized.hs | 561 +++++++++++++++++++++++++----- 3 files changed, 476 insertions(+), 477 deletions(-) diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 7c774ef..225b039 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -37,31 +37,6 @@ import Text.Printf ----------------------------------------------------------------- -data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) - --- | Matrix representation suitable for BLAS\/LAPACK computations. - -data Matrix t = Matrix - { irows :: {-# UNPACK #-} !Int - , icols :: {-# UNPACK #-} !Int - , xRow :: {-# UNPACK #-} !Int - , xCol :: {-# UNPACK #-} !Int - , xdat :: {-# UNPACK #-} !(Vector t) - } - - -rows :: Matrix t -> Int -rows = irows -{-# INLINE rows #-} - -cols :: Matrix t -> Int -cols = icols -{-# INLINE cols #-} - -size :: Matrix t -> (Int, Int) -size m = (irows m, icols m) -{-# INLINE size #-} - rowOrder :: Matrix t -> Bool rowOrder m = xCol m == 1 || cols m == 1 {-# INLINE rowOrder #-} @@ -114,33 +89,6 @@ fmat m | otherwise = extractAll ColumnMajor m --- C-Haskell matrix adapters -{-# INLINE amatr #-} -amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r -amatr x f g = unsafeWith (xdat x) (f . g r c) - where - r = fi (rows x) - c = fi (cols x) - -{-# INLINE amat #-} -amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r -amat x f g = unsafeWith (xdat x) (f . g r c sr sc) - where - r = fi (rows x) - c = fi (cols x) - sr = fi (xRow x) - sc = fi (xCol x) - - -instance Storable t => TransArray (Matrix t) - where - type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b - type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b - apply = amat - {-# INLINE apply #-} - applyRaw = amatr - {-# INLINE applyRaw #-} - infixr 1 # (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r a # b = apply a b @@ -240,22 +188,6 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) ------------------------------------------------------------------ -matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t -matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } -matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } -matrixFromVector o r c v - | r * c == dim v = m - | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m - where - m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } - | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } - --- allocates memory for a new matrix -createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) -createMatrix ord r c = do - p <- createVector (r*c) - return (matrixFromVector ord r c p) - {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ where r is the desired number of rows.) @@ -286,101 +218,6 @@ liftMatrix2 f m1@(size->(r,c)) m2 ------------------------------------------------------------------ -{- --- | Supported matrix elements. -class (Storable a) => Element a where - constantD :: a -> Int -> Vector a - extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) - setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () - sortI :: Ord a => Vector a -> Vector CInt - sortV :: Ord a => Vector a -> Vector a - compareV :: Ord a => Vector a -> Vector a -> Vector CInt - selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a - remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a - rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () - gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () - reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation - - -instance Element Float where - constantD = constantAux cconstantF - extractR = extractAux c_extractF - setRect = setRectAux c_setRectF - sortI = sortIdxF - sortV = sortValF - compareV = compareF - selectV = selectF - remapM = remapF - rowOp = rowOpAux c_rowOpF - gemm = gemmg c_gemmF - reorderV = reorderAux c_reorderF - -instance Element Double where - constantD = constantAux cconstantR - extractR = extractAux c_extractD - setRect = setRectAux c_setRectD - sortI = sortIdxD - sortV = sortValD - compareV = compareD - selectV = selectD - remapM = remapD - rowOp = rowOpAux c_rowOpD - gemm = gemmg c_gemmD - reorderV = reorderAux c_reorderD - -instance Element (Complex Float) where - constantD = constantAux cconstantQ - extractR = extractAux c_extractQ - setRect = setRectAux c_setRectQ - sortI = undefined - sortV = undefined - compareV = undefined - selectV = selectQ - remapM = remapQ - rowOp = rowOpAux c_rowOpQ - gemm = gemmg c_gemmQ - reorderV = reorderAux c_reorderQ - -instance Element (Complex Double) where - constantD = constantAux cconstantC - extractR = extractAux c_extractC - setRect = setRectAux c_setRectC - sortI = undefined - sortV = undefined - compareV = undefined - selectV = selectC - remapM = remapC - rowOp = rowOpAux c_rowOpC - gemm = gemmg c_gemmC - reorderV = reorderAux c_reorderC - -instance Element (CInt) where - constantD = constantAux cconstantI - extractR = extractAux c_extractI - setRect = setRectAux c_setRectI - sortI = sortIdxI - sortV = sortValI - compareV = compareI - selectV = selectI - remapM = remapI - rowOp = rowOpAux c_rowOpI - gemm = gemmg c_gemmI - reorderV = reorderAux c_reorderI - -instance Element Z where - constantD = constantAux cconstantL - extractR = extractAux c_extractL - setRect = setRectAux c_setRectL - sortI = sortIdxL - sortV = sortValL - compareV = compareL - selectV = selectL - remapM = remapL - rowOp = rowOpAux c_rowOpL - gemm = gemmg c_gemmL - reorderV = reorderAux c_reorderL --} - ------------------------------------------------------------------- -- | reference to a rectangular slice of a matrix (no data copy) @@ -435,12 +272,6 @@ repRows n x = fromRows (replicate n (flatten x)) repCols :: Element t => Int -> Matrix t -> Matrix t repCols n x = fromColumns (replicate n (flatten x)) -shSize :: Matrix t -> [Char] -shSize = shDim . size - -shDim :: (Show a, Show a1) => (a1, a) -> [Char] -shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" - emptyM :: Storable t => Int -> Int -> Matrix t emptyM r c = matrixFromVector RowMajor r c (fromList[]) @@ -456,19 +287,6 @@ instance (Storable t, NFData t) => NFData (Matrix t) --------------------------------------------------------------- -extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, - Storable t, Num t3, Num t2, Integral t1, Integral t) - => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t - -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) - -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) -extractAux f ord m moder vr modec vc = do - let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr - nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc - r <- createMatrix ord nr nc - (vr # vc # m #! r) (f moder modec) #|"extract" - - return r - type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) foreign import ccall unsafe "extractD" c_extractD :: Extr Double @@ -480,217 +298,17 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z --------------------------------------------------------------- -setRectAux :: (TransArray c1, TransArray c) - => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) - -> Int -> Int -> c1 -> c -> IO () -setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" - -type SetRect x = I -> I -> x ::> x::> Ok - -foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double -foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float -foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) -foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) -foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I -foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z - -------------------------------------------------------------------------------- -sortG :: (Storable t, Storable a) - => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a -sortG f v = unsafePerformIO $ do - r <- createVector (dim v) - (v #! r) f #|"sortG" - return r - -sortIdxD :: Vector Double -> Vector CInt -sortIdxD = sortG c_sort_indexD -sortIdxF :: Vector Float -> Vector CInt -sortIdxF = sortG c_sort_indexF -sortIdxI :: Vector CInt -> Vector CInt -sortIdxI = sortG c_sort_indexI -sortIdxL :: Vector Z -> Vector I -sortIdxL = sortG c_sort_indexL - -sortValD :: Vector Double -> Vector Double -sortValD = sortG c_sort_valD -sortValF :: Vector Float -> Vector Float -sortValF = sortG c_sort_valF -sortValI :: Vector CInt -> Vector CInt -sortValI = sortG c_sort_valI -sortValL :: Vector Z -> Vector Z -sortValL = sortG c_sort_valL - -foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok - -foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) -foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) -foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) -foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok - -------------------------------------------------------------------------------- -compareG :: (TransArray c, Storable t, Storable a) - => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) - -> c -> Vector t -> Vector a -compareG f u v = unsafePerformIO $ do - r <- createVector (dim v) - (u # v #! r) f #|"compareG" - return r - -compareD :: Vector Double -> Vector Double -> Vector CInt -compareD = compareG c_compareD -compareF :: Vector Float -> Vector Float -> Vector CInt -compareF = compareG c_compareF -compareI :: Vector CInt -> Vector CInt -> Vector CInt -compareI = compareG c_compareI -compareL :: Vector Z -> Vector Z -> Vector CInt -compareL = compareG c_compareL - -foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) -foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) -foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) -foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok - -------------------------------------------------------------------------------- -selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) - => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) - -> c2 -> c1 -> Vector t -> c -> Vector a -selectG f c u v w = unsafePerformIO $ do - r <- createVector (dim v) - (c # u # v # w #! r) f #|"selectG" - return r - -selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double -selectD = selectG c_selectD -selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float -selectF = selectG c_selectF -selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -selectI = selectG c_selectI -selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z -selectL = selectG c_selectL -selectC :: Vector CInt - -> Vector (Complex Double) - -> Vector (Complex Double) - -> Vector (Complex Double) - -> Vector (Complex Double) -selectC = selectG c_selectC -selectQ :: Vector CInt - -> Vector (Complex Float) - -> Vector (Complex Float) - -> Vector (Complex Float) - -> Vector (Complex Float) -selectQ = selectG c_selectQ - -type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) - -foreign import ccall unsafe "chooseD" c_selectD :: Sel Double -foreign import ccall unsafe "chooseF" c_selectF :: Sel Float -foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt -foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) -foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) -foreign import ccall unsafe "chooseL" c_selectL :: Sel Z - --------------------------------------------------------------------------- - -remapG :: (TransArray c, TransArray c1, Storable t, Storable a) - => (CInt -> CInt -> CInt -> CInt -> Ptr t - -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) - -> Matrix t -> c1 -> c -> Matrix a -remapG f i j m = unsafePerformIO $ do - r <- createMatrix RowMajor (rows i) (cols i) - (i # j # m #! r) f #|"remapG" - return r - -remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double -remapD = remapG c_remapD -remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float -remapF = remapG c_remapF -remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt -remapI = remapG c_remapI -remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z -remapL = remapG c_remapL -remapC :: Matrix CInt - -> Matrix CInt - -> Matrix (Complex Double) - -> Matrix (Complex Double) -remapC = remapG c_remapC -remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) -remapQ = remapG c_remapQ - -type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) - -foreign import ccall unsafe "remapD" c_remapD :: Rem Double -foreign import ccall unsafe "remapF" c_remapF :: Rem Float -foreign import ccall unsafe "remapI" c_remapI :: Rem CInt -foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) -foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) -foreign import ccall unsafe "remapL" c_remapL :: Rem Z - -------------------------------------------------------------------------------- - -rowOpAux :: (TransArray c, Storable a) => - (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) - -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () -rowOpAux f c x i1 i2 j1 j2 m = do - px <- newArray [x] - (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" - free px - -type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok - -foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R -foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float -foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C -foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) -foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I -foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z -foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I -foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z - -------------------------------------------------------------------------------- -gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) - => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) - -> c3 -> c2 -> c1 -> c -> IO () -gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" - -type Tgemm x = x :> x ::> x ::> x ::> Ok - -foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R -foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float -foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C -foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) -foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I -foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z -foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I -foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z - -------------------------------------------------------------------------------- - -reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => - (CInt -> Ptr a -> CInt -> Ptr t1 - -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) - -> Vector t1 -> c -> Vector t -> Vector a1 -reorderAux f s d v = unsafePerformIO $ do - k <- createVector (dim s) - r <- createVector (dim v) - (k # s # d # v #! r) f #| "reorderV" - return r - -type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) - -foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double -foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float -foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt -foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) -foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) -foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z - -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ -- This function is intended to be used internally by tensor libraries. diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index a211dd3..10ff8a3 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -257,16 +257,6 @@ instance KnownNat m => Normed (Vector (Mod m Z)) instance KnownNat m => Numeric (Mod m I) instance KnownNat m => Numeric (Mod m Z) -f2i :: Storable t => Vector (Mod n t) -> Vector t -f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) - where (fp,i,n) = unsafeToForeignPtr v - -f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t -f2iM m = m { xdat = f2i (xdat m) } - -i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t) -i2fM m = m { xdat = i2f (xdat m) } - vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) vmod = i2f . cmod' m' where diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs index c79194f..c063369 100644 --- a/packages/base/src/Internal/Specialized.hs +++ b/packages/base/src/Internal/Specialized.hs @@ -8,6 +8,8 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE LambdaCase #-} module Internal.Specialized where import Control.Monad @@ -16,6 +18,7 @@ import Data.Coerce import Data.Complex import Data.Functor import Data.Int +import Data.Maybe import Data.Typeable (eqT,Proxy) import Type.Reflection import Foreign.Marshal.Alloc(free,malloc) @@ -31,127 +34,281 @@ import GHC.TypeLits hiding (Mod) import GHC.TypeLits #endif -import Internal.Vector (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr) +import Internal.Vector -- (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr,(@>)) import Internal.Devel -eqt :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) -eqt _ = eqT -eq32 :: (Typeable a) => a -> Maybe (a :~: Int32) -eq32 _ = eqT -eq64 :: (Typeable a) => a -> Maybe (a :~: Int64) -eq64 _ = eqT -eqint :: (Typeable a) => a -> Maybe (a :~: CInt) -eqint _ = eqT +eqp :: (Typeable a, Typeable b) => proxy a -> Maybe (a :~: b) +eqp _ = eqT +ep32 :: (Typeable a) => proxy a -> Maybe (a :~: Int32) +ep32 _ = eqT +ep64 :: (Typeable a) => proxy a -> Maybe (a :~: Int64) +ep64 _ = eqT +epint :: (Typeable a) => proxy a -> Maybe (a :~: CInt) +epint _ = eqT type Element t = (Storable t, Typeable t) +-- | Wrapper with a phantom integer for statically checked modular arithmetic. +newtype Mod (n :: Nat) t = Mod {unMod:: t} + deriving (Storable) + +instance (NFData t) => NFData (Mod n t) + where + rnf (Mod x) = rnf x + +i2fM :: Storable t => Matrix t -> Matrix (Mod n t) +i2fM m = m { xdat = i2f (xdat m) } + +i2f :: Storable t => Vector t -> Vector (Mod n t) +i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) + where (fp,i,n) = unsafeToForeignPtr v + +f2i :: Storable t => Vector (Mod n t) -> Vector t +f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) + where (fp,i,n) = unsafeToForeignPtr v + +f2iM :: Storable t => Matrix (Mod n t) -> Matrix t +f2iM m = m { xdat = f2i (xdat m) } + +data IntegralRep t a = IntegralRep + { i2rep :: Vector t -> Vector a + , i2repM :: Matrix t -> Matrix a + , rep2i :: Vector a -> Vector t + , rep2iM :: Matrix a -> Matrix t + , rep2one :: a -> t + , modulo :: Maybe t + } + +idint :: Storable t => IntegralRep t t +idint = IntegralRep id id id id id Nothing + +coerceint :: Coercible t a => IntegralRep t a +coerceint = IntegralRep coerce coerce coerce coerce coerce Nothing + +modint :: forall t n. (Read t, Storable t) => TypeRep n -> IntegralRep t (Mod n t) +modint r = IntegralRep i2f i2fM f2i f2iM unMod (Just n) + where + n = read . show $ r -- XXX: Hack to get nat value from Type.Reflection + -- n = fromIntegral . natVal $ (undefined :: Proxy n) + + +typeRepOf :: Typeable a => proxy a -> TypeRep a +typeRepOf proxy = typeRep + data Specialized a = SpFloat !(a :~: Float) | SpDouble !(a :~: Double) | SpCFloat !(a :~: Complex Float) | SpCDouble !(a :~: Complex Double) - | SpInt32 !(Vector Int32 -> Vector a) !Int32 - | SpInt64 !(Vector Int64 -> Vector a) !Int64 - -- | SpModInt32 !Int32 Int32 !(forall f. f Int32 -> f a) - -- | SpModInt64 !Int32 Int64 !(forall f. f Int64 -> f a) + | SpInt32 !(IntegralRep Int32 a) + | SpInt64 !(IntegralRep Int64 a) -specialize :: forall a. Typeable a => a -> Maybe (Specialized a) +specialize :: forall m a. Typeable a => m a -> Maybe (Specialized a) specialize x = foldr1 mplus - [ SpDouble <$> eqt x - , eq64 x <&> \Refl -> SpInt64 id x - , SpFloat <$> eqt x - , eq32 x <&> \Refl -> SpInt32 id x - , SpCDouble <$> eqt x - , SpCFloat <$> eqt x - , eqint x <&> \Refl -> case x of CInt y -> SpInt32 coerce y - -- , em32 x <&> \(nat,Refl) -> case x of Mod y -> SpInt32 (i2f' nat) y - , case typeOf x of - App (App modtyp ntyp) inttyp -> case eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp of - Just HRefl -> let i = unMod x - in case eqTypeRep (typeRep :: TypeRep Int32) inttyp of - Just HRefl -> Just $ SpInt32 i2f i - _ -> case eqTypeRep (typeRep :: TypeRep Int64) inttyp of - Just HRefl -> Just $ SpInt64 i2f i - _ -> Nothing - Nothing -> Nothing + [ SpDouble <$> eqp x + , ep64 x <&> \Refl -> SpInt64 idint + , SpFloat <$> eqp x + , ep32 x <&> \Refl -> SpInt32 idint + , SpCDouble <$> eqp x + , SpCFloat <$> eqp x + , epint x <&> \Refl -> SpInt32 coerceint + , case typeRepOf x of + App (App modtyp n) inttyp + -> do HRefl <- eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp + mplus (eqTypeRep (typeRep :: TypeRep Int32) inttyp <&> \HRefl -> SpInt32 $ modint n) + (eqTypeRep (typeRep :: TypeRep Int64) inttyp <&> \HRefl -> SpInt64 $ modint n) _ -> Nothing ] -- | Supported matrix elements. constantD :: Typeable a => a -> Int -> Vector a -constantD x = case specialize x of - Nothing -> error "constantD" - Just (SpDouble Refl) -> constantAux cconstantR x - Just (SpInt64 out y) -> out . constantAux cconstantL y - Just (SpFloat Refl) -> constantAux cconstantF x - Just (SpInt32 out y) -> out . constantAux cconstantI y - Just (SpCDouble Refl) -> constantAux cconstantC x - Just (SpCFloat Refl) -> constantAux cconstantQ x - -- Just (SpModInt32 _ y ret) -> \n -> ret (constantAux cconstantI y n) +constantD x = fromMaybe (error "constantD") $ specialize (const x) <&> \case + SpDouble Refl -> constantAux cconstantR x + SpInt64 r -> i2rep r . constantAux cconstantL (rep2one r x) + SpFloat Refl -> constantAux cconstantF x + SpInt32 r -> i2rep r . constantAux cconstantI (rep2one r x) + SpCDouble Refl -> constantAux cconstantC x + SpCFloat Refl -> constantAux cconstantQ x --- | Wrapper with a phantom integer for statically checked modular arithmetic. -newtype Mod (n :: Nat) t = Mod {unMod:: t} - deriving (Storable) +data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) -instance (NFData t) => NFData (Mod n t) +-- | Matrix representation suitable for BLAS\/LAPACK computations. +data Matrix t = Matrix + { irows :: {-# UNPACK #-} !Int + , icols :: {-# UNPACK #-} !Int + , xRow :: {-# UNPACK #-} !Int + , xCol :: {-# UNPACK #-} !Int + , xdat :: {-# UNPACK #-} !(Vector t) + } + +-- allocates memory for a new matrix +createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) +createMatrix ord r c = do + p <- createVector (r*c) + return (matrixFromVector ord r c p) + +matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t +matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } +matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } +matrixFromVector o r c v + | r * c == dim v = m + | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m where - rnf (Mod x) = rnf x + m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } + | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } -i2f :: Storable t => Vector t -> Vector (Mod n t) -i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) - where (fp,i,n) = unsafeToForeignPtr v +shSize :: Matrix t -> [Char] +shSize = shDim . size +shDim :: (Show a, Show a1) => (a1, a) -> [Char] +shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" + +size :: Matrix t -> (Int, Int) +size m = (irows m, icols m) +{-# INLINE size #-} -{- extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) +extractR ord m = fromMaybe (\mi is mj js -> error "extractR") $ specialize m <&> \case + SpDouble Refl -> extractAux c_extractD ord m + SpInt64 r -> \mi is mj js -> i2repM r <$> extractAux c_extractL ord (rep2iM r m) mi is mj js + SpFloat Refl -> extractAux c_extractF ord m + SpInt32 r -> \mi is mj js -> i2repM r <$> extractAux (coerce c_extractI) ord (rep2iM r m) mi is mj js + SpCDouble Refl -> extractAux c_extractC ord m + SpCFloat Refl -> extractAux c_extractQ ord m + setRect :: Typeable a => Int -> Int -> Matrix a -> Matrix a -> IO () -sortI :: (Typeable a , Ord a ) => Vector a -> Vector CInt -sortV :: (Typeable a , Ord a ) => Vector a -> Vector a +setRect i j m x = fromMaybe (error "setRect") $ specialize m <&> \case + SpDouble Refl -> setRectAux c_setRectD i j m x + SpInt64 r -> setRectAux c_setRectL i j (rep2iM r m) (rep2iM r x) + SpFloat Refl -> setRectAux c_setRectF i j m x + SpInt32 r -> setRectAux (coerce c_setRectI) i j (rep2iM r m) (rep2iM r x) + SpCDouble Refl -> setRectAux c_setRectC i j m x + SpCFloat Refl -> setRectAux c_setRectQ i j m x + +sortI :: (Typeable a , Ord a) => Vector a -> Vector CInt +sortI v = maybe (error "sortI") ($ v) $ specialize v <&> \case + SpDouble Refl -> sortIdxD + SpInt64 r -> sortIdxL . rep2i r + SpFloat Refl -> sortIdxF + SpInt32 r -> coerce sortIdxI . rep2i r + SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex + SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex + +sortV :: (Typeable a , Ord a ) => Vector a -> Vector a +sortV v = maybe (error "sortV") ($ v) $ specialize v <&> \case + SpDouble Refl -> sortValD + SpInt64 r -> i2rep r . sortValL . rep2i r + SpFloat Refl -> sortValF + SpInt32 r -> i2rep r . coerce sortValI . rep2i r + SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex + SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex + compareV :: (Typeable a , Ord a ) => Vector a -> Vector a -> Vector CInt +compareV u v = fromMaybe (error "compareV" u v) $ specialize u <&> \case + SpDouble Refl -> compareD u v + SpInt64 r -> compareL (rep2i r u) (rep2i r v) + SpFloat Refl -> compareF u v + SpInt32 r -> coerce compareI (rep2i r u) (rep2i r v) + SpCDouble Refl -> undefined -- Unreachable: Ord not implemented for Complex + SpCFloat Refl -> undefined -- Unreachable: Ord not implemented for Complex + selectV :: Typeable a => Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a +selectV c l e g = fromMaybe (error "selectV" c l e g) $ specialize l <&> \case + SpDouble Refl -> selectD c l e g + SpInt64 r -> i2rep r (selectL c (rep2i r l) (rep2i r e) (rep2i r g)) + SpFloat Refl -> selectF c l e g + SpInt32 r -> i2rep r (coerce selectI c (rep2i r l) (rep2i r e) (rep2i r g)) + SpCDouble Refl -> selectC c l e g + SpCFloat Refl -> selectQ c l e g + remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a +remapM i j m = fromMaybe (error "remapM" i j m) $ specialize m <&> \case + SpDouble Refl -> remapD i j m + SpInt64 r -> i2repM r (remapL i j (rep2iM r m)) + SpFloat Refl -> remapF i j m + SpInt32 r -> i2repM r (coerce remapI i j (rep2iM r m)) + SpCDouble Refl -> remapC i j m + SpCFloat Refl -> remapQ i j m + rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () +rowOp c a i1 i2 j1 j2 x = fromMaybe (error "rowOp") $ specialize x <&> \case + SpDouble Refl -> rowOpAux c_rowOpD c a i1 i2 j1 j2 x + SpInt64 r -> case modulo r of + Just m' -> rowOpAux (c_rowOpML m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) + Nothing -> rowOpAux c_rowOpL c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) + SpFloat Refl -> rowOpAux c_rowOpF c a i1 i2 j1 j2 x + SpInt32 r -> case modulo r of + Just m' -> rowOpAux (coerce c_rowOpMI m') c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) + Nothing -> rowOpAux (coerce c_rowOpI) c (rep2one r a) i1 i2 j1 j2 (rep2iM r x) + SpCDouble Refl -> rowOpAux c_rowOpC c a i1 i2 j1 j2 x + SpCFloat Refl -> rowOpAux c_rowOpQ c a i1 i2 j1 j2 x + gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () +gemm u a b c = fromMaybe (error "gemm") $ specialize u <&> \case + SpDouble Refl -> gemmg c_gemmD u a b c + SpInt64 r -> case modulo r of + Just m' -> gemmg (c_gemmML m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) + Nothing -> gemmg c_gemmL (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) + SpFloat Refl -> gemmg c_gemmF u a b c + SpInt32 r -> case modulo r of + Just m' -> gemmg (coerce c_gemmMI m') (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) + Nothing -> gemmg (coerce c_gemmI) (rep2i r u) (rep2iM r a) (rep2iM r b) (rep2iM r c) + SpCDouble Refl -> gemmg c_gemmC u a b c + SpCFloat Refl -> gemmg c_gemmQ u a b c + reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation +reorderV strides dims v = fromMaybe (error "reorderV") $ specialize v <&> \case + SpDouble Refl -> reorderAux c_reorderD strides dims v + SpInt64 r -> i2rep r $ reorderAux c_reorderL strides dims (rep2i r v) + SpFloat Refl -> reorderAux c_reorderF strides dims v + SpInt32 r -> i2rep r $ reorderAux (coerce c_reorderI) strides dims (rep2i r v) + SpCDouble Refl -> reorderAux c_reorderC strides dims v + SpCFloat Refl -> reorderAux c_reorderQ strides dims v -instance KnownNat m => Element (Mod m I) + +instance Storable t => TransArray (Matrix t) where - constantD x n = i2f (constantD (unMod x) n) - extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js - setRect i j m x = setRect i j (f2iM m) (f2iM x) - sortI = sortI . f2i - sortV = i2f . sortV . f2i - compareV u v = compareV (f2i u) (f2i v) - selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) - remapM i j m = i2fM (remap i j (f2iM m)) - rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - -instance KnownNat m => Element (Mod m Z) + type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b + type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b + apply = amat + {-# INLINE apply #-} + applyRaw = amatr + {-# INLINE applyRaw #-} + +-- C-Haskell matrix adapters +{-# INLINE amatr #-} +amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r +amatr x f g = unsafeWith (xdat x) (f . g r c) where - constantD x n = i2f (constantD (unMod x) n) - extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js - setRect i j m x = setRect i j (f2iM m) (f2iM x) - sortI = sortI . f2i - sortV = i2f . sortV . f2i - compareV u v = compareV (f2i u) (f2i v) - selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) - remapM i j m = i2fM (remap i j (f2iM m)) - rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) --} + r = fi (rows x) + c = fi (cols x) + +{-# INLINE amat #-} +amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r +amat x f g = unsafeWith (xdat x) (f . g r c sr sc) + where + r = fi (rows x) + c = fi (cols x) + sr = fi (xRow x) + sc = fi (xCol x) + +rows :: Matrix t -> Int +rows = irows +{-# INLINE rows #-} + +cols :: Matrix t -> Int +cols = icols +{-# INLINE cols #-} + +infixr 1 # +(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r +a # b = apply a b +{-# INLINE (#) #-} -( extractR , setRect , sortI , sortV , compareV , selectV , remapM , rowOp , gemm , reorderV ) - = error "todo Element" +(#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r +a #! b = a # b # id +{-# INLINE (#!) #-} constantAux :: (Storable a1, Storable a) => (Ptr a1 -> CInt -> Ptr a -> IO CInt) -> a1 -> Int -> Vector a @@ -169,3 +326,237 @@ foreign import ccall unsafe "constantQ" cconstantQ :: TConst (Complex Float) foreign import ccall unsafe "constantC" cconstantC :: TConst (Complex Double) foreign import ccall unsafe "constantI" cconstantI :: TConst Int32 foreign import ccall unsafe "constantL" cconstantL :: TConst Int64 + +{- +extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, + Storable t, Num t3, Num t2, Integral t1, Integral t) + => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t + -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) + -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) +extractAux f ord m moder vr modec vc = do + let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr + nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc + r <- createMatrix ord nr nc + (vr # vc # m #!r) (f moder modec) #|"extract" + return r +-} + +extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, + Storable t, Num t3, Num t2, Integral t1, Integral t) + => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t + -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) + -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) +extractAux f ord m moder vr modec vc = do + let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr + nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc + r <- createMatrix ord nr nc + (vr # vc # m #! r) (f moder modec) #|"extract" + return r + +type Extr x = CInt -> CInt -> + CInt -> Ptr CInt -> -- CIdxs + CInt -> Ptr CInt -> -- CIdxs + CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x + CInt -> CInt -> CInt -> CInt -> Ptr x -> -- OM x + IO CInt +foreign import ccall unsafe "extractD" c_extractD :: Extr Double +foreign import ccall unsafe "extractF" c_extractF :: Extr Float +foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) +foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) +foreign import ccall unsafe "extractI" c_extractI :: Extr CInt +foreign import ccall unsafe "extractL" c_extractL :: Extr Int64 + +setRectAux :: (TransArray c1, TransArray c) + => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) + -> Int -> Int -> c1 -> c -> IO () +setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" + +type SetRect x = I -> I -> x ::> x::> Ok + +foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double +foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float +foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) +foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) +foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I +foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z + +sortG :: (Storable t, Storable a) + => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a +sortG f v = unsafePerformIO $ do + r <- createVector (dim v) + (v #! r) f #|"sortG" + return r + +sortIdxD :: Vector Double -> Vector CInt +sortIdxD = sortG c_sort_indexD +sortIdxF :: Vector Float -> Vector CInt +sortIdxF = sortG c_sort_indexF +sortIdxI :: Vector CInt -> Vector CInt +sortIdxI = sortG c_sort_indexI +sortIdxL :: Vector Z -> Vector I +sortIdxL = sortG c_sort_indexL + +foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok + +sortValD :: Vector Double -> Vector Double +sortValD = sortG c_sort_valD +sortValF :: Vector Float -> Vector Float +sortValF = sortG c_sort_valF +sortValI :: Vector CInt -> Vector CInt +sortValI = sortG c_sort_valI +sortValL :: Vector Z -> Vector Z +sortValL = sortG c_sort_valL + +foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) +foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) +foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) +foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok + +compareG :: (TransArray c, Storable t, Storable a) + => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) + -> c -> Vector t -> Vector a +compareG f u v = unsafePerformIO $ do + r <- createVector (dim v) + (u # v #! r) f #|"compareG" + return r + +compareD :: Vector Double -> Vector Double -> Vector CInt +compareD = compareG c_compareD +compareF :: Vector Float -> Vector Float -> Vector CInt +compareF = compareG c_compareF +compareI :: Vector CInt -> Vector CInt -> Vector CInt +compareI = compareG c_compareI +compareL :: Vector Z -> Vector Z -> Vector CInt +compareL = compareG c_compareL + +foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) +foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) +foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) +foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok + +selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) + => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) + -> c2 -> c1 -> Vector t -> c -> Vector a +selectG f c u v w = unsafePerformIO $ do + r <- createVector (dim v) + (c # u # v # w #! r) f #|"selectG" + return r + +selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double +selectD = selectG c_selectD +selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float +selectF = selectG c_selectF +selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt +selectI = selectG c_selectI +selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z +selectL = selectG c_selectL +selectC :: Vector CInt + -> Vector (Complex Double) + -> Vector (Complex Double) + -> Vector (Complex Double) + -> Vector (Complex Double) +selectC = selectG c_selectC +selectQ :: Vector CInt + -> Vector (Complex Float) + -> Vector (Complex Float) + -> Vector (Complex Float) + -> Vector (Complex Float) +selectQ = selectG c_selectQ + +type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) + +foreign import ccall unsafe "chooseD" c_selectD :: Sel Double +foreign import ccall unsafe "chooseF" c_selectF :: Sel Float +foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt +foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) +foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) +foreign import ccall unsafe "chooseL" c_selectL :: Sel Z + + +remapG :: (TransArray c, TransArray c1, Storable t, Storable a) + => (CInt -> CInt -> CInt -> CInt -> Ptr t + -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) + -> Matrix t -> c1 -> c -> Matrix a +remapG f i j m = unsafePerformIO $ do + r <- createMatrix RowMajor (rows i) (cols i) + (i # j # m #! r) f #|"remapG" + return r + +remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double +remapD = remapG c_remapD +remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float +remapF = remapG c_remapF +remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt +remapI = remapG c_remapI +remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z +remapL = remapG c_remapL +remapC :: Matrix CInt + -> Matrix CInt + -> Matrix (Complex Double) + -> Matrix (Complex Double) +remapC = remapG c_remapC +remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) +remapQ = remapG c_remapQ + +type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) + +foreign import ccall unsafe "remapD" c_remapD :: Rem Double +foreign import ccall unsafe "remapF" c_remapF :: Rem Float +foreign import ccall unsafe "remapI" c_remapI :: Rem CInt +foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) +foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) +foreign import ccall unsafe "remapL" c_remapL :: Rem Z + + +rowOpAux :: (TransArray c, Storable a) => + (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) + -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () +rowOpAux f c x i1 i2 j1 j2 m = do + px <- newArray [x] + (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" + free px + +type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok + +foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R +foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float +foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C +foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float) +foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I +foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z +foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I +foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z + + +gemmg :: Storable x => Tgemm x -> Vector x -> Matrix x -> Matrix x -> Matrix x -> IO () +gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" + +type Tgemm x = x :> x ::> x ::> x ::> Ok + +foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R +foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float +foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C +foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) +foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I +foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z +foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I +foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z + +reorderAux :: Storable x => Reorder x -> Vector CInt -> Vector CInt -> Vector x -> Vector x +reorderAux f s d v = unsafePerformIO $ do + k <- createVector (dim s) + r <- createVector (dim v) + (k # s # d # v #! r) f #| "reorderV" + return r + +type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) + +foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double +foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float +foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt +foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) +foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) +foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z -- cgit v1.2.3