From d304980b586fb7c7ee369b7d83620c9d992dea5a Mon Sep 17 00:00:00 2001 From: Joe Crayne Date: Fri, 9 Aug 2019 22:58:27 -0400 Subject: Added Mod specialization to Specialized module. --- packages/base/src/Internal/Modular.hs | 46 ----------- packages/base/src/Internal/Specialized.hs | 128 ++++++++++++++++++++++++------ 2 files changed, 105 insertions(+), 69 deletions(-) diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 5af038b..a211dd3 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs @@ -60,14 +60,6 @@ import Prelude hiding ((<>)) --- | Wrapper with a phantom integer for statically checked modular arithmetic. -newtype Mod (n :: Nat) t = Mod {unMod:: t} - deriving (Storable) - -instance (NFData t) => NFData (Mod n t) - where - rnf (Mod x) = rnf x - infixr 5 ./. type (./.) x n = Mod n x @@ -136,40 +128,6 @@ instance (Integral t, KnownNat n) => Num (Mod n t) fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) -instance KnownNat m => Element (Mod m I) - where - constantD x n = i2f (constantD (unMod x) n) - extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js - setRect i j m x = setRect i j (f2iM m) (f2iM x) - sortI = sortI . f2i - sortV = i2f . sortV . f2i - compareV u v = compareV (f2i u) (f2i v) - selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) - remapM i j m = i2fM (remap i j (f2iM m)) - rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - -instance KnownNat m => Element (Mod m Z) - where - constantD x n = i2f (constantD (unMod x) n) - extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js - setRect i j m x = setRect i j (f2iM m) (f2iM x) - sortI = sortI . f2i - sortV = i2f . sortV . f2i - compareV u v = compareV (f2i u) (f2i v) - selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) - remapM i j m = i2fM (remap i j (f2iM m)) - rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - instance KnownNat m => CTrans (Mod m I) instance KnownNat m => CTrans (Mod m Z) @@ -299,10 +257,6 @@ instance KnownNat m => Normed (Vector (Mod m Z)) instance KnownNat m => Numeric (Mod m I) instance KnownNat m => Numeric (Mod m Z) -i2f :: Storable t => Vector t -> Vector (Mod n t) -i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) - where (fp,i,n) = unsafeToForeignPtr v - f2i :: Storable t => Vector (Mod n t) -> Vector t f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) where (fp,i,n) = unsafeToForeignPtr v diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs index ff89024..c79194f 100644 --- a/packages/base/src/Internal/Specialized.hs +++ b/packages/base/src/Internal/Specialized.hs @@ -1,58 +1,106 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE KindSignatures #-} module Internal.Specialized where import Control.Monad +import Control.DeepSeq ( NFData(..) ) +import Data.Coerce import Data.Complex +import Data.Functor import Data.Int -import Data.Typeable +import Data.Typeable (eqT,Proxy) +import Type.Reflection import Foreign.Marshal.Alloc(free,malloc) import Foreign.Marshal.Array(newArray,copyArray) +import Foreign.ForeignPtr(castForeignPtr) import Foreign.Ptr import Foreign.Storable import Foreign.C.Types (CInt(..)) -import Data.Vector.Storable (Vector) import System.IO.Unsafe +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (Mod) +#else +import GHC.TypeLits +#endif -import Internal.Vector (createVector) +import Internal.Vector (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr) import Internal.Devel +eqt :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) +eqt _ = eqT +eq32 :: (Typeable a) => a -> Maybe (a :~: Int32) +eq32 _ = eqT +eq64 :: (Typeable a) => a -> Maybe (a :~: Int64) +eq64 _ = eqT +eqint :: (Typeable a) => a -> Maybe (a :~: CInt) +eqint _ = eqT + type Element t = (Storable t, Typeable t) data Specialized a - = SpFloat !(a :~: Float) - | SpDouble !(a :~: Double) - | SpCFloat !(a :~: Complex Float) - | SpCDouble !(a :~: Complex Double) - | SpInt32 !(a :~: Int32) - | SpInt64 !(a :~: Int64) - -specialize :: Typeable a => a -> Maybe (Specialized a) + = SpFloat !(a :~: Float) + | SpDouble !(a :~: Double) + | SpCFloat !(a :~: Complex Float) + | SpCDouble !(a :~: Complex Double) + | SpInt32 !(Vector Int32 -> Vector a) !Int32 + | SpInt64 !(Vector Int64 -> Vector a) !Int64 + -- | SpModInt32 !Int32 Int32 !(forall f. f Int32 -> f a) + -- | SpModInt64 !Int32 Int64 !(forall f. f Int64 -> f a) + +specialize :: forall a. Typeable a => a -> Maybe (Specialized a) specialize x = foldr1 mplus - [ SpDouble <$> cst x - , SpInt64 <$> cst x - , SpFloat <$> cst x - , SpInt32 <$> cst x - , SpCDouble <$> cst x - , SpCFloat <$> cst x - -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x + [ SpDouble <$> eqt x + , eq64 x <&> \Refl -> SpInt64 id x + , SpFloat <$> eqt x + , eq32 x <&> \Refl -> SpInt32 id x + , SpCDouble <$> eqt x + , SpCFloat <$> eqt x + , eqint x <&> \Refl -> case x of CInt y -> SpInt32 coerce y + -- , em32 x <&> \(nat,Refl) -> case x of Mod y -> SpInt32 (i2f' nat) y + , case typeOf x of + App (App modtyp ntyp) inttyp -> case eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp of + Just HRefl -> let i = unMod x + in case eqTypeRep (typeRep :: TypeRep Int32) inttyp of + Just HRefl -> Just $ SpInt32 i2f i + _ -> case eqTypeRep (typeRep :: TypeRep Int64) inttyp of + Just HRefl -> Just $ SpInt64 i2f i + _ -> Nothing + Nothing -> Nothing + _ -> Nothing ] - where - cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) - cst _ = eqT -- | Supported matrix elements. constantD :: Typeable a => a -> Int -> Vector a constantD x = case specialize x of Nothing -> error "constantD" Just (SpDouble Refl) -> constantAux cconstantR x - Just (SpInt64 Refl) -> constantAux cconstantL x + Just (SpInt64 out y) -> out . constantAux cconstantL y Just (SpFloat Refl) -> constantAux cconstantF x - Just (SpInt32 Refl) -> constantAux cconstantI x + Just (SpInt32 out y) -> out . constantAux cconstantI y Just (SpCDouble Refl) -> constantAux cconstantC x Just (SpCFloat Refl) -> constantAux cconstantQ x + -- Just (SpModInt32 _ y ret) -> \n -> ret (constantAux cconstantI y n) + +-- | Wrapper with a phantom integer for statically checked modular arithmetic. +newtype Mod (n :: Nat) t = Mod {unMod:: t} + deriving (Storable) + +instance (NFData t) => NFData (Mod n t) + where + rnf (Mod x) = rnf x + +i2f :: Storable t => Vector t -> Vector (Mod n t) +i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) + where (fp,i,n) = unsafeToForeignPtr v + {- extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) @@ -65,6 +113,40 @@ remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation + +instance KnownNat m => Element (Mod m I) + where + constantD x n = i2f (constantD (unMod x) n) + extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js + setRect i j m x = setRect i j (f2iM m) (f2iM x) + sortI = sortI . f2i + sortV = i2f . sortV . f2i + compareV u v = compareV (f2i u) (f2i v) + selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) + remapM i j m = i2fM (remap i j (f2iM m)) + rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + +instance KnownNat m => Element (Mod m Z) + where + constantD x n = i2f (constantD (unMod x) n) + extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js + setRect i j m x = setRect i j (f2iM m) (f2iM x) + sortI = sortI . f2i + sortV = i2f . sortV . f2i + compareV u v = compareV (f2i u) (f2i v) + selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) + remapM i j m = i2fM (remap i j (f2iM m)) + rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) -} -- cgit v1.2.3