summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Specialized.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Specialized.hs')
-rw-r--r--packages/base/src/Internal/Specialized.hs128
1 files changed, 105 insertions, 23 deletions
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs
index ff89024..c79194f 100644
--- a/packages/base/src/Internal/Specialized.hs
+++ b/packages/base/src/Internal/Specialized.hs
@@ -1,58 +1,106 @@
1{-# LANGUAGE BangPatterns #-} 1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE GeneralizedNewtypeDeriving #-}
2{-# LANGUAGE ConstraintKinds #-} 3{-# LANGUAGE ConstraintKinds #-}
3{-# LANGUAGE TypeOperators #-} 4{-# LANGUAGE TypeOperators #-}
4{-# LANGUAGE TypeFamilies #-} 5{-# LANGUAGE TypeFamilies #-}
6{-# LANGUAGE CPP #-}
7{-# LANGUAGE DataKinds #-}
8{-# LANGUAGE RankNTypes #-}
9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE KindSignatures #-}
5module Internal.Specialized where 11module Internal.Specialized where
6 12
7import Control.Monad 13import Control.Monad
14import Control.DeepSeq ( NFData(..) )
15import Data.Coerce
8import Data.Complex 16import Data.Complex
17import Data.Functor
9import Data.Int 18import Data.Int
10import Data.Typeable 19import Data.Typeable (eqT,Proxy)
20import Type.Reflection
11import Foreign.Marshal.Alloc(free,malloc) 21import Foreign.Marshal.Alloc(free,malloc)
12import Foreign.Marshal.Array(newArray,copyArray) 22import Foreign.Marshal.Array(newArray,copyArray)
23import Foreign.ForeignPtr(castForeignPtr)
13import Foreign.Ptr 24import Foreign.Ptr
14import Foreign.Storable 25import Foreign.Storable
15import Foreign.C.Types (CInt(..)) 26import Foreign.C.Types (CInt(..))
16import Data.Vector.Storable (Vector)
17import System.IO.Unsafe 27import System.IO.Unsafe
28#if MIN_VERSION_base(4,11,0)
29import GHC.TypeLits hiding (Mod)
30#else
31import GHC.TypeLits
32#endif
18 33
19import Internal.Vector (createVector) 34import Internal.Vector (Vector,createVector,unsafeFromForeignPtr,unsafeToForeignPtr)
20import Internal.Devel 35import Internal.Devel
21 36
37eqt :: (Typeable a, Typeable b) => a -> Maybe (a :~: b)
38eqt _ = eqT
39eq32 :: (Typeable a) => a -> Maybe (a :~: Int32)
40eq32 _ = eqT
41eq64 :: (Typeable a) => a -> Maybe (a :~: Int64)
42eq64 _ = eqT
43eqint :: (Typeable a) => a -> Maybe (a :~: CInt)
44eqint _ = eqT
45
22type Element t = (Storable t, Typeable t) 46type Element t = (Storable t, Typeable t)
23 47
24data Specialized a 48data Specialized a
25 = SpFloat !(a :~: Float) 49 = SpFloat !(a :~: Float)
26 | SpDouble !(a :~: Double) 50 | SpDouble !(a :~: Double)
27 | SpCFloat !(a :~: Complex Float) 51 | SpCFloat !(a :~: Complex Float)
28 | SpCDouble !(a :~: Complex Double) 52 | SpCDouble !(a :~: Complex Double)
29 | SpInt32 !(a :~: Int32) 53 | SpInt32 !(Vector Int32 -> Vector a) !Int32
30 | SpInt64 !(a :~: Int64) 54 | SpInt64 !(Vector Int64 -> Vector a) !Int64
31 55 -- | SpModInt32 !Int32 Int32 !(forall f. f Int32 -> f a)
32specialize :: Typeable a => a -> Maybe (Specialized a) 56 -- | SpModInt64 !Int32 Int64 !(forall f. f Int64 -> f a)
57
58specialize :: forall a. Typeable a => a -> Maybe (Specialized a)
33specialize x = foldr1 mplus 59specialize x = foldr1 mplus
34 [ SpDouble <$> cst x 60 [ SpDouble <$> eqt x
35 , SpInt64 <$> cst x 61 , eq64 x <&> \Refl -> SpInt64 id x
36 , SpFloat <$> cst x 62 , SpFloat <$> eqt x
37 , SpInt32 <$> cst x 63 , eq32 x <&> \Refl -> SpInt32 id x
38 , SpCDouble <$> cst x 64 , SpCDouble <$> eqt x
39 , SpCFloat <$> cst x 65 , SpCFloat <$> eqt x
40 -- , fmap (\(CInt y) -> SpInt32 y) <$> cast x 66 , eqint x <&> \Refl -> case x of CInt y -> SpInt32 coerce y
67 -- , em32 x <&> \(nat,Refl) -> case x of Mod y -> SpInt32 (i2f' nat) y
68 , case typeOf x of
69 App (App modtyp ntyp) inttyp -> case eqTypeRep (typeRep :: TypeRep (Mod :: Nat -> * -> *)) modtyp of
70 Just HRefl -> let i = unMod x
71 in case eqTypeRep (typeRep :: TypeRep Int32) inttyp of
72 Just HRefl -> Just $ SpInt32 i2f i
73 _ -> case eqTypeRep (typeRep :: TypeRep Int64) inttyp of
74 Just HRefl -> Just $ SpInt64 i2f i
75 _ -> Nothing
76 Nothing -> Nothing
77 _ -> Nothing
41 ] 78 ]
42 where
43 cst :: (Typeable a, Typeable b) => a -> Maybe (a :~: b)
44 cst _ = eqT
45 79
46-- | Supported matrix elements. 80-- | Supported matrix elements.
47constantD :: Typeable a => a -> Int -> Vector a 81constantD :: Typeable a => a -> Int -> Vector a
48constantD x = case specialize x of 82constantD x = case specialize x of
49 Nothing -> error "constantD" 83 Nothing -> error "constantD"
50 Just (SpDouble Refl) -> constantAux cconstantR x 84 Just (SpDouble Refl) -> constantAux cconstantR x
51 Just (SpInt64 Refl) -> constantAux cconstantL x 85 Just (SpInt64 out y) -> out . constantAux cconstantL y
52 Just (SpFloat Refl) -> constantAux cconstantF x 86 Just (SpFloat Refl) -> constantAux cconstantF x
53 Just (SpInt32 Refl) -> constantAux cconstantI x 87 Just (SpInt32 out y) -> out . constantAux cconstantI y
54 Just (SpCDouble Refl) -> constantAux cconstantC x 88 Just (SpCDouble Refl) -> constantAux cconstantC x
55 Just (SpCFloat Refl) -> constantAux cconstantQ x 89 Just (SpCFloat Refl) -> constantAux cconstantQ x
90 -- Just (SpModInt32 _ y ret) -> \n -> ret (constantAux cconstantI y n)
91
92-- | Wrapper with a phantom integer for statically checked modular arithmetic.
93newtype Mod (n :: Nat) t = Mod {unMod:: t}
94 deriving (Storable)
95
96instance (NFData t) => NFData (Mod n t)
97 where
98 rnf (Mod x) = rnf x
99
100i2f :: Storable t => Vector t -> Vector (Mod n t)
101i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
102 where (fp,i,n) = unsafeToForeignPtr v
103
56 104
57{- 105{-
58extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) 106extractR :: Typeable a => MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
@@ -65,6 +113,40 @@ remapM :: Typeable a => Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
65rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 113rowOp :: Typeable a => Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
66gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () 114gemm :: Typeable a => Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
67reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation 115reorderV :: Typeable a => Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
116
117instance KnownNat m => Element (Mod m I)
118 where
119 constantD x n = i2f (constantD (unMod x) n)
120 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
121 setRect i j m x = setRect i j (f2iM m) (f2iM x)
122 sortI = sortI . f2i
123 sortV = i2f . sortV . f2i
124 compareV u v = compareV (f2i u) (f2i v)
125 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
126 remapM i j m = i2fM (remap i j (f2iM m))
127 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
128 where
129 m' = fromIntegral . natVal $ (undefined :: Proxy m)
130 gemm u a b c = gemmg (c_gemmMI m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
131 where
132 m' = fromIntegral . natVal $ (undefined :: Proxy m)
133
134instance KnownNat m => Element (Mod m Z)
135 where
136 constantD x n = i2f (constantD (unMod x) n)
137 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
138 setRect i j m x = setRect i j (f2iM m) (f2iM x)
139 sortI = sortI . f2i
140 sortV = i2f . sortV . f2i
141 compareV u v = compareV (f2i u) (f2i v)
142 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
143 remapM i j m = i2fM (remap i j (f2iM m))
144 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
145 where
146 m' = fromIntegral . natVal $ (undefined :: Proxy m)
147 gemm u a b c = gemmg (c_gemmML m') (f2i u) (f2iM a) (f2iM b) (f2iM c)
148 where
149 m' = fromIntegral . natVal $ (undefined :: Proxy m)
68-} 150-}
69 151
70 152