summaryrefslogtreecommitdiff
path: root/src/LambdaCube/Compiler/Infer.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/LambdaCube/Compiler/Infer.hs')
-rw-r--r--src/LambdaCube/Compiler/Infer.hs37
1 files changed, 24 insertions, 13 deletions
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs
index f39c9332..0b922c3b 100644
--- a/src/LambdaCube/Compiler/Infer.hs
+++ b/src/LambdaCube/Compiler/Infer.hs
@@ -678,7 +678,10 @@ reflCstr te = \case
678-} 678-}
679 x -> TT 679 x -> TT
680 680
681cstrT t a a' | a == a' = Unit 681cstrT t (UL a) (UL a') | a == a' = Unit
682cstrT t (ConN "Succ" [a]) (ConN "Succ" [a']) = cstrT TNat a a'
683cstrT t (FixLabel _ a) a' = cstrT t a a'
684cstrT t a (FixLabel _ a') = cstrT t a a'
682cstrT t a a' = CstrT t a a' 685cstrT t a a' = CstrT t a a'
683 686
684cstr = cstrT_ TType 687cstr = cstrT_ TType
@@ -693,9 +696,11 @@ cstrT_ typ = cstr__ []
693 cstr_ ns a (FixLabel _ a') = cstr_ ns a a' 696 cstr_ ns a (FixLabel _ a') = cstr_ ns a a'
694-- cstr_ ns (PMLabel a _) a' = cstr_ ns a a' 697-- cstr_ ns (PMLabel a _) a' = cstr_ ns a a'
695-- cstr_ ns a (PMLabel a' _) = cstr_ ns a a' 698-- cstr_ ns a (PMLabel a' _) = cstr_ ns a a'
696 cstr_ ns TType TType = Unit 699-- cstr_ ns TType TType = Unit
697 cstr_ ns (Con a []) (Con a' []) | a == a' = Unit 700 cstr_ ns (Con a xs) (Con a' xs') | a == a' = foldr t2 Unit $ zipWith (cstr__ ns) xs xs'
698 cstr_ ns (TyCon a []) (TyCon a' []) | a == a' = Unit 701 cstr_ [] (TyConN "'FrameBuffer" [a, b]) (TyConN "'FrameBuffer" [a', b']) = t2 (cstrT TNat a a') (cstr__ [] b b') -- todo: elim
702 cstr_ ns (TyCon a xs) (TyCon a' xs') | a == a' = foldr t2 Unit $ zipWith (cstr__ ns) xs xs'
703-- cstr_ ns (TyCon a []) (TyCon a' []) | a == a' = Unit
699 cstr_ ns (Var i) (Var i') | i == i', i < length ns = Unit 704 cstr_ ns (Var i) (Var i') | i == i', i < length ns = Unit
700 cstr_ (_: ns) (downE 0 -> Just a) (downE 0 -> Just a') = cstr__ ns a a' 705 cstr_ (_: ns) (downE 0 -> Just a) (downE 0 -> Just a') = cstr__ ns a a'
701-- cstr_ ((t, t'): ns) (UApp (downE 0 -> Just a) (UVar 0)) (UApp (downE 0 -> Just a') (UVar 0)) = traceInj2 (a, "V0") (a', "V0") $ cstr__ ns a a' 706-- cstr_ ((t, t'): ns) (UApp (downE 0 -> Just a) (UVar 0)) (UApp (downE 0 -> Just a') (UVar 0)) = traceInj2 (a, "V0") (a', "V0") $ cstr__ ns a a'
@@ -705,7 +710,8 @@ cstrT_ typ = cstr__ []
705 cstr_ ns (UBind h a b) (UBind h' a' b') | h == h' = t2 (cstr__ ns a a') (cstr__ ((a, a'): ns) b b') 710 cstr_ ns (UBind h a b) (UBind h' a' b') | h == h' = t2 (cstr__ ns a a') (cstr__ ((a, a'): ns) b b')
706-- cstr_ [] t (Meta a b) = Meta a $ cstr_ [] (up1E 0 t) b 711-- cstr_ [] t (Meta a b) = Meta a $ cstr_ [] (up1E 0 t) b
707-- cstr_ [] (Meta a b) t = Meta a $ cstr_ [] b (up1E 0 t) 712-- cstr_ [] (Meta a b) t = Meta a $ cstr_ [] b (up1E 0 t)
708 cstr_ ns (unApp -> Just (a, b)) (unApp -> Just (a', b')) = traceInj2 (a, show b) (a', show b') $ t2 (cstr__ ns a a') (cstr__ ns b b') 713-- cstr_ ns (unApp -> Just (a, b)) (unApp -> Just (a', b')) = traceInj2 (a, show b) (a', show b') $ t2 (cstr__ ns a a') (cstr__ ns b b')
714-- cstr_ ns (unApp -> Just (a, b)) (unApp -> Just (a', b')) = traceInj2 (a, show b) (a', show b') $ t2 (cstr__ ns a a') (cstr__ ns b b')
709-- cstr_ ns (Label f xs _) (Label f' xs' _) | f == f' = foldr1 T2 $ zipWith (cstr__ ns) xs xs' 715-- cstr_ ns (Label f xs _) (Label f' xs' _) | f == f' = foldr1 T2 $ zipWith (cstr__ ns) xs xs'
710 cstr_ [] (UL (FunN "'VecScalar" [a, b])) (TVec a' b') = t2 (cstrT TNat a a') (cstr__ [] b b') 716 cstr_ [] (UL (FunN "'VecScalar" [a, b])) (TVec a' b') = t2 (cstrT TNat a a') (cstr__ [] b b')
711 cstr_ [] (UL (FunN "'VecScalar" [a, b])) (UL (FunN "'VecScalar" [a', b'])) = t2 (cstrT TNat a a') (cstr__ [] b b') 717 cstr_ [] (UL (FunN "'VecScalar" [a, b])) (UL (FunN "'VecScalar" [a', b'])) = t2 (cstrT TNat a a') (cstr__ [] b b')
@@ -1116,7 +1122,7 @@ getGEnv exs f = do
1116 src <- ask 1122 src <- ask
1117 gets (\ge -> EGlobal src ge mempty) >>= f 1123 gets (\ge -> EGlobal src ge mempty) >>= f
1118inferTerm exs msg tr f t = getGEnv exs $ \env -> let env' = f env in smartTrace exs $ \tr -> 1124inferTerm exs msg tr f t = getGEnv exs $ \env -> let env' = f env in smartTrace exs $ \tr ->
1119 fmap (\t -> if tr_light exs then length (showExp $ fst t) `seq` t else t) $ fmap (addType . recheck msg env') $ replaceMetas "lam" Lam . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) env' t) 1125 fmap (addType . recheck msg env') $ replaceMetas "lam" Lam . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) env' t)
1120inferType exs tr t = getGEnv exs $ \env -> fmap (recheck "inferType" env) $ replaceMetas "pi" Pi . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) (CheckType_ (debugSI "inferType CheckType_") TType env) t) 1126inferType exs tr t = getGEnv exs $ \env -> fmap (recheck "inferType" env) $ replaceMetas "pi" Pi . fst =<< lift (lift $ inferN (if tr then trace_level exs else 0) (CheckType_ (debugSI "inferType CheckType_") TType env) t)
1121 1127
1122smartTrace :: MonadError String m => Extensions -> (Bool -> m a) -> m a 1128smartTrace :: MonadError String m => Extensions -> (Bool -> m a) -> m a
@@ -1198,6 +1204,7 @@ defined' = Map.keys
1198addF exs = gets $ addForalls exs . defined' 1204addF exs = gets $ addForalls exs . defined'
1199 1205
1200fixType = Pi Hidden TType $ Pi Visible (Pi Visible (Var 0) (Var 1)) (Var 1) -- forall a . (a -> a) -> a 1206fixType = Pi Hidden TType $ Pi Visible (Pi Visible (Var 0) (Var 1)) (Var 1) -- forall a . (a -> a) -> a
1207fixTerm = lamify fixType $ TFun "f_i_x" fixType
1201 1208
1202addLams' x [] _ e = Fun x $ reverse e 1209addLams' x [] _ e = Fun x $ reverse e
1203addLams' x (h: ar) (Pi h' d t) e | h == h' = Lam h d $ addLams' x ar t (Var 0: map (up1E 0) e) 1210addLams' x (h: ar) (Pi h' d t) e | h == h' = Lam h d $ addLams' x ar t (Var 0: map (up1E 0) e)
@@ -1229,18 +1236,22 @@ handleStmt exs = \case
1229 -- recursive let 1236 -- recursive let
1230 Let (si, n) mf mt ar t_ -> do 1237 Let (si, n) mf mt ar t_ -> do
1231 af <- addF exs 1238 af <- addF exs
1232 (x@(Lam Hidden _ e), _) 1239 (Lam Hidden _ e, _)
1233 <- inferTerm exs n tr (EBind2 BMeta fixType) (SAppV (SVar si 0) $ SLamV $ maybe id (flip SAnn . af) mt t_) 1240 <- inferTerm exs n tr (EBind2 BMeta fixType) $ SAppV (SVar si 0) $ SLamV $ maybe id (flip SAnn . af) mt t_
1234 let 1241 let
1235 par i (Hidden: ar) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z 1242 par i (Hidden: ar) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z
1236 par i ar@(Visible: _) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z 1243 par i ar@(Visible: _) (Pi Hidden _ tt) (Lam Hidden k z) = Lam Hidden k $ par (i+1) ar tt z
1237 par i ar tt (Var i' `App` _ `App` f) | i == i' = x where 1244 par i ar tt (Var i' `App` _ `App` Lam' f) | i == i'
1238 x = label LabelPM (addLams' (FunName n mf t) ar tt $ reverse $ downTo 0 i) $ label LabelFix (addLams' (FunName n mf t) ar tt $ reverse $ downTo 0 i) $ f `app_` x 1245 = substE "let2" 0 (label LabelFix (addLams' fname ar tt $ reverse $ downTo 0 i) (foldl app_ term $ downTo 0 i)) f
1239 1246
1240 x' = x `app_` (Lam Hidden TType $ Lam Visible (Pi Visible (Var 0) (Var 1)) $ TFun "f_i_x" fixType [Var 1, Var 0]) 1247 fname = FunName n mf t
1241 t = expType x' 1248 alt = addLams' fname ar t []
1249
1250 term = label LabelPM alt $ par 0 ar t e
1251
1252 t = expType $ substE "let" 0 fixTerm e
1242 tellStmtType exs si t 1253 tellStmtType exs si t
1243 addToEnv exs (si, n) (par 0 ar t e, traceD ("addToEnv: " ++ n ++ " = " ++ showExp x') t) 1254 addToEnv exs (si, n) (term, t)
1244 TypeFamily s ps t -> handleStmt exs $ Primitive s Nothing $ addParamsS ps t 1255 TypeFamily s ps t -> handleStmt exs $ Primitive s Nothing $ addParamsS ps t
1245 Data (si,s) ps t_ addfa cs -> do 1256 Data (si,s) ps t_ addfa cs -> do
1246 af <- if addfa then gets $ addForalls exs . (s:) . defined' else return id 1257 af <- if addfa then gets $ addForalls exs . (s:) . defined' else return id