diff options
Diffstat (limited to 'CGExp.hs')
-rw-r--r-- | CGExp.hs | 332 |
1 files changed, 332 insertions, 0 deletions
diff --git a/CGExp.hs b/CGExp.hs new file mode 100644 index 00000000..5477e485 --- /dev/null +++ b/CGExp.hs | |||
@@ -0,0 +1,332 @@ | |||
1 | {-# LANGUAGE LambdaCase #-} | ||
2 | {-# LANGUAGE ViewPatterns #-} | ||
3 | {-# LANGUAGE PatternSynonyms #-} | ||
4 | {-# LANGUAGE FlexibleContexts #-} | ||
5 | {-# LANGUAGE FlexibleInstances #-} | ||
6 | {-# LANGUAGE NoMonomorphismRestriction #-} | ||
7 | {-# LANGUAGE DeriveFunctor #-} | ||
8 | {-# LANGUAGE DeriveFoldable #-} | ||
9 | {-# LANGUAGE DeriveTraversable #-} | ||
10 | {-# LANGUAGE RecursiveDo #-} | ||
11 | module CGExp | ||
12 | ( module CGExp | ||
13 | , Lit(..), Export(..), ModuleR(..) | ||
14 | ) where | ||
15 | |||
16 | import Control.Monad.Reader | ||
17 | import Control.Monad.State | ||
18 | import Control.Monad.Except | ||
19 | import Control.Monad.Identity | ||
20 | import Control.Monad.Writer | ||
21 | import Control.Arrow | ||
22 | import qualified Data.Set as S | ||
23 | import qualified Data.Map as M | ||
24 | import Text.Parsec.Pos | ||
25 | |||
26 | import Pretty | ||
27 | import qualified Infer as I | ||
28 | import Infer (SName, Lit(..), Visibility(..), Export(..), ModuleR(..)) | ||
29 | |||
30 | -------------------------------------------------------------------------------- | ||
31 | |||
32 | data Exp_ a | ||
33 | = Pi_ Visibility SName a a | ||
34 | | Lam_ Visibility Pat a a | ||
35 | | Con_ (SName, a) [a] | ||
36 | | ELit_ Lit | ||
37 | | Fun_ (SName, a) [a] | ||
38 | | App_ a a | ||
39 | | Var_ SName a | ||
40 | | TType_ | ||
41 | | Let_ Pat a a | ||
42 | | EFieldProj_ a SName | ||
43 | deriving (Show, Eq, Functor, Foldable, Traversable) | ||
44 | |||
45 | instance PShow Exp where pShowPrec p = text . show | ||
46 | |||
47 | pattern Pi h n a b = Exp (Pi_ h n a b) | ||
48 | pattern Lam h n a b = Exp (Lam_ h n a b) | ||
49 | pattern Con a b = Exp (Con_ a b) | ||
50 | pattern ELit a = Exp (ELit_ a) | ||
51 | pattern Fun a b = Exp (Fun_ a b) | ||
52 | pattern EApp a b = Exp (App_ a b) | ||
53 | pattern Var a b = Exp (Var_ a b) | ||
54 | pattern TType = Exp TType_ | ||
55 | pattern ELet a b c = Exp (Let_ a b c) | ||
56 | pattern EFieldProj a b = Exp (EFieldProj_ a b) | ||
57 | |||
58 | newtype Exp = Exp (Exp_ Exp) | ||
59 | deriving (Show, Eq) | ||
60 | |||
61 | type ConvM a = StateT [SName] (Reader [SName]) a | ||
62 | |||
63 | newName = gets head <* modify tail | ||
64 | |||
65 | toExp :: I.Exp -> Exp | ||
66 | toExp = flip runReader [] . flip evalStateT freshTypeVars . f | ||
67 | where | ||
68 | f = \case | ||
69 | I.FunN "swizzvector" [_, _, _, exp, getSwizzVec -> Just (concat -> s)] -> newName >>= \n -> do | ||
70 | e <- f exp | ||
71 | return $ app' (EFieldProj (Pi Visible n (tyOf e) (TVec (length s) TFloat)) s) e | ||
72 | I.FunN "swizzscalar" [_, _, exp, mkSwizzStr -> Just s] -> newName >>= \n -> do | ||
73 | e <- f exp | ||
74 | return $ app' (EFieldProj (Pi Visible n (tyOf e) TFloat) s) e | ||
75 | I.Var i -> asks $ uncurry Var . (!!! i) | ||
76 | I.Pi b x y -> newName >>= \n -> do | ||
77 | t <- f x | ||
78 | Pi b n t <$> local ((n, t):) (f y) | ||
79 | I.Lam b x y -> newName >>= \n -> do | ||
80 | t <- f x | ||
81 | Lam b (PVar t n) t <$> local ((n, t):) (f y) | ||
82 | I.Con (I.ConName s _ _ t) xs -> con s <$> f t <*> mapM f xs | ||
83 | I.TyCon (I.TyConName s _ _ t _ _) xs -> con s <$> f t <*> mapM f xs | ||
84 | I.ELit l -> pure $ ELit l | ||
85 | I.Fun (I.FunName s _ t) xs -> fun s <$> f t <*> mapM f xs | ||
86 | I.CaseFun (I.CaseFunName s t _) xs -> fun s <$> f t <*> mapM f xs | ||
87 | I.App a b -> app' <$> f a <*> f b | ||
88 | I.Label x _ -> f x | ||
89 | I.TType -> pure TType | ||
90 | I.LabelEnd x -> f x | ||
91 | z -> error $ "toExp: " ++ show z | ||
92 | |||
93 | getSwizzVec = \case | ||
94 | I.VV2 _ (mkSwizzStr -> Just sx) (mkSwizzStr -> Just sy) -> Just [sx, sy] | ||
95 | I.VV3 _ (mkSwizzStr -> Just sx) (mkSwizzStr -> Just sy) (mkSwizzStr -> Just sz) -> Just [sx, sy, sz] | ||
96 | I.VV4 _ (mkSwizzStr -> Just sx) (mkSwizzStr -> Just sy) (mkSwizzStr -> Just sz) (mkSwizzStr -> Just sw) -> Just [sx, sy, sz, sw] | ||
97 | _ -> Nothing | ||
98 | |||
99 | mkSwizzStr = \case | ||
100 | I.ConN "Sx" [] -> Just "x" | ||
101 | I.ConN "Sy" [] -> Just "y" | ||
102 | I.ConN "Sz" [] -> Just "z" | ||
103 | I.ConN "Sw" [] -> Just "w" | ||
104 | _ -> Nothing | ||
105 | |||
106 | xs !!! i | i < 0 || i >= length xs = error $ show xs ++ " !! " ++ show i | ||
107 | xs !!! i = xs !! i | ||
108 | |||
109 | untick ('\'': s) = s | ||
110 | untick s = s | ||
111 | |||
112 | fun s t xs = Fun (untick s, t) xs | ||
113 | |||
114 | con (untick -> s) t xs = Con (s, t) xs | ||
115 | |||
116 | freeVars :: Exp -> S.Set SName | ||
117 | freeVars = \case | ||
118 | Var n _ -> S.singleton n | ||
119 | Con _ xs -> S.unions $ map freeVars xs | ||
120 | ELit _ -> mempty | ||
121 | Fun _ xs -> S.unions $ map freeVars xs | ||
122 | EApp a b -> freeVars a `S.union` freeVars b | ||
123 | Pi _ n a b -> freeVars a `S.union` (S.delete n $ freeVars b) | ||
124 | Lam _ n a b -> freeVars a `S.union` (foldr S.delete (freeVars b) (patVars n)) | ||
125 | EFieldProj a _ -> freeVars a | ||
126 | TType -> mempty | ||
127 | ELet n a b -> freeVars a `S.union` (foldr S.delete (freeVars b) (patVars n)) | ||
128 | |||
129 | type Ty = Exp | ||
130 | |||
131 | tyOf :: Exp -> Ty | ||
132 | tyOf = \case | ||
133 | Lam h (PVar _ n) t x -> Pi h n t $ tyOf x | ||
134 | EApp f x -> app (tyOf f) x | ||
135 | Var _ t -> t | ||
136 | Pi{} -> Type | ||
137 | Con (_, t) xs -> foldl app t xs | ||
138 | Fun (_, t) xs -> foldl app t xs | ||
139 | ELit l -> toExp $ I.litType l | ||
140 | TType -> TType | ||
141 | ELet a b c -> tyOf $ EApp (ELam a c) b | ||
142 | EFieldProj t s -> t | ||
143 | x -> error $ "tyOf: " ++ show x | ||
144 | where | ||
145 | app (Pi _ n a b) x = substE n x b | ||
146 | |||
147 | substE n x = \case | ||
148 | z@(Var n' _) | n' == n -> x | ||
149 | | otherwise -> z | ||
150 | Pi h n' a b | n == n' -> Pi h n' (substE n x a) b | ||
151 | Pi h n' a b -> Pi h n' (substE n x a) (substE n x b) | ||
152 | Lam h n' a b -> Lam h n' (substE n x a) $ if n `elem` patVars n' then b else substE n x b | ||
153 | Con cn xs -> Con cn (map (substE n x) xs) | ||
154 | Fun cn xs -> Fun cn (map (substE n x) xs) | ||
155 | TType -> TType | ||
156 | EApp a b -> app' (substE n x a) (substE n x b) | ||
157 | z -> error $ "substE: " ++ show z | ||
158 | |||
159 | app' (Lam _ (PVar _ n) _ x) b = substE n b x | ||
160 | app' a b = EApp a b | ||
161 | |||
162 | -------------------------------------------------------------------------------- | ||
163 | |||
164 | data Pat | ||
165 | = PVar Exp SName | ||
166 | | PTuple [Pat] | ||
167 | deriving (Eq, Show) | ||
168 | |||
169 | instance PShow Pat where pShowPrec p = text . show | ||
170 | |||
171 | patVars (PVar _ n) = [n] | ||
172 | patVars (PTuple ps) = concatMap patVars ps | ||
173 | |||
174 | patTy (PVar t _) = t | ||
175 | patTy (PTuple ps) = Con ("Tuple" ++ show (length ps), tupTy $ length ps) $ map patTy ps | ||
176 | |||
177 | tupTy n = foldr (:~>) Type $ replicate n Type | ||
178 | |||
179 | ------------- | ||
180 | |||
181 | pattern EVar n <- Var n _ | ||
182 | pattern TVar t n = Var n t | ||
183 | |||
184 | pattern ELam n b <- Lam Visible n _ b where ELam n b = Lam Visible n (patTy n) b | ||
185 | |||
186 | pattern a :~> b = Pi Visible "" a b | ||
187 | infixr 1 :~> | ||
188 | |||
189 | pattern PrimN n xs <- Fun (n, t) (filterRelevant (n, 0) t -> xs) where PrimN n xs = Fun (n, hackType n) xs | ||
190 | pattern Prim1 n a = PrimN n [a] | ||
191 | pattern Prim2 n a b = PrimN n [a, b] | ||
192 | pattern Prim3 n a b c <- PrimN n [a, b, c] | ||
193 | pattern Prim4 n a b c d <- PrimN n [a, b, c, d] | ||
194 | pattern Prim5 n a b c d e <- PrimN n [a, b, c, d, e] | ||
195 | |||
196 | -- todo: remove | ||
197 | hackType = \case | ||
198 | "Output" -> TType | ||
199 | "Bool" -> TType | ||
200 | "Float" -> TType | ||
201 | "Nat" -> TType | ||
202 | "Zero" -> TNat | ||
203 | "Succ" -> TNat :~> TNat | ||
204 | "String" -> TType | ||
205 | "Sampler" -> TType | ||
206 | "VecS" -> TType :~> TNat :~> TType | ||
207 | -- "EFieldProj" -> Pi Visible "projt" TType $ Pi Visible "projt2" TString $ Pi Visible "projvec" (TVec (error "pn1") TFloat) (TVec (error "pn2") TFloat) | ||
208 | n -> error $ "type of " ++ show n | ||
209 | |||
210 | filterRelevant _ _ [] = [] | ||
211 | filterRelevant i (Pi h n t t') (x: xs) = (if h == Visible || exception i then (x:) else id) $ filterRelevant (id *** (+1) $ i) (substE n x t') xs | ||
212 | where | ||
213 | -- todo: remove | ||
214 | exception i = i `elem` [("ColorImage", 0), ("DepthImage", 0), ("StencilImage", 0)] | ||
215 | |||
216 | pattern AN n xs <- Con (n, t) (filterRelevant (n, 0) t -> xs) where AN n xs = Con (n, hackType n) xs | ||
217 | pattern A0 n = AN n [] | ||
218 | pattern A1 n a = AN n [a] | ||
219 | pattern A2 n a b = AN n [a, b] | ||
220 | pattern A3 n a b c <- AN n [a, b, c] | ||
221 | pattern A4 n a b c d <- AN n [a, b, c, d] | ||
222 | pattern A5 n a b c d e <- AN n [a, b, c, d, e] | ||
223 | |||
224 | pattern TCon0 n = A0 n | ||
225 | pattern TCon t n = Con (n, t) [] | ||
226 | |||
227 | pattern Type = TType | ||
228 | pattern Star = TType | ||
229 | pattern TUnit <- A0 "Tuple0" | ||
230 | pattern TBool <- A0 "Bool" | ||
231 | pattern TWord <- A0 "Word" | ||
232 | pattern TInt <- A0 "Int" | ||
233 | pattern TNat = A0 "Nat" | ||
234 | pattern TFloat = A0 "Float" | ||
235 | pattern TString = A0 "String" | ||
236 | pattern TList n <- A1 "List" n | ||
237 | |||
238 | pattern TSampler = A0 "Sampler" | ||
239 | pattern TFrameBuffer a b <- A2 "FrameBuffer" a b | ||
240 | pattern Depth n <- A1 "Depth" n | ||
241 | pattern Stencil n <- A1 "Stencil" n | ||
242 | pattern Color n <- A1 "Color" n | ||
243 | |||
244 | pattern Zero = A0 "Zero" | ||
245 | pattern Succ n = A1 "Succ" n | ||
246 | |||
247 | pattern TVec n a = A2 "VecS" a (Nat n) | ||
248 | pattern TMat i j a <- A3 "Mat" (Nat i) (Nat j) a | ||
249 | |||
250 | pattern Nat n <- (fromNat -> Just n) where Nat n = toNat n | ||
251 | |||
252 | toNat :: Int -> Exp | ||
253 | toNat 0 = Zero | ||
254 | toNat n = Succ (toNat $ n-1) | ||
255 | |||
256 | fromNat :: Exp -> Maybe Int | ||
257 | fromNat Zero = Just 0 | ||
258 | fromNat (Succ n) = (1 +) <$> fromNat n | ||
259 | |||
260 | pattern TTuple xs <- (getTuple -> Just xs) | ||
261 | pattern ETuple xs <- (getTuple -> Just xs) | ||
262 | |||
263 | getTuple = \case | ||
264 | AN "Tuple0" [] -> Just [] | ||
265 | AN "Tuple2" [a, b] -> Just [a, b] | ||
266 | AN "Tuple3" [a, b, c] -> Just [a, b, c] | ||
267 | AN "Tuple4" [a, b, c, d] -> Just [a, b, c, d] | ||
268 | AN "Tuple5" [a, b, c, d, e] -> Just [a, b, c, d, e] | ||
269 | AN "Tuple6" [a, b, c, d, e, f] -> Just [a, b, c, d, e, f] | ||
270 | AN "Tuple7" [a, b, c, d, e, f, g] -> Just [a, b, c, d, e, f, g] | ||
271 | _ -> Nothing | ||
272 | |||
273 | pattern ERecord a <- (const Nothing -> Just a) | ||
274 | |||
275 | -------------------------------------------------------------------------------- | ||
276 | |||
277 | showN = id | ||
278 | show5 = show | ||
279 | |||
280 | pattern ExpN a = a | ||
281 | |||
282 | newtype ErrorMsg = ErrorMsg String | ||
283 | instance Show ErrorMsg where show (ErrorMsg s) = s | ||
284 | |||
285 | type ErrorT = ExceptT ErrorMsg | ||
286 | mapError f e = e | ||
287 | pattern InFile :: String -> ErrorMsg -> ErrorMsg | ||
288 | pattern InFile s e <- ((,) "?" -> (s, e)) where InFile _ e = e | ||
289 | |||
290 | type Info = (SourcePos, SourcePos, String) | ||
291 | type Infos = [Info] | ||
292 | |||
293 | type PolyEnv = I.GlobalEnv | ||
294 | |||
295 | joinPolyEnvs :: MonadError ErrorMsg m => Bool -> [PolyEnv] -> m PolyEnv | ||
296 | joinPolyEnvs _ = return . mconcat | ||
297 | getPolyEnv = id | ||
298 | |||
299 | type MName = SName | ||
300 | type VarMT = StateT FreshVars | ||
301 | type FreshVars = [String] | ||
302 | freshTypeVars = (flip (:) <$> map show [0..] <*> ['a'..'z']) | ||
303 | |||
304 | throwErrorTCM :: MonadError ErrorMsg m => Doc -> m a | ||
305 | throwErrorTCM d = throwError $ ErrorMsg $ show d | ||
306 | |||
307 | infos = const [] | ||
308 | |||
309 | type EName = SName | ||
310 | |||
311 | parseLC :: MonadError ErrorMsg m => FilePath -> String -> m ModuleR | ||
312 | parseLC f s = either (throwError . ErrorMsg) return (I.parse f s) | ||
313 | |||
314 | inference_ :: PolyEnv -> ModuleR -> ErrorT (WriterT Infos (VarMT Identity)) PolyEnv | ||
315 | inference_ pe m = mdo | ||
316 | defs <- either (throwError . ErrorMsg) return $ definitions m $ I.mkGlobalEnv' defs `I.joinGlobalEnv'` I.extractGlobalEnv' pe | ||
317 | either (throwError . ErrorMsg) return $ I.infer pe defs | ||
318 | |||
319 | reduce = id | ||
320 | |||
321 | type Item = (I.Exp, I.Exp) | ||
322 | |||
323 | tyOfItem :: Item -> Exp | ||
324 | tyOfItem = toExp . snd | ||
325 | |||
326 | pattern ISubst a b <- ((,) () -> (a, (toExp -> b, tb))) | ||
327 | |||
328 | dummyPos :: SourcePos | ||
329 | dummyPos = newPos "" 0 0 | ||
330 | |||
331 | showErr e = (dummyPos, dummyPos, e) | ||
332 | |||