diff options
author | Joe Crayne <joe@jerkface.net> | 2019-08-09 22:58:27 -0400 |
---|---|---|
committer | Joe Crayne <joe@jerkface.net> | 2019-08-09 22:58:27 -0400 |
commit | d304980b586fb7c7ee369b7d83620c9d992dea5a (patch) | |
tree | bca2142f07c694849da102e17022a63cf3570351 | |
parent | 27bde39cfb955505ad65d54c575b2dde016d8ee6 (diff) |
Added Mod specialization to Specialized module.
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 46 | ||||
-rw-r--r-- | packages/base/src/Internal/Specialized.hs | 128 |
2 files changed, 105 insertions, 69 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 5af038b..a211dd3 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -60,14 +60,6 @@ import Prelude hiding ((<>)) | |||
60 | 60 | ||
61 | 61 | ||
62 | 62 | ||
63 | -- | Wrapper with a phantom integer for statically checked modular arithmetic. | ||
64 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | ||
65 | deriving (Storable) | ||
66 | |||
67 | instance (NFData t) => NFData (Mod n t) | ||
68 | where | ||
69 | rnf (Mod x) = rnf x | ||
70 | |||
71 | infixr 5 ./. | 63 | infixr 5 ./. |
72 | type (./.) x n = Mod n x | 64 | type (./.) x n = Mod n x |
73 | 65 | ||
@@ -136,40 +128,6 @@ instance (Integral t, KnownNat n) => Num (Mod n t) | |||
136 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) | 128 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) |
137 | 129 | ||
138 | 130 | ||
139 | instance KnownNat m => Element (Mod m I) | ||
140 | where | ||
141 | constantD x n = i2f (constantD (unMod x) n) | ||
142 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | ||
143 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
144 | sortI = sortI . f2i | ||
145 | sortV = i2f . sortV . f2i | ||
146 | compareV u v = compareV (f2i u) (f2i v) | ||
147 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
148 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
149 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
150 | where | ||
151 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
152 | gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | ||
153 | where | ||
154 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
155 | |||
156 | instance KnownNat m => Element (Mod m Z) | ||
157 | where | ||
158 | constantD x n = i2f (constantD (unMod x) n) | ||
159 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | ||
160 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
161 | sortI = sortI . f2i | ||
162 | sortV = i2f . sortV . f2i | ||
163 | compareV u v = compareV (f2i u) (f2i v) | ||
164 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
165 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
166 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
167 | where | ||
168 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
169 | gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | ||
170 | where | ||
171 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
172 | |||
173 | 131 | ||
174 | instance KnownNat m => CTrans (Mod m I) | 132 | instance KnownNat m => CTrans (Mod m I) |
175 | instance KnownNat m => CTrans (Mod m Z) | 133 | instance KnownNat m => CTrans (Mod m Z) |
@@ -299,10 +257,6 @@ instance KnownNat m => Normed (Vector (Mod m Z)) | |||
299 | instance KnownNat m => Numeric (Mod m I) | 257 | instance KnownNat m => Numeric (Mod m I) |
300 | instance KnownNat m => Numeric (Mod m Z) | 258 | instance KnownNat m => Numeric (Mod m Z) |
301 | 259 | ||
302 | i2f :: Storable t => Vector t -> Vector (Mod n t) | ||
303 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
304 | where (fp,i,n) = unsafeToForeignPtr v | ||
305 | |||
306 | f2i :: Storable t => Vector (Mod n t) -> Vector t | 260 | f2i :: Storable t => Vector (Mod n t) -> Vector t |
307 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | 261 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) |
308 | where (fp,i,n) = unsafeToForeignPtr v | 262 | where (fp,i,n) = unsafeToForeignPtr v |
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs index ff89024..c79194f 100644 --- a/packages/base/src/Internal/Specialized.hs +++ b/packages/base/src/Internal/Specialized.hs | |||
@@ -1,58 +1,106 @@ | |||
1 | {-# LANGUAGE BangPatterns #-} | 1 | {-# LANGUAGE BangPatterns #-} |
2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
2 | {-# LANGUAGE ConstraintKinds #-} | 3 | {-# LANGUAGE ConstraintKinds #-} |
3 | {-# LANGUAGE TypeOperators #-} | 4 | {-# LANGUAGE TypeOperators #-} |
4 | {-# LANGUAGE TypeFamilies #-} | 5 | {-# LANGUAGE TypeFamilies #-} |
6 | {-# LANGUAGE CPP #-} | ||
7 | {-# LANGUAGE DataKinds #-} | ||
8 | {-# LANGUAGE RankNTypes #-} | ||
9 | {-# LANGUAGE ScopedTypeVariables #-} | ||
10 | {-# LANGUAGE KindSignatures #-} | ||
5 | module Internal.Specialized where | 11 | module Internal.Specialized where |
6 | 12 | ||
7 | import Control.Monad | 13 | import Control.Monad |
14 | import Control.DeepSeq ( NFData(..) ) | ||
15 | import Data.Coerce | ||
8 | import Data.Complex | 16 | import Data.Complex |
17 | import Data.Functor | ||
9 | import Data.Int | 18 | import Data.Int |
10 | import Data.Typeable | 19 | import Data.Typeable (eqT,Proxy) |
20 | import Type.Reflection | ||
11 | import Foreign.Marshal.Alloc(free,malloc) | 21 | import Foreign.Marshal.Alloc(free,malloc) |
12 | import Foreign.Marshal.Array(newArray,copyArray) | 22 | import Foreign.Marshal.Array(newArray,copyArray) |
23 | import Foreign.ForeignPtr(castForeignPtr) | ||
13 | import Foreign.Ptr | 24 | import Foreign.Ptr |
14 | import Foreign.Storable | 25 | import Foreign.Storable |
15 | import Foreign.C.Types (CInt(..)) | 26 | import Foreign.C.Types (CInt(..)) |
16 | import Data.Vector.Storable (Vector) | ||
17 | import System.IO.Unsafe | 27 | import System.IO.Unsafe |
28 | #if MIN_VERSION_base(4,11,0) | ||
29 | import GHC.TypeLits hiding (Mod) | ||
30 | #else | ||
31 | import GHC.TypeLits | ||
32 | #endif | ||
18 | 33 | ||
19 | import Internal.Vector (createVector) | 34 | import Internal.Vector (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr) |
20 | import Internal.Devel | 35 | import Internal.Devel |
21 | 36 | ||
37 | eqt :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) | ||
38 | eqt _ = eqT | ||
39 | eq32 :: (Typeable a) => a -> Maybe (a :~: Int32) | ||
40 | eq32 _ = eqT | ||
41 | eq64 :: (Typeable a) => a -> Maybe (a :~: Int64) | ||
42 | eq64 _ = eqT | ||
43 | eqint :: (Typeable a) => a -> Maybe (a :~: CInt) | ||
44 | eqint _ = eqT | ||
45 | |||
22 | type Element t = (Storable t, Typeable t) | 46 | type Element t = (Storable t, Typeable t) |
23 | 47 | ||
24 | data Specialized a | 48 | data Specialized a |
25 | = SpFloat !(a :~: Float) | 49 | = SpFloat !(a :~: Float) |
26 | | SpDouble !(a :~: Double) | 50 | | SpDouble !(a :~: Double) |
27 | | SpCFloat !(a :~: Complex Float) | 51 | | SpCFloat !(a :~: Complex Float) |
28 | | SpCDouble !(a :~: Complex Double) | 52 | | SpCDouble !(a :~: Complex Double) |
29 | | SpInt32 !(a :~: Int32) | 53 | | SpInt32 !(Vector Int32 -> Vector a) !Int32 |
30 | | SpInt64 !(a :~: Int64) | 54 | | SpInt64 !(Vector Int64 -> Vector a) !Int64 |
31 | 55 | -- | SpModInt32 !Int32 Int32 !(forall f. f Int32 -> f a) | |
32 | specialize :: Typeable a => a -> Maybe (Specialized a) | 56 | -- | SpModInt64 !Int32 Int64 !(forall f. f Int64 -> f a) |
57 | |||
58 | specialize :: forall a. Typeable a => a -> Maybe (Specialized a) | ||
33 | specialize x = foldr1 mplus | 59 | specialize x = foldr1 mplus |
34 | [ SpDouble <$> cst x | 60 | [ SpDouble <$> eqt x |
35 | , SpInt64 <$> cst x | 61 | , eq64 x <&> \Refl -> SpInt64 id x |
36 | , SpFloat <$> cst x | 62 | , SpFloat <$> eqt x |
37 | , SpInt32 <$> cst x | 63 | , eq32 x <&> \Refl -> SpInt32 id x |
38 | , SpCDouble <$> cst x | 64 | , SpCDouble <$> eqt x |
39 | , SpCFloat <$> cst x | 65 | , SpCFloat <$> eqt x |
40 | -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x | 66 | , eqint x <&> \Refl -> case x of CInt y -> SpInt32 coerce y |
67 | -- , em32 x <&> \(nat,Refl) -> case x of Mod y -> SpInt32 (i2f' nat) y | ||
68 | , case typeOf x of | ||
69 | App (App modtyp ntyp) inttyp -> case eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp of | ||
70 | Just HRefl -> let i = unMod x | ||
71 | in case eqTypeRep (typeRep :: TypeRep Int32) inttyp of | ||
72 | Just HRefl -> Just $ SpInt32 i2f i | ||
73 | _ -> case eqTypeRep (typeRep :: TypeRep Int64) inttyp of | ||
74 | Just HRefl -> Just $ SpInt64 i2f i | ||
75 | _ -> Nothing | ||
76 | Nothing -> Nothing | ||
77 | _ -> Nothing | ||
41 | ] | 78 | ] |
42 | where | ||
43 | cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b) | ||
44 | cst _ = eqT | ||
45 | 79 | ||
46 | -- | Supported matrix elements. | 80 | -- | Supported matrix elements. |
47 | constantD :: Typeable a => a -> Int -> Vector a | 81 | constantD :: Typeable a => a -> Int -> Vector a |
48 | constantD x = case specialize x of | 82 | constantD x = case specialize x of |
49 | Nothing -> error "constantD" | 83 | Nothing -> error "constantD" |
50 | Just (SpDouble Refl) -> constantAux cconstantR x | 84 | Just (SpDouble Refl) -> constantAux cconstantR x |
51 | Just (SpInt64 Refl) -> constantAux cconstantL x | 85 | Just (SpInt64 out y) -> out . constantAux cconstantL y |
52 | Just (SpFloat Refl) -> constantAux cconstantF x | 86 | Just (SpFloat Refl) -> constantAux cconstantF x |
53 | Just (SpInt32 Refl) -> constantAux cconstantI x | 87 | Just (SpInt32 out y) -> out . constantAux cconstantI y |
54 | Just (SpCDouble Refl) -> constantAux cconstantC x | 88 | Just (SpCDouble Refl) -> constantAux cconstantC x |
55 | Just (SpCFloat Refl) -> constantAux cconstantQ x | 89 | Just (SpCFloat Refl) -> constantAux cconstantQ x |
90 | -- Just (SpModInt32 _ y ret) -> \n -> ret (constantAux cconstantI y n) | ||
91 | |||
92 | -- | Wrapper with a phantom integer for statically checked modular arithmetic. | ||
93 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | ||
94 | deriving (Storable) | ||
95 | |||
96 | instance (NFData t) => NFData (Mod n t) | ||
97 | where | ||
98 | rnf (Mod x) = rnf x | ||
99 | |||
100 | i2f :: Storable t => Vector t -> Vector (Mod n t) | ||
101 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
102 | where (fp,i,n) = unsafeToForeignPtr v | ||
103 | |||
56 | 104 | ||
57 | {- | 105 | {- |
58 | extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | 106 | extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) |
@@ -65,6 +113,40 @@ remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | |||
65 | rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 113 | rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
66 | gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | 114 | gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () |
67 | reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | 115 | reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation |
116 | |||
117 | instance KnownNat m => Element (Mod m I) | ||
118 | where | ||
119 | constantD x n = i2f (constantD (unMod x) n) | ||
120 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | ||
121 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
122 | sortI = sortI . f2i | ||
123 | sortV = i2f . sortV . f2i | ||
124 | compareV u v = compareV (f2i u) (f2i v) | ||
125 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
126 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
127 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
128 | where | ||
129 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
130 | gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | ||
131 | where | ||
132 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
133 | |||
134 | instance KnownNat m => Element (Mod m Z) | ||
135 | where | ||
136 | constantD x n = i2f (constantD (unMod x) n) | ||
137 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js | ||
138 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
139 | sortI = sortI . f2i | ||
140 | sortV = i2f . sortV . f2i | ||
141 | compareV u v = compareV (f2i u) (f2i v) | ||
142 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
143 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
144 | rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) | ||
145 | where | ||
146 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
147 | gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c) | ||
148 | where | ||
149 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
68 | -} | 150 | -} |
69 | 151 | ||
70 | 152 | ||