diff options
Diffstat (limited to 'src/LambdaCube/Compiler/Infer.hs')
-rw-r--r-- | src/LambdaCube/Compiler/Infer.hs | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs index f39c9332..0b922c3b 100644 --- a/src/LambdaCube/Compiler/Infer.hs +++ b/src/LambdaCube/Compiler/Infer.hs | |||
@@ -678,7 +678,10 @@ reflCstr te = \case | |||
678 | -} | 678 | -} |
679 | x -> TT | 679 | x -> TT |
680 | 680 | ||
681 | cstrT t a a' | a == a' = Unit | 681 | cstrT t (UL a) (UL a') | a == a' = Unit |
682 | cstrT t (ConN "Succ" [a]) (ConN "Succ" [a']) = cstrT TNat a a' | ||
683 | cstrT t (FixLabel _ a) a' = cstrT t a a' | ||
684 | cstrT t a (FixLabel _ a') = cstrT t a a' | ||
682 | cstrT t a a' = CstrT t a a' | 685 | cstrT t a a' = CstrT t a a' |
683 | 686 | ||
684 | cstr = cstrT_ TType | 687 | cstr = cstrT_ TType |
@@ -693,9 +696,11 @@ cstrT_ typ = cstr__ [] | |||
693 | cstr_ ns a (FixLabel _ a') = cstr_ ns a a' | 696 | cstr_ ns a (FixLabel _ a') = cstr_ ns a a' |
694 | -- cstr_ ns (PMLabel a _) a' = cstr_ ns a a' | 697 | -- cstr_ ns (PMLabel a _) a' = cstr_ ns a a' |
695 | -- cstr_ ns a (PMLabel a' _) = cstr_ ns a a' | 698 | -- cstr_ ns a (PMLabel a' _) = cstr_ ns a a' |
696 | cstr_ ns TType TType = Unit | 699 | -- cstr_ ns TType TType = Unit |
697 | cstr_ ns (Con a []) (Con a' []) | a == a' = Unit | 700 | cstr_ ns (Con a xs) (Con a' xs') | a == a' = foldr t2 Unit $ zipWith (cstr__ ns) xs xs' |
698 | cstr_ ns (TyCon a []) (TyCon a' []) | a == a' = Unit | 701 | cstr_ [] (TyConN "'FrameBuffer" [a, b]) (TyConN "'FrameBuffer" [a', b']) = t2 (cstrT TNat a a') (cstr__ [] b b') -- todo: elim |
702 | cstr_ ns (TyCon a xs) (TyCon a' xs') | a == a' = foldr t2 Unit $ zipWith (cstr__ ns) xs xs' | ||
703 | -- cstr_ ns (TyCon a []) (TyCon a' []) | a == a' = Unit | ||
699 | cstr_ ns (Var i) (Var i') | i == i', i < length ns = Unit | 704 | cstr_ ns (Var i) (Var i') | i == i', i < length ns = Unit |
700 | cstr_ (_: ns) (downE 0 -> Just a) (downE 0 -> Just a') = cstr__ ns a a' | 705 | cstr_ (_: ns) (downE 0 -> Just a) (downE 0 -> Just a') = cstr__ ns a a' |
701 | -- cstr_ ((t, t'): ns) (UApp (downE 0 -> Just a) (UVar 0)) (UApp (downE 0 -> Just a') (UVar 0)) = traceInj2 (a, "V0") (a', "V0") $ cstr__ ns a a' | 706 | -- cstr_ ((t, t'): ns) (UApp (downE 0 -> Just a) (UVar 0)) (UApp (downE 0 -> Just a') (UVar 0)) = traceInj2 (a, "V0") (a', "V0") $ cstr__ ns a a' |
@@ -705,7 +710,8 @@ cstrT_ typ = cstr__ [] | |||
705 | cstr_ ns (UBind h a b) (UBind h' a' b') | h == h' = t2 (cstr__ ns a a') (cstr__ ((a, a'): ns) b b') | 710 | cstr_ ns (UBind h a b) (UBind h' a' b') | h == h' = t2 (cstr__ ns a a') (cstr__ ((a, a'): ns) b b') |
706 | -- cstr_ [] t (Meta a b) = Meta a $ cstr_ [] (up1E 0 t) b | 711 | -- cstr_ [] t (Meta a b) = Meta a $ cstr_ [] (up1E 0 t) b |
707 | -- cstr_ [] (Meta a b) t = Meta a $ cstr_ [] b (up1E 0 t) | 712 | -- cstr_ [] (Meta a b) t = Meta a $ cstr_ [] b (up1E 0 t) |
708 | cstr_ ns (unApp -> Just (a, b)) (unApp -> Just (a', b')) = traceInj2 (a, show b) (a', show b') $ t2 (cstr__ ns a a') (cstr__ ns b b') | 713 | -- cstr_ ns (unApp -> Just (a, b)) (unApp -> Just (a', b')) = traceInj2 (a, show b) (a', show b') $ t2 (cstr__ ns a a') (cstr__ ns b b') |
714 | -- cstr_ ns (unApp -> Just (a, b)) (unApp -> Just (a', b')) = traceInj2 (a, show b) (a', show b') $ t2 (cstr__ ns a a') (cstr__ ns b b') | ||
709 | -- cstr_ ns (Label f xs _) (Label f' xs' _) | f == f' = foldr1 T2 $ zipWith (cstr__ ns) xs xs' | 715 | -- cstr_ ns (Label f xs _) (Label f' xs' _) | f == f' = foldr1 T2 $ zipWith (cstr__ ns) xs xs' |
710 | cstr_ [] (UL (FunN "'VecScalar" [a, b])) (TVec a' b') = t2 (cstrT TNat a a') (cstr__ [] b b') | 716 | cstr_ [] (UL (FunN "'VecScalar" [a, b])) (TVec a' b') = t2 (cstrT TNat a a') (cstr__ [] b b') |
711 | cstr_ [] (UL (FunN "'VecScalar" [a, b])) (UL (FunN "'VecScalar" [a', b'])) = t2 (cstrT TNat a a') (cstr__ [] b b') | 717 | cstr_ [] (UL (FunN "'VecScalar" [a, b])) (UL (FunN "'VecScalar" [a', b'])) = t2 (cstrT TNat a a') (cstr__ [] b b') |
@@ -1116,7 +1122,7 @@ getGEnv exs f = do | |||
1116 | src <- ask | 1122 | src <- ask |
1117 | gets (\ge -> EGlobal src ge mempty) >>= f | 1123 | gets (\ge -> EGlobal src ge mempty) >>= f |
1118 | inferTerm exs msg tr f t = getGEnv exs $ \env -> let env' = f env in smartTrace exs $ \tr -> | 1124 | inferTerm exs msg tr f t = getGEnv exs $ \env -> let env' = f env in smartTrace exs $ \tr -> |
1119 | fmap (\t -> if tr_light exs then length (showExp $ fst t) `seq` t else t) $ fmap (addType . recheck msg env') $ replaceMetas "lam" Lam . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) env' t) | 1125 | fmap (addType . recheck msg env') $ replaceMetas "lam" Lam . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) env' t) |
1120 | inferType exs tr t = getGEnv exs $ \env -> fmap (recheck "inferType" env) $ replaceMetas "pi" Pi . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) (CheckType_ (debugSI "inferType CheckType_") TType env) t) | 1126 | inferType exs tr t = getGEnv exs $ \env -> fmap (recheck "inferType" env) $ replaceMetas "pi" Pi . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) (CheckType_ (debugSI "inferType CheckType_") TType env) t) |
1121 | 1127 | ||
1122 | smartTrace :: MonadError String m => Extensions -> (Bool -> m a) -> m a | 1128 | smartTrace :: MonadError String m => Extensions -> (Bool -> m a) -> m a |
@@ -1198,6 +1204,7 @@ defined' = Map.keys | |||
1198 | addF exs = gets $ addForalls exs . defined' | 1204 | addF exs = gets $ addForalls exs . defined' |
1199 | 1205 | ||
1200 | fixType = Pi Hidden TType $ Pi Visible (Pi Visible (Var 0) (Var 1)) (Var 1) -- forall a . (a -> a) -> a | 1206 | fixType = Pi Hidden TType $ Pi Visible (Pi Visible (Var 0) (Var 1)) (Var 1) -- forall a . (a -> a) -> a |
1207 | fixTerm = lamify fixType $ TFun "f_i_x" fixType | ||
1201 | 1208 | ||
1202 | addLams' x [] _ e = Fun x $ reverse e | 1209 | addLams' x [] _ e = Fun x $ reverse e |
1203 | addLams' x (h: ar) (Pi h' d t) e | h == h' = Lam h d $ addLams' x ar t (Var 0: map (up1E 0) e) | 1210 | addLams' x (h: ar) (Pi h' d t) e | h == h' = Lam h d $ addLams' x ar t (Var 0: map (up1E 0) e) |
@@ -1229,18 +1236,22 @@ handleStmt exs = \case | |||
1229 | -- recursive let | 1236 | -- recursive let |
1230 | Let (si, n) mf mt ar t_ -> do | 1237 | Let (si, n) mf mt ar t_ -> do |
1231 | af <- addF exs | 1238 | af <- addF exs |
1232 | (x@(Lam Hidden _ e), _) | 1239 | (Lam Hidden _ e, _) |
1233 | <- inferTerm exs n tr (EBind2 BMeta fixType) (SAppV (SVar si 0) $ SLamV $ maybe id (flip SAnn . af) mt t_) | 1240 | <- inferTerm exs n tr (EBind2 BMeta fixType) $ SAppV (SVar si 0) $ SLamV $ maybe id (flip SAnn . af) mt t_ |
1234 | let | 1241 | let |
1235 | par i (Hidden: ar) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z | 1242 | par i (Hidden: ar) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z |
1236 | par i ar@(Visible: _) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z | 1243 | par i ar@(Visible: _) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z |
1237 | par i ar tt (Var i' `App` _ `App` f) | i == i' = x where | 1244 | par i ar tt (Var i' `App` _ `App` Lam' f) | i == i' |
1238 | x = label LabelPM (addLams' (FunName n mf t) ar tt $ reverse $ downTo 0 i) $ label LabelFix (addLams' (FunName n mf t) ar tt $ reverse $ downTo 0 i) $ f `app_` x | 1245 | = substE "let2" 0 (label LabelFix (addLams' fname ar tt $ reverse $ downTo 0 i) (foldl app_ term $ downTo 0 i)) f |
1239 | 1246 | ||
1240 | x' = x `app_` (Lam Hidden TType $ Lam Visible (Pi Visible (Var 0) (Var 1)) $ TFun "f_i_x" fixType [Var 1, Var 0]) | 1247 | fname = FunName n mf t |
1241 | t = expType x' | 1248 | alt = addLams' fname ar t [] |
1249 | |||
1250 | term = label LabelPM alt $ par 0 ar t e | ||
1251 | |||
1252 | t = expType $ substE "let" 0 fixTerm e | ||
1242 | tellStmtType exs si t | 1253 | tellStmtType exs si t |
1243 | addToEnv exs (si, n) (par 0 ar t e, traceD ("addToEnv: " ++ n ++ " = " ++ showExp x') t) | 1254 | addToEnv exs (si, n) (term, t) |
1244 | TypeFamily s ps t -> handleStmt exs $ Primitive s Nothing $ addParamsS ps t | 1255 | TypeFamily s ps t -> handleStmt exs $ Primitive s Nothing $ addParamsS ps t |
1245 | Data (si,s) ps t_ addfa cs -> do | 1256 | Data (si,s) ps t_ addfa cs -> do |
1246 | af <- if addfa then gets $ addForalls exs . (s:) . defined' else return id | 1257 | af <- if addfa then gets $ addForalls exs . (s:) . defined' else return id |