summaryrefslogtreecommitdiff
path: root/src/LambdaCube/Compiler/Infer.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/LambdaCube/Compiler/Infer.hs')
-rw-r--r--src/LambdaCube/Compiler/Infer.hs44
1 files changed, 22 insertions, 22 deletions
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs
index 16381ba4..977b8f9b 100644
--- a/src/LambdaCube/Compiler/Infer.hs
+++ b/src/LambdaCube/Compiler/Infer.hs
@@ -20,7 +20,7 @@ module LambdaCube.Compiler.Infer
20 , pattern Var, pattern CaseFun, pattern TyCaseFun, pattern App_, app_ 20 , pattern Var, pattern CaseFun, pattern TyCaseFun, pattern App_, app_
21 , pattern Con, pattern TyCon, pattern Pi, pattern Lam, pattern Fun, pattern ELit, pattern Func, pattern LabelEnd, pattern FL, pattern UFL, unFunc_ 21 , pattern Con, pattern TyCon, pattern Pi, pattern Lam, pattern Fun, pattern ELit, pattern Func, pattern LabelEnd, pattern FL, pattern UFL, unFunc_
22 , outputType, boolType, trueExp 22 , outputType, boolType, trueExp
23 , down, Subst (..), free 23 , down, Subst (..), free, subst
24 , initEnv, Env(..), pattern EBind2 24 , initEnv, Env(..), pattern EBind2
25 , SI(..), Range(..) -- todo: remove 25 , SI(..), Range(..) -- todo: remove
26 , Info(..), Infos, listAllInfos, listTypeInfos, listTraceInfos 26 , Info(..), Infos, listAllInfos, listTypeInfos, listTraceInfos
@@ -360,11 +360,13 @@ unfixlabel a = a
360-------------------------------------------------------------------------------- low-level toolbox 360-------------------------------------------------------------------------------- low-level toolbox
361 361
362class Subst b a where 362class Subst b a where
363 subst :: Int -> b -> a -> a 363 subst_ :: Int -> MaxDB -> b -> a -> a
364
365subst i x a = subst_ i (maxDB_ x) x a
364 366
365down :: (Subst Exp a, Up a{-used-}) => Int -> a -> Maybe a 367down :: (Subst Exp a, Up a{-used-}) => Int -> a -> Maybe a
366down t x | used t x = Nothing 368down t x | used t x = Nothing
367 | otherwise = Just $ subst t (error "impossible: down" :: Exp) x 369 | otherwise = Just $ subst_ t mempty (error "impossible: down" :: Exp) x
368 370
369instance Eq Exp where 371instance Eq Exp where
370 FL a == a' = a == a' 372 FL a == a' = a == a'
@@ -408,12 +410,12 @@ instance Up Exp where
408 | otherwise = ((getAny .) . fold ((Any .) . (==))) i e 410 | otherwise = ((getAny .) . fold ((Any .) . (==))) i e
409 411
410 fold f i = \case 412 fold f i = \case
411 Lam b -> {-fold f i t <> todo: explain why this is not needed -} fold f (i+1) b 413 Lam b -> fold f (i+1) b
412 Pi _ a b -> fold f i a <> fold f (i+1) b 414 Pi _ a b -> fold f i a <> fold f (i+1) b
413 Con _ _ as -> foldMap (fold f i) as 415 Con _ _ as -> foldMap (fold f i) as
414 TyCon _ as -> foldMap (fold f i) as 416 TyCon _ as -> foldMap (fold f i) as
415 TType -> mempty 417 TType -> mempty
416 ELit _ -> mempty 418 ELit{} -> mempty
417 Neut x -> fold f i x 419 Neut x -> fold f i x
418 420
419 maxDB_ = \case 421 maxDB_ = \case
@@ -423,7 +425,7 @@ instance Up Exp where
423 TyCon_ c _ _ -> c 425 TyCon_ c _ _ -> c
424 426
425 TType -> mempty 427 TType -> mempty
426 ELit _ -> mempty 428 ELit{} -> mempty
427 Neut x -> maxDB_ x 429 Neut x -> maxDB_ x
428 430
429 closedExp = \case 431 closedExp = \case
@@ -436,9 +438,8 @@ instance Up Exp where
436 Neut a -> Neut $ closedExp a 438 Neut a -> Neut $ closedExp a
437 439
438instance Subst Exp Exp where 440instance Subst Exp Exp where
439 subst i0 x = f i0 441 subst_ i0 dx x = f i0
440 where 442 where
441 dx = maxDB_ x
442 f i (Neut n) = substNeut n 443 f i (Neut n) = substNeut n
443 where 444 where
444 substNeut e | cmpDB i e = Neut e 445 substNeut e | cmpDB i e = Neut e
@@ -503,7 +504,7 @@ instance Up Neutral where
503 d@Delta{} -> d 504 d@Delta{} -> d
504 505
505instance (Subst x a, Subst x b) => Subst x (a, b) where 506instance (Subst x a, Subst x b) => Subst x (a, b) where
506 subst i x (a, b) = (subst i x a, subst i x b) 507 subst_ i dx x (a, b) = (subst_ i dx x a, subst_ i dx x b)
507 508
508varType' :: Int -> [Exp] -> Exp 509varType' :: Int -> [Exp] -> Exp
509varType' i vs = vs !! i 510varType' i vs = vs !! i
@@ -558,7 +559,7 @@ getFunDef s f = case s of
558 parEval _ _ (LabelEnd x) = LabelEnd x 559 parEval _ _ (LabelEnd x) = LabelEnd x
559 parEval t a b = ParEval t a b 560 parEval t a b = ParEval t a b
560 CFName _ (SData s) -> case s of 561 CFName _ (SData s) -> case s of
561 "unsafeCoerce" -> \case xs@(_: _: x: _) -> case x of x@FL{} -> x; Neut{} -> f xs; _ -> x 562 "unsafeCoerce" -> \case xs@(_: _: x@NonNeut: _) -> x; xs -> f xs
562 "reflCstr" -> \case (a: _) -> TT 563 "reflCstr" -> \case (a: _) -> TT
563 564
564 "hlistNilCase" -> \case (_: x: (unfixlabel -> Con n@(ConName _ 0 _) _ _): _) -> x; xs -> f xs 565 "hlistNilCase" -> \case (_: x: (unfixlabel -> Con n@(ConName _ 0 _) _ _): _) -> x; xs -> f xs
@@ -621,24 +622,23 @@ cstr = f []
621 622
622 f_ [] TType (UFunN FVecScalar [a, b]) (UFunN FVecScalar [a', b']) = t2 (f [] TNat a a') (f [] TType b b') 623 f_ [] TType (UFunN FVecScalar [a, b]) (UFunN FVecScalar [a', b']) = t2 (f [] TNat a a') (f [] TType b b')
623 f_ [] TType (UFunN FVecScalar [a, b]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b') 624 f_ [] TType (UFunN FVecScalar [a, b]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b')
624 f_ [] TType (UFunN FVecScalar [a, b]) t@(TTyCon0 n) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) 625 f_ [] TType (UFunN FVecScalar [a, b]) t@NonNeut = t2 (f [] TNat a (ENat 1)) (f [] TType b t)
625 f_ [] TType (TVec a' b') (UFunN FVecScalar [a, b]) = t2 (f [] TNat a' a) (f [] TType b' b) 626 f_ [] TType (TVec a' b') (UFunN FVecScalar [a, b]) = t2 (f [] TNat a' a) (f [] TType b' b)
626 f_ [] TType t@(TTyCon0 n) (UFunN FVecScalar [a, b]) | isElemTy n = t2 (f [] TNat a (ENat 1)) (f [] TType b t) 627 f_ [] TType t@NonNeut (UFunN FVecScalar [a, b]) = t2 (f [] TNat a (ENat 1)) (f [] TType b t)
627 628
628 f_ [] typ a@Neut{} a' = CstrT typ a a' 629 f_ [] typ a@Neut{} a' = CstrT typ a a'
629 f_ [] typ a a'@Neut{} = CstrT typ a a' 630 f_ [] typ a a'@Neut{} = CstrT typ a a'
630 631 f_ ns typ a a' = Empty $ unlines [ "can not unify", ppShow a, "with", ppShow a' ]
631 f_ ns typ a a' = Empty $ unlines [ "can not unify"
632 , ppShow a
633 , "with"
634 , ppShow a'
635 ]
636 632
637 ff _ _ [] = Unit 633 ff _ _ [] = Unit
638 ff ns tt@(Pi v t _) ((t1, t2'): ts) = t2 (f ns t t1 t2') $ ff ns (appTy tt t1) ts 634 ff ns tt@(Pi v t _) ((t1, t2'): ts) = t2 (f ns t t1 t2') $ ff ns (appTy tt t1) ts
639 ff ns t zs = error $ "ff: " -- ++ show (a, n, length xs', length $ mkConPars n typ) ++ "\n" ++ ppShow (nType a) ++ "\n" ++ ppShow (foldl appTy (nType a) $ mkConPars n typ) ++ "\n" ++ ppShow (zip xs xs') ++ "\n" ++ ppShow zs ++ "\n" ++ ppShow t 635 ff ns t zs = error $ "ff: " -- ++ show (a, n, length xs', length $ mkConPars n typ) ++ "\n" ++ ppShow (nType a) ++ "\n" ++ ppShow (foldl appTy (nType a) $ mkConPars n typ) ++ "\n" ++ ppShow (zip xs xs') ++ "\n" ++ ppShow zs ++ "\n" ++ ppShow t
640 636
641 isElemTy n = n `elem` [FBool, FFloat, FInt] 637pattern NonNeut <- (nonNeut -> True)
638
639nonNeut FL{} = True
640nonNeut Neut{} = False
641nonNeut _ = True
642 642
643t2C (unfixlabel -> TT) (unfixlabel -> TT) = TT 643t2C (unfixlabel -> TT) (unfixlabel -> TT) = TT
644t2C a b = TFun Ft2C (Unit :~> Unit :~> Unit) [a, b] 644t2C a b = TFun Ft2C (Unit :~> Unit :~> Unit) [a, b]
@@ -710,9 +710,9 @@ instance (Subst Exp a, Up a) => Up (CEnv a) where
710 maxDB_ _ = error "maxDB_ @(CEnv _)" 710 maxDB_ _ = error "maxDB_ @(CEnv _)"
711 711
712instance (Subst Exp a, Up a) => Subst Exp (CEnv a) where 712instance (Subst Exp a, Up a) => Subst Exp (CEnv a) where
713 subst i x = \case 713 subst_ i dx x = \case
714 MEnd a -> MEnd $ subst i x a 714 MEnd a -> MEnd $ subst_ i dx x a
715 Meta a b -> Meta (subst i x a) (subst (i+1) (up 1 x) b) 715 Meta a b -> Meta (subst_ i dx x a) (subst_ (i+1) (upDB 1 dx) (up 1 x) b)
716 Assign j a b 716 Assign j a b
717 | j > i, Just a' <- down i a -> assign (j-1) a' (subst i (subst (j-1) (fst a') x) b) 717 | j > i, Just a' <- down i a -> assign (j-1) a' (subst i (subst (j-1) (fst a') x) b)
718 | j > i, Just x' <- down (j-1) x -> assign (j-1) (subst i x' a) (subst i x' b) 718 | j > i, Just x' <- down (j-1) x -> assign (j-1) (subst i x' a) (subst i x' b)