diff options
-rw-r--r-- | src/LambdaCube/Compiler/CoreToIR.hs | 2 | ||||
-rw-r--r-- | src/LambdaCube/Compiler/Infer.hs | 67 |
2 files changed, 31 insertions, 38 deletions
diff --git a/src/LambdaCube/Compiler/CoreToIR.hs b/src/LambdaCube/Compiler/CoreToIR.hs index 58679a6a..293b1bc9 100644 --- a/src/LambdaCube/Compiler/CoreToIR.hs +++ b/src/LambdaCube/Compiler/CoreToIR.hs | |||
@@ -901,7 +901,7 @@ toExp = flip runReader [] . flip evalStateT freshTypeVars . f_ | |||
901 | app' <$> f_ (I.Neut a, t) <*> (head <$> chain [] t [b]) | 901 | app' <$> f_ (I.Neut a, t) <*> (head <$> chain [] t [b]) |
902 | I.ELit l -> pure $ ELit l | 902 | I.ELit l -> pure $ ELit l |
903 | I.TType -> pure TType | 903 | I.TType -> pure TType |
904 | x@I.PMLabel{} -> f_ (I.unpmlabel x, et) | 904 | (I.unpmlabel -> Just x) -> f_ (x, et) |
905 | I.FixLabel _ x -> f_ (x, et) | 905 | I.FixLabel _ x -> f_ (x, et) |
906 | -- I.LabelEnd x -> f x -- not possible | 906 | -- I.LabelEnd x -> f x -- not possible |
907 | z -> error $ "toExp: " ++ show z | 907 | z -> error $ "toExp: " ++ show z |
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs index 291d1e00..ae2a9f1b 100644 --- a/src/LambdaCube/Compiler/Infer.hs +++ b/src/LambdaCube/Compiler/Infer.hs | |||
@@ -17,7 +17,7 @@ | |||
17 | module LambdaCube.Compiler.Infer | 17 | module LambdaCube.Compiler.Infer |
18 | ( Binder (..), SName, Lit(..), Visibility(..), Export(..), Module(..) | 18 | ( Binder (..), SName, Lit(..), Visibility(..), Export(..), Module(..) |
19 | , Exp (..), ExpType, GlobalEnv | 19 | , Exp (..), ExpType, GlobalEnv |
20 | , pattern Var, pattern Fun, pattern CaseFun, pattern TyCaseFun, pattern App_, pattern PMLabel, pattern FixLabel | 20 | , pattern Var, pattern Fun, pattern CaseFun, pattern TyCaseFun, pattern App_, pattern FixLabel |
21 | , pattern Con, pattern TyCon, pattern Pi, pattern Lam | 21 | , pattern Con, pattern TyCon, pattern Pi, pattern Lam |
22 | , outputType, boolType, trueExp | 22 | , outputType, boolType, trueExp |
23 | , down | 23 | , down |
@@ -169,7 +169,7 @@ pattern ParEval t a b = TFun "parEval" (TType :~> Var 0 :~> Var 1 :~> Var 2) [t, | |||
169 | pattern Undef t = TFun "undefined" (Pi Hidden TType (Var 0)) [t] | 169 | pattern Undef t = TFun "undefined" (Pi Hidden TType (Var 0)) [t] |
170 | pattern T2 a b = TFun "'T2" (TType :~> TType :~> TType) [a, b] | 170 | pattern T2 a b = TFun "'T2" (TType :~> TType :~> TType) [a, b] |
171 | pattern T2C a b = TFun "t2C" (Unit :~> Unit :~> Unit) [a, b] | 171 | pattern T2C a b = TFun "t2C" (Unit :~> Unit :~> Unit) [a, b] |
172 | pattern CSplit a b c <- FunN "'Split" [a, b, c] | 172 | pattern CSplit a b c <- UFunN "'Split" [a, b, c] |
173 | 173 | ||
174 | pattern EInt a = ELit (LInt a) | 174 | pattern EInt a = ELit (LInt a) |
175 | pattern EFloat a = ELit (LFloat a) | 175 | pattern EFloat a = ELit (LFloat a) |
@@ -241,31 +241,23 @@ pattern LabelEnd x = LabelEnd_ LEPM x | |||
241 | 241 | ||
242 | label LabelFix x y = FixLabel x y | 242 | label LabelFix x y = FixLabel x y |
243 | pmLabel :: FunName -> Int -> [Exp] -> Exp -> Exp | 243 | pmLabel :: FunName -> Int -> [Exp] -> Exp -> Exp |
244 | pmLabel _ _ _ (unlabel'' -> LabelEnd y) = y | 244 | pmLabel _ _ _ (unfixlabel -> LabelEnd y) = y |
245 | pmLabel f i xs y@Neut{} = PMLabel f i xs y | 245 | pmLabel f i xs y@Neut{} = PMLabel f i xs y |
246 | pmLabel f i xs y@Lam{} = PMLabel f i xs y | 246 | pmLabel f i xs y@Lam{} = PMLabel f i xs y |
247 | pmLabel f i xs y = error $ "pmLabel: " ++ show y | 247 | pmLabel f i xs y = error $ "pmLabel: " ++ show y |
248 | 248 | ||
249 | pattern UL a <- (unlabel -> a) where UL = unlabel | 249 | pattern UFunN a b <- (unpmlabel -> Just (FunN a b)) |
250 | 250 | ||
251 | unpmlabel (PMLabel f i a _) | 251 | unpmlabel (PMLabel f i a _) |
252 | | i >= 0 = iterateN i Lam $ Fun f $ a ++ downTo 0 i | 252 | | i >= 0 = Just $ iterateN i Lam $ Fun f $ a ++ downTo 0 i |
253 | | otherwise = foldl app_ (Fun f $ reverse $ drop (-i) $ reverse a) (reverse $ take (-i) $ reverse a) | 253 | | otherwise = Just $ foldl app_ (Fun f $ reverse $ drop (-i) $ reverse a) (reverse $ take (-i) $ reverse a) |
254 | unpmlabel _ = Nothing | ||
254 | 255 | ||
255 | unlabel x@PMLabel{} = unlabel (unpmlabel x) | 256 | unfixlabel (FixLabel _ a) = unfixlabel a |
256 | unlabel (FixLabel _ a) = unlabel a | 257 | unfixlabel a = a |
257 | --unlabel (LabelEnd_ _ a) = unlabel a | ||
258 | unlabel a = a | ||
259 | 258 | ||
260 | unlabel'' (FixLabel _ a) = unlabel'' a | 259 | unlabelend (LabelEnd_ _ a) = unlabelend a |
261 | unlabel'' a = a | 260 | unlabelend a = a |
262 | |||
263 | pattern UL' a <- (unlabel' -> a) where UL' = unlabel' | ||
264 | |||
265 | --unlabel (PMLabel a _) = unlabel a | ||
266 | --unlabel (FixLabel _ a) = unlabel a | ||
267 | unlabel' (LabelEnd_ _ a) = unlabel' a | ||
268 | unlabel' a = a | ||
269 | 261 | ||
270 | 262 | ||
271 | -------------------------------------------------------------------------------- low-level toolbox | 263 | -------------------------------------------------------------------------------- low-level toolbox |
@@ -523,7 +515,7 @@ getFunDef s = case show s of | |||
523 | 515 | ||
524 | cstr = f [] | 516 | cstr = f [] |
525 | where | 517 | where |
526 | f _ _ (UL a) (UL a') | a == a' = Unit | 518 | f _ _ a a' | a == a' = Unit |
527 | f ns typ (LabelEnd_ k a) a' = f ns typ a a' | 519 | f ns typ (LabelEnd_ k a) a' = f ns typ a a' |
528 | f ns typ a (LabelEnd_ k a') = f ns typ a a' | 520 | f ns typ a (LabelEnd_ k a') = f ns typ a a' |
529 | f ns typ (FixLabel _ a) a' = f ns typ a a' | 521 | f ns typ (FixLabel _ a) a' = f ns typ a a' |
@@ -535,19 +527,19 @@ cstr = f [] | |||
535 | f (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a' | 527 | f (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a' |
536 | f ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b') | 528 | f ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b') |
537 | 529 | ||
538 | f [] TType (UL (FunN "'VecScalar" [a, b])) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b') | 530 | f [] TType (UFunN "'VecScalar" [a, b]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b') |
539 | f [] TType (UL (FunN "'VecScalar" [a, b])) (UL (FunN "'VecScalar" [a', b'])) = t2 (f [] TNat a a') (f [] TType b b') | 531 | f [] TType (UFunN "'VecScalar" [a, b]) (UFunN "'VecScalar" [a', b']) = t2 (f [] TNat a a') (f [] TType b b') |
540 | f [] TType (UL (FunN "'VecScalar" [a, b])) t@(TTyCon0 n) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) | 532 | f [] TType (UFunN "'VecScalar" [a, b]) t@(TTyCon0 n) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) |
541 | f [] TType t@(TTyCon0 n) (UL (FunN "'VecScalar" [a, b])) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) | 533 | f [] TType t@(TTyCon0 n) (UFunN "'VecScalar" [a, b]) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) |
542 | 534 | ||
543 | f [] TType (UL (FunN "'FragOps" [a])) (TyConN "'FragmentOperation" [x]) = f [] (TList TImageSemantics) a (cons x nil) | 535 | f [] TType (UFunN "'FragOps" [a]) (TyConN "'FragmentOperation" [x]) = f [] (TList TImageSemantics) a (cons x nil) |
544 | f [] TType (UL (FunN "'FragOps" [a])) (TyConN "'Tuple2" [TyConN "'FragmentOperation" [x], TyConN "'FragmentOperation" [y]]) = f [] (TList TImageSemantics) a $ cons x $ cons y nil | 536 | f [] TType (UFunN "'FragOps" [a]) (TyConN "'Tuple2" [TyConN "'FragmentOperation" [x], TyConN "'FragmentOperation" [y]]) = f [] (TList TImageSemantics) a $ cons x $ cons y nil |
545 | 537 | ||
546 | f ns@[] TType (TyConN "'Tuple2" [x, y]) (UL (FunN "'JoinTupleType" [x', y'])) = t2 (f ns TType x x') (f ns TType y y') | 538 | f ns@[] TType (TyConN "'Tuple2" [x, y]) (UFunN "'JoinTupleType" [x', y']) = t2 (f ns TType x x') (f ns TType y y') |
547 | f ns@[] TType (UL (FunN "'JoinTupleType" [x', y'])) (TyConN "'Tuple2" [x, y]) = t2 (f ns TType x' x) (f ns TType y' y) | 539 | f ns@[] TType (UFunN "'JoinTupleType" [x', y']) (TyConN "'Tuple2" [x, y]) = t2 (f ns TType x' x) (f ns TType y' y) |
548 | f ns@[] TType (UL (FunN "'JoinTupleType" [x', y'])) x@NoTup = t2 (f ns TType x' x) (f ns TType y' $ TTyCon0 "'Tuple0") | 540 | f ns@[] TType (UFunN "'JoinTupleType" [x', y']) x@NoTup = t2 (f ns TType x' x) (f ns TType y' $ TTyCon0 "'Tuple0") |
549 | 541 | ||
550 | f ns@[] TType (x@NoTup) (UL (FunN "'InterpolatedType" [x'])) = f ns TType (TTyCon "'Interpolated" (TType :~> TType) [x]) x' | 542 | f ns@[] TType (x@NoTup) (UFunN "'InterpolatedType" [x']) = f ns TType (TTyCon "'Interpolated" (TType :~> TType) [x]) x' |
551 | 543 | ||
552 | f [] typ a@Neut{} a' = CstrT typ a a' | 544 | f [] typ a@Neut{} a' = CstrT typ a a' |
553 | f [] typ a a'@Neut{} = CstrT typ a a' | 545 | f [] typ a a'@Neut{} = CstrT typ a a' |
@@ -931,8 +923,8 @@ replaceMetas bind = \case | |||
931 | 923 | ||
932 | 924 | ||
933 | isCstr CstrT{} = True | 925 | isCstr CstrT{} = True |
934 | isCstr (UL (FunN s _)) = s `elem` ["'Eq", "'Ord", "'Num", "'CNum", "'Signed", "'Component", "'Integral", "'NumComponent", "'Floating"] -- todo: use Constraint type to decide this | 926 | isCstr (UFunN s _) = s `elem` ["'Eq", "'Ord", "'Num", "'CNum", "'Signed", "'Component", "'Integral", "'NumComponent", "'Floating"] -- todo: use Constraint type to decide this |
935 | isCstr (UL c) = {- trace_ (ppShow c ++ show c) $ -} False | 927 | isCstr _ = {- trace_ (ppShow c ++ show c) $ -} False |
936 | 928 | ||
937 | -------------------------------------------------------------------------------- re-checking | 929 | -------------------------------------------------------------------------------- re-checking |
938 | 930 | ||
@@ -1056,19 +1048,20 @@ initEnv = Map.fromList | |||
1056 | extractDesugarInfo :: GlobalEnv -> DesugarInfo | 1048 | extractDesugarInfo :: GlobalEnv -> DesugarInfo |
1057 | extractDesugarInfo ge = | 1049 | extractDesugarInfo ge = |
1058 | ( Map.fromList | 1050 | ( Map.fromList |
1059 | [ (n, f) | (n, (d, _, si)) <- Map.toList ge, f <- maybeToList $ case UL' d of | 1051 | [ (n, f) | (n, (d, _, si)) <- Map.toList ge, f <- maybeToList $ case d of |
1060 | Con (ConName _ f _ _ _) 0 [] -> f | 1052 | Con (ConName _ f _ _ _) 0 [] -> f |
1061 | TyCon (TyConName _ f _ _ _ _) [] -> f | 1053 | TyCon (TyConName _ f _ _ _ _) [] -> f |
1062 | (getLams -> UL (getLams -> Fun (FunName _ f _) _)) -> f | 1054 | (getLams -> (Fun (FunName _ f _) [])) -> f |
1055 | PMLabel (FunName _ f _) _ [] _ -> f | ||
1063 | Fun (FunName _ f _) [] -> f | 1056 | Fun (FunName _ f _) [] -> f |
1064 | _ -> Nothing | 1057 | _ -> Nothing |
1065 | ] | 1058 | ] |
1066 | , Map.fromList $ | 1059 | , Map.fromList $ |
1067 | [ (n, Left ((t, inum), map f cons)) | 1060 | [ (n, Left ((t, inum), map f cons)) |
1068 | | (n, (UL' (Con cn 0 []), _, si)) <- Map.toList ge, let TyConName t _ inum _ cons _ = conTypeName cn | 1061 | | (n, ( (Con cn 0 []), _, si)) <- Map.toList ge, let TyConName t _ inum _ cons _ = conTypeName cn |
1069 | ] ++ | 1062 | ] ++ |
1070 | [ (n, Right $ pars t) | 1063 | [ (n, Right $ pars t) |
1071 | | (n, (UL' (TyCon (TyConName _ _ _ t _ _) []), _, _)) <- Map.toList ge | 1064 | | (n, ( (TyCon (TyConName _ _ _ t _ _) []), _, _)) <- Map.toList ge |
1072 | ] | 1065 | ] |
1073 | ) | 1066 | ) |
1074 | where | 1067 | where |
@@ -1199,7 +1192,7 @@ arity :: Exp -> Int | |||
1199 | arity = length . fst . getParams | 1192 | arity = length . fst . getParams |
1200 | 1193 | ||
1201 | getParams :: Exp -> ([(Visibility, Exp)], Exp) | 1194 | getParams :: Exp -> ([(Visibility, Exp)], Exp) |
1202 | getParams (UL' (Pi h a b)) = first ((h, a):) $ getParams b | 1195 | getParams (unlabelend -> Pi h a b) = first ((h, a):) $ getParams b |
1203 | getParams x = ([], x) | 1196 | getParams x = ([], x) |
1204 | 1197 | ||
1205 | getLams (Lam b) = getLams b | 1198 | getLams (Lam b) = getLams b |