summaryrefslogtreecommitdiff
path: root/CGExp.hs
diff options
context:
space:
mode:
Diffstat (limited to 'CGExp.hs')
-rw-r--r--CGExp.hs332
1 files changed, 332 insertions, 0 deletions
diff --git a/CGExp.hs b/CGExp.hs
new file mode 100644
index 00000000..5477e485
--- /dev/null
+++ b/CGExp.hs
@@ -0,0 +1,332 @@
1{-# LANGUAGE LambdaCase #-}
2{-# LANGUAGE ViewPatterns #-}
3{-# LANGUAGE PatternSynonyms #-}
4{-# LANGUAGE FlexibleContexts #-}
5{-# LANGUAGE FlexibleInstances #-}
6{-# LANGUAGE NoMonomorphismRestriction #-}
7{-# LANGUAGE DeriveFunctor #-}
8{-# LANGUAGE DeriveFoldable #-}
9{-# LANGUAGE DeriveTraversable #-}
10{-# LANGUAGE RecursiveDo #-}
11module CGExp
12 ( module CGExp
13 , Lit(..), Export(..), ModuleR(..)
14 ) where
15
16import Control.Monad.Reader
17import Control.Monad.State
18import Control.Monad.Except
19import Control.Monad.Identity
20import Control.Monad.Writer
21import Control.Arrow
22import qualified Data.Set as S
23import qualified Data.Map as M
24import Text.Parsec.Pos
25
26import Pretty
27import qualified Infer as I
28import Infer (SName, Lit(..), Visibility(..), Export(..), ModuleR(..))
29
30--------------------------------------------------------------------------------
31
32data Exp_ a
33 = Pi_ Visibility SName a a
34 | Lam_ Visibility Pat a a
35 | Con_ (SName, a) [a]
36 | ELit_ Lit
37 | Fun_ (SName, a) [a]
38 | App_ a a
39 | Var_ SName a
40 | TType_
41 | Let_ Pat a a
42 | EFieldProj_ a SName
43 deriving (Show, Eq, Functor, Foldable, Traversable)
44
45instance PShow Exp where pShowPrec p = text . show
46
47pattern Pi h n a b = Exp (Pi_ h n a b)
48pattern Lam h n a b = Exp (Lam_ h n a b)
49pattern Con a b = Exp (Con_ a b)
50pattern ELit a = Exp (ELit_ a)
51pattern Fun a b = Exp (Fun_ a b)
52pattern EApp a b = Exp (App_ a b)
53pattern Var a b = Exp (Var_ a b)
54pattern TType = Exp TType_
55pattern ELet a b c = Exp (Let_ a b c)
56pattern EFieldProj a b = Exp (EFieldProj_ a b)
57
58newtype Exp = Exp (Exp_ Exp)
59 deriving (Show, Eq)
60
61type ConvM a = StateT [SName] (Reader [SName]) a
62
63newName = gets head <* modify tail
64
65toExp :: I.Exp -> Exp
66toExp = flip runReader [] . flip evalStateT freshTypeVars . f
67 where
68 f = \case
69 I.FunN "swizzvector" [_, _, _, exp, getSwizzVec -> Just (concat -> s)] -> newName >>= \n -> do
70 e <- f exp
71 return $ app' (EFieldProj (Pi Visible n (tyOf e) (TVec (length s) TFloat)) s) e
72 I.FunN "swizzscalar" [_, _, exp, mkSwizzStr -> Just s] -> newName >>= \n -> do
73 e <- f exp
74 return $ app' (EFieldProj (Pi Visible n (tyOf e) TFloat) s) e
75 I.Var i -> asks $ uncurry Var . (!!! i)
76 I.Pi b x y -> newName >>= \n -> do
77 t <- f x
78 Pi b n t <$> local ((n, t):) (f y)
79 I.Lam b x y -> newName >>= \n -> do
80 t <- f x
81 Lam b (PVar t n) t <$> local ((n, t):) (f y)
82 I.Con (I.ConName s _ _ t) xs -> con s <$> f t <*> mapM f xs
83 I.TyCon (I.TyConName s _ _ t _ _) xs -> con s <$> f t <*> mapM f xs
84 I.ELit l -> pure $ ELit l
85 I.Fun (I.FunName s _ t) xs -> fun s <$> f t <*> mapM f xs
86 I.CaseFun (I.CaseFunName s t _) xs -> fun s <$> f t <*> mapM f xs
87 I.App a b -> app' <$> f a <*> f b
88 I.Label x _ -> f x
89 I.TType -> pure TType
90 I.LabelEnd x -> f x
91 z -> error $ "toExp: " ++ show z
92
93getSwizzVec = \case
94 I.VV2 _ (mkSwizzStr -> Just sx) (mkSwizzStr -> Just sy) -> Just [sx, sy]
95 I.VV3 _ (mkSwizzStr -> Just sx) (mkSwizzStr -> Just sy) (mkSwizzStr -> Just sz) -> Just [sx, sy, sz]
96 I.VV4 _ (mkSwizzStr -> Just sx) (mkSwizzStr -> Just sy) (mkSwizzStr -> Just sz) (mkSwizzStr -> Just sw) -> Just [sx, sy, sz, sw]
97 _ -> Nothing
98
99mkSwizzStr = \case
100 I.ConN "Sx" [] -> Just "x"
101 I.ConN "Sy" [] -> Just "y"
102 I.ConN "Sz" [] -> Just "z"
103 I.ConN "Sw" [] -> Just "w"
104 _ -> Nothing
105
106xs !!! i | i < 0 || i >= length xs = error $ show xs ++ " !! " ++ show i
107xs !!! i = xs !! i
108
109untick ('\'': s) = s
110untick s = s
111
112fun s t xs = Fun (untick s, t) xs
113
114con (untick -> s) t xs = Con (s, t) xs
115
116freeVars :: Exp -> S.Set SName
117freeVars = \case
118 Var n _ -> S.singleton n
119 Con _ xs -> S.unions $ map freeVars xs
120 ELit _ -> mempty
121 Fun _ xs -> S.unions $ map freeVars xs
122 EApp a b -> freeVars a `S.union` freeVars b
123 Pi _ n a b -> freeVars a `S.union` (S.delete n $ freeVars b)
124 Lam _ n a b -> freeVars a `S.union` (foldr S.delete (freeVars b) (patVars n))
125 EFieldProj a _ -> freeVars a
126 TType -> mempty
127 ELet n a b -> freeVars a `S.union` (foldr S.delete (freeVars b) (patVars n))
128
129type Ty = Exp
130
131tyOf :: Exp -> Ty
132tyOf = \case
133 Lam h (PVar _ n) t x -> Pi h n t $ tyOf x
134 EApp f x -> app (tyOf f) x
135 Var _ t -> t
136 Pi{} -> Type
137 Con (_, t) xs -> foldl app t xs
138 Fun (_, t) xs -> foldl app t xs
139 ELit l -> toExp $ I.litType l
140 TType -> TType
141 ELet a b c -> tyOf $ EApp (ELam a c) b
142 EFieldProj t s -> t
143 x -> error $ "tyOf: " ++ show x
144 where
145 app (Pi _ n a b) x = substE n x b
146
147substE n x = \case
148 z@(Var n' _) | n' == n -> x
149 | otherwise -> z
150 Pi h n' a b | n == n' -> Pi h n' (substE n x a) b
151 Pi h n' a b -> Pi h n' (substE n x a) (substE n x b)
152 Lam h n' a b -> Lam h n' (substE n x a) $ if n `elem` patVars n' then b else substE n x b
153 Con cn xs -> Con cn (map (substE n x) xs)
154 Fun cn xs -> Fun cn (map (substE n x) xs)
155 TType -> TType
156 EApp a b -> app' (substE n x a) (substE n x b)
157 z -> error $ "substE: " ++ show z
158
159app' (Lam _ (PVar _ n) _ x) b = substE n b x
160app' a b = EApp a b
161
162--------------------------------------------------------------------------------
163
164data Pat
165 = PVar Exp SName
166 | PTuple [Pat]
167 deriving (Eq, Show)
168
169instance PShow Pat where pShowPrec p = text . show
170
171patVars (PVar _ n) = [n]
172patVars (PTuple ps) = concatMap patVars ps
173
174patTy (PVar t _) = t
175patTy (PTuple ps) = Con ("Tuple" ++ show (length ps), tupTy $ length ps) $ map patTy ps
176
177tupTy n = foldr (:~>) Type $ replicate n Type
178
179-------------
180
181pattern EVar n <- Var n _
182pattern TVar t n = Var n t
183
184pattern ELam n b <- Lam Visible n _ b where ELam n b = Lam Visible n (patTy n) b
185
186pattern a :~> b = Pi Visible "" a b
187infixr 1 :~>
188
189pattern PrimN n xs <- Fun (n, t) (filterRelevant (n, 0) t -> xs) where PrimN n xs = Fun (n, hackType n) xs
190pattern Prim1 n a = PrimN n [a]
191pattern Prim2 n a b = PrimN n [a, b]
192pattern Prim3 n a b c <- PrimN n [a, b, c]
193pattern Prim4 n a b c d <- PrimN n [a, b, c, d]
194pattern Prim5 n a b c d e <- PrimN n [a, b, c, d, e]
195
196-- todo: remove
197hackType = \case
198 "Output" -> TType
199 "Bool" -> TType
200 "Float" -> TType
201 "Nat" -> TType
202 "Zero" -> TNat
203 "Succ" -> TNat :~> TNat
204 "String" -> TType
205 "Sampler" -> TType
206 "VecS" -> TType :~> TNat :~> TType
207-- "EFieldProj" -> Pi Visible "projt" TType $ Pi Visible "projt2" TString $ Pi Visible "projvec" (TVec (error "pn1") TFloat) (TVec (error "pn2") TFloat)
208 n -> error $ "type of " ++ show n
209
210filterRelevant _ _ [] = []
211filterRelevant i (Pi h n t t') (x: xs) = (if h == Visible || exception i then (x:) else id) $ filterRelevant (id *** (+1) $ i) (substE n x t') xs
212 where
213 -- todo: remove
214 exception i = i `elem` [("ColorImage", 0), ("DepthImage", 0), ("StencilImage", 0)]
215
216pattern AN n xs <- Con (n, t) (filterRelevant (n, 0) t -> xs) where AN n xs = Con (n, hackType n) xs
217pattern A0 n = AN n []
218pattern A1 n a = AN n [a]
219pattern A2 n a b = AN n [a, b]
220pattern A3 n a b c <- AN n [a, b, c]
221pattern A4 n a b c d <- AN n [a, b, c, d]
222pattern A5 n a b c d e <- AN n [a, b, c, d, e]
223
224pattern TCon0 n = A0 n
225pattern TCon t n = Con (n, t) []
226
227pattern Type = TType
228pattern Star = TType
229pattern TUnit <- A0 "Tuple0"
230pattern TBool <- A0 "Bool"
231pattern TWord <- A0 "Word"
232pattern TInt <- A0 "Int"
233pattern TNat = A0 "Nat"
234pattern TFloat = A0 "Float"
235pattern TString = A0 "String"
236pattern TList n <- A1 "List" n
237
238pattern TSampler = A0 "Sampler"
239pattern TFrameBuffer a b <- A2 "FrameBuffer" a b
240pattern Depth n <- A1 "Depth" n
241pattern Stencil n <- A1 "Stencil" n
242pattern Color n <- A1 "Color" n
243
244pattern Zero = A0 "Zero"
245pattern Succ n = A1 "Succ" n
246
247pattern TVec n a = A2 "VecS" a (Nat n)
248pattern TMat i j a <- A3 "Mat" (Nat i) (Nat j) a
249
250pattern Nat n <- (fromNat -> Just n) where Nat n = toNat n
251
252toNat :: Int -> Exp
253toNat 0 = Zero
254toNat n = Succ (toNat $ n-1)
255
256fromNat :: Exp -> Maybe Int
257fromNat Zero = Just 0
258fromNat (Succ n) = (1 +) <$> fromNat n
259
260pattern TTuple xs <- (getTuple -> Just xs)
261pattern ETuple xs <- (getTuple -> Just xs)
262
263getTuple = \case
264 AN "Tuple0" [] -> Just []
265 AN "Tuple2" [a, b] -> Just [a, b]
266 AN "Tuple3" [a, b, c] -> Just [a, b, c]
267 AN "Tuple4" [a, b, c, d] -> Just [a, b, c, d]
268 AN "Tuple5" [a, b, c, d, e] -> Just [a, b, c, d, e]
269 AN "Tuple6" [a, b, c, d, e, f] -> Just [a, b, c, d, e, f]
270 AN "Tuple7" [a, b, c, d, e, f, g] -> Just [a, b, c, d, e, f, g]
271 _ -> Nothing
272
273pattern ERecord a <- (const Nothing -> Just a)
274
275--------------------------------------------------------------------------------
276
277showN = id
278show5 = show
279
280pattern ExpN a = a
281
282newtype ErrorMsg = ErrorMsg String
283instance Show ErrorMsg where show (ErrorMsg s) = s
284
285type ErrorT = ExceptT ErrorMsg
286mapError f e = e
287pattern InFile :: String -> ErrorMsg -> ErrorMsg
288pattern InFile s e <- ((,) "?" -> (s, e)) where InFile _ e = e
289
290type Info = (SourcePos, SourcePos, String)
291type Infos = [Info]
292
293type PolyEnv = I.GlobalEnv
294
295joinPolyEnvs :: MonadError ErrorMsg m => Bool -> [PolyEnv] -> m PolyEnv
296joinPolyEnvs _ = return . mconcat
297getPolyEnv = id
298
299type MName = SName
300type VarMT = StateT FreshVars
301type FreshVars = [String]
302freshTypeVars = (flip (:) <$> map show [0..] <*> ['a'..'z'])
303
304throwErrorTCM :: MonadError ErrorMsg m => Doc -> m a
305throwErrorTCM d = throwError $ ErrorMsg $ show d
306
307infos = const []
308
309type EName = SName
310
311parseLC :: MonadError ErrorMsg m => FilePath -> String -> m ModuleR
312parseLC f s = either (throwError . ErrorMsg) return (I.parse f s)
313
314inference_ :: PolyEnv -> ModuleR -> ErrorT (WriterT Infos (VarMT Identity)) PolyEnv
315inference_ pe m = mdo
316 defs <- either (throwError . ErrorMsg) return $ definitions m $ I.mkGlobalEnv' defs `I.joinGlobalEnv'` I.extractGlobalEnv' pe
317 either (throwError . ErrorMsg) return $ I.infer pe defs
318
319reduce = id
320
321type Item = (I.Exp, I.Exp)
322
323tyOfItem :: Item -> Exp
324tyOfItem = toExp . snd
325
326pattern ISubst a b <- ((,) () -> (a, (toExp -> b, tb)))
327
328dummyPos :: SourcePos
329dummyPos = newPos "" 0 0
330
331showErr e = (dummyPos, dummyPos, e)
332