summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/LambdaCube/Compiler/CoreToIR.hs2
-rw-r--r--src/LambdaCube/Compiler/Infer.hs67
2 files changed, 31 insertions, 38 deletions
diff --git a/src/LambdaCube/Compiler/CoreToIR.hs b/src/LambdaCube/Compiler/CoreToIR.hs
index 58679a6a..293b1bc9 100644
--- a/src/LambdaCube/Compiler/CoreToIR.hs
+++ b/src/LambdaCube/Compiler/CoreToIR.hs
@@ -901,7 +901,7 @@ toExp = flip runReader [] . flip evalStateT freshTypeVars . f_
901 app' <$> f_ (I.Neut a, t) <*> (head <$> chain [] t [b]) 901 app' <$> f_ (I.Neut a, t) <*> (head <$> chain [] t [b])
902 I.ELit l -> pure $ ELit l 902 I.ELit l -> pure $ ELit l
903 I.TType -> pure TType 903 I.TType -> pure TType
904 x@I.PMLabel{} -> f_ (I.unpmlabel x, et) 904 (I.unpmlabel -> Just x) -> f_ (x, et)
905 I.FixLabel _ x -> f_ (x, et) 905 I.FixLabel _ x -> f_ (x, et)
906-- I.LabelEnd x -> f x -- not possible 906-- I.LabelEnd x -> f x -- not possible
907 z -> error $ "toExp: " ++ show z 907 z -> error $ "toExp: " ++ show z
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs
index 291d1e00..ae2a9f1b 100644
--- a/src/LambdaCube/Compiler/Infer.hs
+++ b/src/LambdaCube/Compiler/Infer.hs
@@ -17,7 +17,7 @@
17module LambdaCube.Compiler.Infer 17module LambdaCube.Compiler.Infer
18 ( Binder (..), SName, Lit(..), Visibility(..), Export(..), Module(..) 18 ( Binder (..), SName, Lit(..), Visibility(..), Export(..), Module(..)
19 , Exp (..), ExpType, GlobalEnv 19 , Exp (..), ExpType, GlobalEnv
20 , pattern Var, pattern Fun, pattern CaseFun, pattern TyCaseFun, pattern App_, pattern PMLabel, pattern FixLabel 20 , pattern Var, pattern Fun, pattern CaseFun, pattern TyCaseFun, pattern App_, pattern FixLabel
21 , pattern Con, pattern TyCon, pattern Pi, pattern Lam 21 , pattern Con, pattern TyCon, pattern Pi, pattern Lam
22 , outputType, boolType, trueExp 22 , outputType, boolType, trueExp
23 , down 23 , down
@@ -169,7 +169,7 @@ pattern ParEval t a b = TFun "parEval" (TType :~> Var 0 :~> Var 1 :~> Var 2) [t,
169pattern Undef t = TFun "undefined" (Pi Hidden TType (Var 0)) [t] 169pattern Undef t = TFun "undefined" (Pi Hidden TType (Var 0)) [t]
170pattern T2 a b = TFun "'T2" (TType :~> TType :~> TType) [a, b] 170pattern T2 a b = TFun "'T2" (TType :~> TType :~> TType) [a, b]
171pattern T2C a b = TFun "t2C" (Unit :~> Unit :~> Unit) [a, b] 171pattern T2C a b = TFun "t2C" (Unit :~> Unit :~> Unit) [a, b]
172pattern CSplit a b c <- FunN "'Split" [a, b, c] 172pattern CSplit a b c <- UFunN "'Split" [a, b, c]
173 173
174pattern EInt a = ELit (LInt a) 174pattern EInt a = ELit (LInt a)
175pattern EFloat a = ELit (LFloat a) 175pattern EFloat a = ELit (LFloat a)
@@ -241,31 +241,23 @@ pattern LabelEnd x = LabelEnd_ LEPM x
241 241
242label LabelFix x y = FixLabel x y 242label LabelFix x y = FixLabel x y
243pmLabel :: FunName -> Int -> [Exp] -> Exp -> Exp 243pmLabel :: FunName -> Int -> [Exp] -> Exp -> Exp
244pmLabel _ _ _ (unlabel'' -> LabelEnd y) = y 244pmLabel _ _ _ (unfixlabel -> LabelEnd y) = y
245pmLabel f i xs y@Neut{} = PMLabel f i xs y 245pmLabel f i xs y@Neut{} = PMLabel f i xs y
246pmLabel f i xs y@Lam{} = PMLabel f i xs y 246pmLabel f i xs y@Lam{} = PMLabel f i xs y
247pmLabel f i xs y = error $ "pmLabel: " ++ show y 247pmLabel f i xs y = error $ "pmLabel: " ++ show y
248 248
249pattern UL a <- (unlabel -> a) where UL = unlabel 249pattern UFunN a b <- (unpmlabel -> Just (FunN a b))
250 250
251unpmlabel (PMLabel f i a _) 251unpmlabel (PMLabel f i a _)
252 | i >= 0 = iterateN i Lam $ Fun f $ a ++ downTo 0 i 252 | i >= 0 = Just $ iterateN i Lam $ Fun f $ a ++ downTo 0 i
253 | otherwise = foldl app_ (Fun f $ reverse $ drop (-i) $ reverse a) (reverse $ take (-i) $ reverse a) 253 | otherwise = Just $ foldl app_ (Fun f $ reverse $ drop (-i) $ reverse a) (reverse $ take (-i) $ reverse a)
254unpmlabel _ = Nothing
254 255
255unlabel x@PMLabel{} = unlabel (unpmlabel x) 256unfixlabel (FixLabel _ a) = unfixlabel a
256unlabel (FixLabel _ a) = unlabel a 257unfixlabel a = a
257--unlabel (LabelEnd_ _ a) = unlabel a
258unlabel a = a
259 258
260unlabel'' (FixLabel _ a) = unlabel'' a 259unlabelend (LabelEnd_ _ a) = unlabelend a
261unlabel'' a = a 260unlabelend a = a
262
263pattern UL' a <- (unlabel' -> a) where UL' = unlabel'
264
265--unlabel (PMLabel a _) = unlabel a
266--unlabel (FixLabel _ a) = unlabel a
267unlabel' (LabelEnd_ _ a) = unlabel' a
268unlabel' a = a
269 261
270 262
271-------------------------------------------------------------------------------- low-level toolbox 263-------------------------------------------------------------------------------- low-level toolbox
@@ -523,7 +515,7 @@ getFunDef s = case show s of
523 515
524cstr = f [] 516cstr = f []
525 where 517 where
526 f _ _ (UL a) (UL a') | a == a' = Unit 518 f _ _ a a' | a == a' = Unit
527 f ns typ (LabelEnd_ k a) a' = f ns typ a a' 519 f ns typ (LabelEnd_ k a) a' = f ns typ a a'
528 f ns typ a (LabelEnd_ k a') = f ns typ a a' 520 f ns typ a (LabelEnd_ k a') = f ns typ a a'
529 f ns typ (FixLabel _ a) a' = f ns typ a a' 521 f ns typ (FixLabel _ a) a' = f ns typ a a'
@@ -535,19 +527,19 @@ cstr = f []
535 f (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a' 527 f (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a'
536 f ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b') 528 f ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b')
537 529
538 f [] TType (UL (FunN "'VecScalar" [a, b])) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b') 530 f [] TType (UFunN "'VecScalar" [a, b]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b')
539 f [] TType (UL (FunN "'VecScalar" [a, b])) (UL (FunN "'VecScalar" [a', b'])) = t2 (f [] TNat a a') (f [] TType b b') 531 f [] TType (UFunN "'VecScalar" [a, b]) (UFunN "'VecScalar" [a', b']) = t2 (f [] TNat a a') (f [] TType b b')
540 f [] TType (UL (FunN "'VecScalar" [a, b])) t@(TTyCon0 n) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) 532 f [] TType (UFunN "'VecScalar" [a, b]) t@(TTyCon0 n) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t)
541 f [] TType t@(TTyCon0 n) (UL (FunN "'VecScalar" [a, b])) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) 533 f [] TType t@(TTyCon0 n) (UFunN "'VecScalar" [a, b]) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t)
542 534
543 f [] TType (UL (FunN "'FragOps" [a])) (TyConN "'FragmentOperation" [x]) = f [] (TList TImageSemantics) a (cons x nil) 535 f [] TType (UFunN "'FragOps" [a]) (TyConN "'FragmentOperation" [x]) = f [] (TList TImageSemantics) a (cons x nil)
544 f [] TType (UL (FunN "'FragOps" [a])) (TyConN "'Tuple2" [TyConN "'FragmentOperation" [x], TyConN "'FragmentOperation" [y]]) = f [] (TList TImageSemantics) a $ cons x $ cons y nil 536 f [] TType (UFunN "'FragOps" [a]) (TyConN "'Tuple2" [TyConN "'FragmentOperation" [x], TyConN "'FragmentOperation" [y]]) = f [] (TList TImageSemantics) a $ cons x $ cons y nil
545 537
546 f ns@[] TType (TyConN "'Tuple2" [x, y]) (UL (FunN "'JoinTupleType" [x', y'])) = t2 (f ns TType x x') (f ns TType y y') 538 f ns@[] TType (TyConN "'Tuple2" [x, y]) (UFunN "'JoinTupleType" [x', y']) = t2 (f ns TType x x') (f ns TType y y')
547 f ns@[] TType (UL (FunN "'JoinTupleType" [x', y'])) (TyConN "'Tuple2" [x, y]) = t2 (f ns TType x' x) (f ns TType y' y) 539 f ns@[] TType (UFunN "'JoinTupleType" [x', y']) (TyConN "'Tuple2" [x, y]) = t2 (f ns TType x' x) (f ns TType y' y)
548 f ns@[] TType (UL (FunN "'JoinTupleType" [x', y'])) x@NoTup = t2 (f ns TType x' x) (f ns TType y' $ TTyCon0 "'Tuple0") 540 f ns@[] TType (UFunN "'JoinTupleType" [x', y']) x@NoTup = t2 (f ns TType x' x) (f ns TType y' $ TTyCon0 "'Tuple0")
549 541
550 f ns@[] TType (x@NoTup) (UL (FunN "'InterpolatedType" [x'])) = f ns TType (TTyCon "'Interpolated" (TType :~> TType) [x]) x' 542 f ns@[] TType (x@NoTup) (UFunN "'InterpolatedType" [x']) = f ns TType (TTyCon "'Interpolated" (TType :~> TType) [x]) x'
551 543
552 f [] typ a@Neut{} a' = CstrT typ a a' 544 f [] typ a@Neut{} a' = CstrT typ a a'
553 f [] typ a a'@Neut{} = CstrT typ a a' 545 f [] typ a a'@Neut{} = CstrT typ a a'
@@ -931,8 +923,8 @@ replaceMetas bind = \case
931 923
932 924
933isCstr CstrT{} = True 925isCstr CstrT{} = True
934isCstr (UL (FunN s _)) = s `elem` ["'Eq", "'Ord", "'Num", "'CNum", "'Signed", "'Component", "'Integral", "'NumComponent", "'Floating"] -- todo: use Constraint type to decide this 926isCstr (UFunN s _) = s `elem` ["'Eq", "'Ord", "'Num", "'CNum", "'Signed", "'Component", "'Integral", "'NumComponent", "'Floating"] -- todo: use Constraint type to decide this
935isCstr (UL c) = {- trace_ (ppShow c ++ show c) $ -} False 927isCstr _ = {- trace_ (ppShow c ++ show c) $ -} False
936 928
937-------------------------------------------------------------------------------- re-checking 929-------------------------------------------------------------------------------- re-checking
938 930
@@ -1056,19 +1048,20 @@ initEnv = Map.fromList
1056extractDesugarInfo :: GlobalEnv -> DesugarInfo 1048extractDesugarInfo :: GlobalEnv -> DesugarInfo
1057extractDesugarInfo ge = 1049extractDesugarInfo ge =
1058 ( Map.fromList 1050 ( Map.fromList
1059 [ (n, f) | (n, (d, _, si)) <- Map.toList ge, f <- maybeToList $ case UL' d of 1051 [ (n, f) | (n, (d, _, si)) <- Map.toList ge, f <- maybeToList $ case d of
1060 Con (ConName _ f _ _ _) 0 [] -> f 1052 Con (ConName _ f _ _ _) 0 [] -> f
1061 TyCon (TyConName _ f _ _ _ _) [] -> f 1053 TyCon (TyConName _ f _ _ _ _) [] -> f
1062 (getLams -> UL (getLams -> Fun (FunName _ f _) _)) -> f 1054 (getLams -> (Fun (FunName _ f _) [])) -> f
1055 PMLabel (FunName _ f _) _ [] _ -> f
1063 Fun (FunName _ f _) [] -> f 1056 Fun (FunName _ f _) [] -> f
1064 _ -> Nothing 1057 _ -> Nothing
1065 ] 1058 ]
1066 , Map.fromList $ 1059 , Map.fromList $
1067 [ (n, Left ((t, inum), map f cons)) 1060 [ (n, Left ((t, inum), map f cons))
1068 | (n, (UL' (Con cn 0 []), _, si)) <- Map.toList ge, let TyConName t _ inum _ cons _ = conTypeName cn 1061 | (n, ( (Con cn 0 []), _, si)) <- Map.toList ge, let TyConName t _ inum _ cons _ = conTypeName cn
1069 ] ++ 1062 ] ++
1070 [ (n, Right $ pars t) 1063 [ (n, Right $ pars t)
1071 | (n, (UL' (TyCon (TyConName _ _ _ t _ _) []), _, _)) <- Map.toList ge 1064 | (n, ( (TyCon (TyConName _ _ _ t _ _) []), _, _)) <- Map.toList ge
1072 ] 1065 ]
1073 ) 1066 )
1074 where 1067 where
@@ -1199,7 +1192,7 @@ arity :: Exp -> Int
1199arity = length . fst . getParams 1192arity = length . fst . getParams
1200 1193
1201getParams :: Exp -> ([(Visibility, Exp)], Exp) 1194getParams :: Exp -> ([(Visibility, Exp)], Exp)
1202getParams (UL' (Pi h a b)) = first ((h, a):) $ getParams b 1195getParams (unlabelend -> Pi h a b) = first ((h, a):) $ getParams b
1203getParams x = ([], x) 1196getParams x = ([], x)
1204 1197
1205getLams (Lam b) = getLams b 1198getLams (Lam b) = getLams b