summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Crayne <joe@jerkface.net>2019-08-09 22:58:27 -0400
committerJoe Crayne <joe@jerkface.net>2019-08-09 22:58:27 -0400
commitd304980b586fb7c7ee369b7d83620c9d992dea5a (patch)
treebca2142f07c694849da102e17022a63cf3570351
parent27bde39cfb955505ad65d54c575b2dde016d8ee6 (diff)
Added Mod specialization to Specialized module.
-rw-r--r--packages/base/src/Internal/Modular.hs46
-rw-r--r--packages/base/src/Internal/Specialized.hs128
2 files changed, 105 insertions, 69 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 5af038b..a211dd3 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -60,14 +60,6 @@ import Prelude hiding ((<>))
60 60
61 61
62 62
63-- | Wrapper with a phantom integer for statically checked modular arithmetic.
64newtype Mod (n :: Nat) t = Mod {unMod:: t}
65 deriving (Storable)
66
67instance (NFData t) => NFData (Mod n t)
68 where
69 rnf (Mod x) = rnf x
70
71infixr 5 ./. 63infixr 5 ./.
72type (./.) x n = Mod n x 64type (./.) x n = Mod n x
73 65
@@ -136,40 +128,6 @@ instance (Integral t, KnownNat n) => Num (Mod n t)
136 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) 128 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m))
137 129
138 130
139instance KnownNat m => Element (Mod m I)
140 where
141 constantD x n = i2f (constantD (unMod x) n)
142 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
143 setRect i j m x = setRect i j (f2iM m) (f2iM x)
144 sortI = sortI . f2i
145 sortV = i2f . sortV . f2i
146 compareV u v = compareV (f2i u) (f2i v)
147 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
148 remapM i j m = i2fM (remap i j (f2iM m))
149 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
150 where
151 m' = fromIntegral . natVal $ (undefined :: Proxy m)
152 gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
153 where
154 m' = fromIntegral . natVal $ (undefined :: Proxy m)
155
156instance KnownNat m => Element (Mod m Z)
157 where
158 constantD x n = i2f (constantD (unMod x) n)
159 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
160 setRect i j m x = setRect i j (f2iM m) (f2iM x)
161 sortI = sortI . f2i
162 sortV = i2f . sortV . f2i
163 compareV u v = compareV (f2i u) (f2i v)
164 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
165 remapM i j m = i2fM (remap i j (f2iM m))
166 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
167 where
168 m' = fromIntegral . natVal $ (undefined :: Proxy m)
169 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
170 where
171 m' = fromIntegral . natVal $ (undefined :: Proxy m)
172
173 131
174instance KnownNat m => CTrans (Mod m I) 132instance KnownNat m => CTrans (Mod m I)
175instance KnownNat m => CTrans (Mod m Z) 133instance KnownNat m => CTrans (Mod m Z)
@@ -299,10 +257,6 @@ instance KnownNat m => Normed (Vector (Mod m Z))
299instance KnownNat m => Numeric (Mod m I) 257instance KnownNat m => Numeric (Mod m I)
300instance KnownNat m => Numeric (Mod m Z) 258instance KnownNat m => Numeric (Mod m Z)
301 259
302i2f :: Storable t => Vector t -> Vector (Mod n t)
303i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
304 where (fp,i,n) = unsafeToForeignPtr v
305
306f2i :: Storable t => Vector (Mod n t) -> Vector t 260f2i :: Storable t => Vector (Mod n t) -> Vector t
307f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) 261f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
308 where (fp,i,n) = unsafeToForeignPtr v 262 where (fp,i,n) = unsafeToForeignPtr v
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs
index ff89024..c79194f 100644
--- a/packages/base/src/Internal/Specialized.hs
+++ b/packages/base/src/Internal/Specialized.hs
@@ -1,58 +1,106 @@
1{-# LANGUAGE BangPatterns #-} 1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE GeneralizedNewtypeDeriving #-}
2{-# LANGUAGE ConstraintKinds #-} 3{-# LANGUAGE ConstraintKinds #-}
3{-# LANGUAGE TypeOperators #-} 4{-# LANGUAGE TypeOperators #-}
4{-# LANGUAGE TypeFamilies #-} 5{-# LANGUAGE TypeFamilies #-}
6{-# LANGUAGE CPP #-}
7{-# LANGUAGE DataKinds #-}
8{-# LANGUAGE RankNTypes #-}
9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE KindSignatures #-}
5module Internal.Specialized where 11module Internal.Specialized where
6 12
7import Control.Monad 13import Control.Monad
14import Control.DeepSeq ( NFData(..) )
15import Data.Coerce
8import Data.Complex 16import Data.Complex
17import Data.Functor
9import Data.Int 18import Data.Int
10import Data.Typeable 19import Data.Typeable (eqT,Proxy)
20import Type.Reflection
11import Foreign.Marshal.Alloc(free,malloc) 21import Foreign.Marshal.Alloc(free,malloc)
12import Foreign.Marshal.Array(newArray,copyArray) 22import Foreign.Marshal.Array(newArray,copyArray)
23import Foreign.ForeignPtr(castForeignPtr)
13import Foreign.Ptr 24import Foreign.Ptr
14import Foreign.Storable 25import Foreign.Storable
15import Foreign.C.Types (CInt(..)) 26import Foreign.C.Types (CInt(..))
16import Data.Vector.Storable (Vector)
17import System.IO.Unsafe 27import System.IO.Unsafe
28#if MIN_VERSION_base(4,11,0)
29import GHC.TypeLits hiding (Mod)
30#else
31import GHC.TypeLits
32#endif
18 33
19import Internal.Vector (createVector) 34import Internal.Vector (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr)
20import Internal.Devel 35import Internal.Devel
21 36
37eqt :: (Typeable a, Typeable b) => a -> Maybe (a :~: b)
38eqt _ = eqT
39eq32 :: (Typeable a) => a -> Maybe (a :~: Int32)
40eq32 _ = eqT
41eq64 :: (Typeable a) => a -> Maybe (a :~: Int64)
42eq64 _ = eqT
43eqint :: (Typeable a) => a -> Maybe (a :~: CInt)
44eqint _ = eqT
45
22type Element t = (Storable t, Typeable t) 46type Element t = (Storable t, Typeable t)
23 47
24data Specialized a 48data Specialized a
25 = SpFloat !(a :~: Float) 49 = SpFloat !(a :~: Float)
26 | SpDouble !(a :~: Double) 50 | SpDouble !(a :~: Double)
27 | SpCFloat !(a :~: Complex Float) 51 | SpCFloat !(a :~: Complex Float)
28 | SpCDouble !(a :~: Complex Double) 52 | SpCDouble !(a :~: Complex Double)
29 | SpInt32 !(a :~: Int32) 53 | SpInt32 !(Vector Int32 -> Vector a) !Int32
30 | SpInt64 !(a :~: Int64) 54 | SpInt64 !(Vector Int64 -> Vector a) !Int64
31 55 -- | SpModInt32 !Int32 Int32 !(forall f. f Int32 -> f a)
32specialize :: Typeable a => a -> Maybe (Specialized a) 56 -- | SpModInt64 !Int32 Int64 !(forall f. f Int64 -> f a)
57
58specialize :: forall a. Typeable a => a -> Maybe (Specialized a)
33specialize x = foldr1 mplus 59specialize x = foldr1 mplus
34 [ SpDouble <$> cst x 60 [ SpDouble <$> eqt x
35 , SpInt64 <$> cst x 61 , eq64 x <&> \Refl -> SpInt64 id x
36 , SpFloat <$> cst x 62 , SpFloat <$> eqt x
37 , SpInt32 <$> cst x 63 , eq32 x <&> \Refl -> SpInt32 id x
38 , SpCDouble <$> cst x 64 , SpCDouble <$> eqt x
39 , SpCFloat <$> cst x 65 , SpCFloat <$> eqt x
40 -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x 66 , eqint x <&> \Refl -> case x of CInt y -> SpInt32 coerce y
67 -- , em32 x <&> \(nat,Refl) -> case x of Mod y -> SpInt32 (i2f' nat) y
68 , case typeOf x of
69 App (App modtyp ntyp) inttyp -> case eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp of
70 Just HRefl -> let i = unMod x
71 in case eqTypeRep (typeRep :: TypeRep Int32) inttyp of
72 Just HRefl -> Just $ SpInt32 i2f i
73 _ -> case eqTypeRep (typeRep :: TypeRep Int64) inttyp of
74 Just HRefl -> Just $ SpInt64 i2f i
75 _ -> Nothing
76 Nothing -> Nothing
77 _ -> Nothing
41 ] 78 ]
42 where
43 cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b)
44 cst _ = eqT
45 79
46-- | Supported matrix elements. 80-- | Supported matrix elements.
47constantD :: Typeable a => a -> Int -> Vector a 81constantD :: Typeable a => a -> Int -> Vector a
48constantD x = case specialize x of 82constantD x = case specialize x of
49 Nothing -> error "constantD" 83 Nothing -> error "constantD"
50 Just (SpDouble Refl) -> constantAux cconstantR x 84 Just (SpDouble Refl) -> constantAux cconstantR x
51 Just (SpInt64 Refl) -> constantAux cconstantL x 85 Just (SpInt64 out y) -> out . constantAux cconstantL y
52 Just (SpFloat Refl) -> constantAux cconstantF x 86 Just (SpFloat Refl) -> constantAux cconstantF x
53 Just (SpInt32 Refl) -> constantAux cconstantI x 87 Just (SpInt32 out y) -> out . constantAux cconstantI y
54 Just (SpCDouble Refl) -> constantAux cconstantC x 88 Just (SpCDouble Refl) -> constantAux cconstantC x
55 Just (SpCFloat Refl) -> constantAux cconstantQ x 89 Just (SpCFloat Refl) -> constantAux cconstantQ x
90 -- Just (SpModInt32 _ y ret) -> \n -> ret (constantAux cconstantI y n)
91
92-- | Wrapper with a phantom integer for statically checked modular arithmetic.
93newtype Mod (n :: Nat) t = Mod {unMod:: t}
94 deriving (Storable)
95
96instance (NFData t) => NFData (Mod n t)
97 where
98 rnf (Mod x) = rnf x
99
100i2f :: Storable t => Vector t -> Vector (Mod n t)
101i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
102 where (fp,i,n) = unsafeToForeignPtr v
103
56 104
57{- 105{-
58extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) 106extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
@@ -65,6 +113,40 @@ remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
65rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 113rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
66gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () 114gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
67reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation 115reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
116
117instance KnownNat m => Element (Mod m I)
118 where
119 constantD x n = i2f (constantD (unMod x) n)
120 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
121 setRect i j m x = setRect i j (f2iM m) (f2iM x)
122 sortI = sortI . f2i
123 sortV = i2f . sortV . f2i
124 compareV u v = compareV (f2i u) (f2i v)
125 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
126 remapM i j m = i2fM (remap i j (f2iM m))
127 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
128 where
129 m' = fromIntegral . natVal $ (undefined :: Proxy m)
130 gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
131 where
132 m' = fromIntegral . natVal $ (undefined :: Proxy m)
133
134instance KnownNat m => Element (Mod m Z)
135 where
136 constantD x n = i2f (constantD (unMod x) n)
137 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
138 setRect i j m x = setRect i j (f2iM m) (f2iM x)
139 sortI = sortI . f2i
140 sortV = i2f . sortV . f2i
141 compareV u v = compareV (f2i u) (f2i v)
142 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
143 remapM i j m = i2fM (remap i j (f2iM m))
144 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
145 where
146 m' = fromIntegral . natVal $ (undefined :: Proxy m)
147 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
148 where
149 m' = fromIntegral . natVal $ (undefined :: Proxy m)
68-} 150-}
69 151
70 152