diff options
Diffstat (limited to 'src/LambdaCube')
-rw-r--r-- | src/LambdaCube/Compiler/Core.hs | 139 | ||||
-rw-r--r-- | src/LambdaCube/Compiler/Infer.hs | 2 | ||||
-rw-r--r-- | src/LambdaCube/Compiler/InferMonad.hs | 5 |
3 files changed, 73 insertions, 73 deletions
diff --git a/src/LambdaCube/Compiler/Core.hs b/src/LambdaCube/Compiler/Core.hs index eb5a99f1..eb7c59cf 100644 --- a/src/LambdaCube/Compiler/Core.hs +++ b/src/LambdaCube/Compiler/Core.hs | |||
@@ -17,7 +17,7 @@ import Data.Function | |||
17 | import Data.List | 17 | import Data.List |
18 | import Control.Arrow hiding ((<+>)) | 18 | import Control.Arrow hiding ((<+>)) |
19 | 19 | ||
20 | --import LambdaCube.Compiler.Utils | 20 | import LambdaCube.Compiler.Utils |
21 | import LambdaCube.Compiler.DeBruijn | 21 | import LambdaCube.Compiler.DeBruijn |
22 | import LambdaCube.Compiler.Pretty hiding (braces, parens) | 22 | import LambdaCube.Compiler.Pretty hiding (braces, parens) |
23 | import LambdaCube.Compiler.DesugaredSource | 23 | import LambdaCube.Compiler.DesugaredSource |
@@ -35,7 +35,7 @@ data CaseFunName = CaseFunName FName Type Int{-num of parameters-} | |||
35 | data TyCaseFunName = TyCaseFunName FName Type | 35 | data TyCaseFunName = TyCaseFunName FName Type |
36 | 36 | ||
37 | data FunDef | 37 | data FunDef |
38 | = DeltaDef !Int{-arity-} (FreeVars -> [Exp] -> Exp) | 38 | = DeltaDef !Int{-arity-} (FreeVars -> [Exp]{-args in reversed order-} -> Exp) |
39 | | NoDef | 39 | | NoDef |
40 | | ExpDef Exp | 40 | | ExpDef Exp |
41 | 41 | ||
@@ -120,10 +120,6 @@ pattern NoRHS <- (isRHS -> False) | |||
120 | isRHS RHS{} = True | 120 | isRHS RHS{} = True |
121 | isRHS _ = False | 121 | isRHS _ = False |
122 | 122 | ||
123 | -- TODO: elim | ||
124 | pattern Reverse xs <- (reverse -> xs) | ||
125 | where Reverse = reverse | ||
126 | |||
127 | pattern Fun f xs n <- Fun_ _ f xs n | 123 | pattern Fun f xs n <- Fun_ _ f xs n |
128 | where Fun f xs n = Fun_ (foldMap getFreeVars xs) f xs n | 124 | where Fun f xs n = Fun_ (foldMap getFreeVars xs) f xs n |
129 | pattern CaseFun_ a b c <- CaseFun__ _ a b c | 125 | pattern CaseFun_ a b c <- CaseFun__ _ a b c |
@@ -155,13 +151,13 @@ pattern App a b <- Neut (App_ (Neut -> a) b) | |||
155 | pattern DFun a t b = Neut (DFunN a t b) | 151 | pattern DFun a t b = Neut (DFunN a t b) |
156 | 152 | ||
157 | -- unreducable function application | 153 | -- unreducable function application |
158 | pattern UFun a b <- Neut (Fun (FunName (FTag a) _ _ t) (reverse -> b) NoRHS) | 154 | pattern UFun a b <- Neut (Fun (FunName (FTag a) _ _ t) b NoRHS) |
159 | 155 | ||
160 | -- saturated delta function application | 156 | -- saturated delta function application |
161 | pattern DFunN a t xs = DFunN_ (FTag a) t xs | 157 | pattern DFunN a t xs = DFunN_ (FTag a) t xs |
162 | 158 | ||
163 | pattern DFunN_ a t xs <- Fun (FunName' a t) (Reverse xs) _ | 159 | pattern DFunN_ a t xs <- Fun (FunName' a t) xs _ |
164 | where DFunN_ a t xs = Fun (FunName' a t) (Reverse xs) delta | 160 | where DFunN_ a t xs = Fun (FunName' a t) xs delta |
165 | 161 | ||
166 | conParams (conTypeName -> TyConName _ _ _ _ (CaseFunName _ _ pars)) = pars | 162 | conParams (conTypeName -> TyConName _ _ _ _ (CaseFunName _ _ pars)) = pars |
167 | mkConPars n (snd . getParams . hnf -> TyCon (TyConName _ _ _ _ (CaseFunName _ _ pars)) xs) = take (min n pars) xs | 163 | mkConPars n (snd . getParams . hnf -> TyCon (TyConName _ _ _ _ (CaseFunName _ _ pars)) xs) = take (min n pars) xs |
@@ -207,12 +203,12 @@ pattern CEmpty s <- ConN FCEmpty (HString s: _) | |||
207 | where CEmpty s = tCon FCEmpty 1 (TString :~> TConstraint) [HString s] | 203 | where CEmpty s = tCon FCEmpty 1 (TString :~> TConstraint) [HString s] |
208 | 204 | ||
209 | pattern CstrT t a b = Neut (CstrT' t a b) | 205 | pattern CstrT t a b = Neut (CstrT' t a b) |
210 | pattern CstrT' t a b = DFunN F'EqCT (TType :~> Var 0 :~> Var 1 :~> TConstraint) [t, a, b] | 206 | pattern CstrT' t a b = DFunN F'EqCT (TType :~> Var 0 :~> Var 1 :~> TConstraint) [b, a, t] |
211 | pattern Coe a b w x = DFun Fcoe (TType :~> TType :~> CW (CstrT TType (Var 1) (Var 0)) :~> Var 2 :~> Var 2) [a,b,w,x] | 207 | pattern Coe a b w x = DFun Fcoe (TType :~> TType :~> CW (CstrT TType (Var 1) (Var 0)) :~> Var 2 :~> Var 2) [x,w,b,a] |
212 | pattern ParEval t a b = DFun FparEval (TType :~> Var 0 :~> Var 1 :~> Var 2) [t, a, b] | 208 | pattern ParEval t a b = DFun FparEval (TType :~> Var 0 :~> Var 1 :~> Var 2) [b, a, t] |
213 | pattern T2 a b = DFun F'T2 (TConstraint :~> TConstraint :~> TConstraint) [a, b] | 209 | pattern T2 a b = DFun F'T2 (TConstraint :~> TConstraint :~> TConstraint) [b, a] |
214 | pattern CW a = DFun F'CW (TConstraint :~> TType) [a] | 210 | pattern CW a = DFun F'CW (TConstraint :~> TType) [a] |
215 | pattern CSplit a b c <- UFun F'Split [a, b, c] | 211 | pattern CSplit a b c <- UFun F'Split [c, b, a] |
216 | 212 | ||
217 | pattern HLit a <- (hnf -> ELit a) | 213 | pattern HLit a <- (hnf -> ELit a) |
218 | where HLit = ELit | 214 | where HLit = ELit |
@@ -247,7 +243,7 @@ mkOrdering x = case x of | |||
247 | conTypeName :: ConName -> TyConName | 243 | conTypeName :: ConName -> TyConName |
248 | conTypeName (ConName _ _ t) = case snd $ getParams t of TyCon n _ -> n | 244 | conTypeName (ConName _ _ t) = case snd $ getParams t of TyCon n _ -> n |
249 | 245 | ||
250 | mkFun_ md (FunName _ _ (DeltaDef ar f) _) as _ | length as == ar = f md $ reverse as | 246 | mkFun_ md (FunName _ _ (DeltaDef ar f) _) as _ | length as == ar = f md as |
251 | mkFun_ md f xs y = Neut $ Fun_ md f xs $ hnf y | 247 | mkFun_ md f xs y = Neut $ Fun_ md f xs $ hnf y |
252 | 248 | ||
253 | mkFun :: FunName -> [Exp] -> Exp -> Exp | 249 | mkFun :: FunName -> [Exp] -> Exp -> Exp |
@@ -440,7 +436,7 @@ getFixLam _ = Nothing | |||
440 | instance MkDoc Neutral where | 436 | instance MkDoc Neutral where |
441 | mkDoc pr@(reduce, body) = \case | 437 | mkDoc pr@(reduce, body) = \case |
442 | CstrT' t a b -> shCstr (mkDoc pr a) (mkDoc pr (ET b t)) | 438 | CstrT' t a b -> shCstr (mkDoc pr a) (mkDoc pr (ET b t)) |
443 | Fun (FunName _ _ (ExpDef d) _) xs _ | body -> mkDoc (reduce, False) (foldl app_ d $ reverse xs) | 439 | Fun (FunName _ _ (ExpDef d) _) xs _ | body -> mkDoc (reduce, False) (foldlrev app_ d xs) |
444 | FFix (getFixLam -> Just (s, xs)) | not body -> foldl DApp (pShow s) $ mkDoc pr <$> xs | 440 | FFix (getFixLam -> Just (s, xs)) | not body -> foldl DApp (pShow s) $ mkDoc pr <$> xs |
445 | FFix f {- | body -} -> foldl DApp "primFix" [{-pShow t -}"_", mkDoc pr f] | 441 | FFix f {- | body -} -> foldl DApp "primFix" [{-pShow t -}"_", mkDoc pr f] |
446 | Fun (FunName _ _ (DeltaDef n _) _) _ _ | body -> text $ "<<delta function with arity " ++ show n ++ ">>" | 442 | Fun (FunName _ _ (DeltaDef n _) _) _ _ | body -> text $ "<<delta function with arity " ++ show n ++ ">>" |
@@ -476,79 +472,80 @@ pattern FunName' a t <- FunName a _ _ t | |||
476 | mkFunDef a@(show -> "primFix") t = fn | 472 | mkFunDef a@(show -> "primFix") t = fn |
477 | where | 473 | where |
478 | fn = FunName a 0 (DeltaDef (length $ fst $ getParams t) fx) t | 474 | fn = FunName a 0 (DeltaDef (length $ fst $ getParams t) fx) t |
479 | fx s xs = Neut $ Fun_ s fn (reverse xs) $ case xs of | 475 | fx s xs = Neut $ Fun_ s fn xs $ case xs of |
480 | _: f: _ -> RHS x where x = f `app_` Neut (Fun_ s fn (reverse xs) $ RHS x) | 476 | f: _{-1-} -> RHS x where x = f `app_` Neut (Fun_ s fn xs $ RHS x) |
481 | _ -> delta | 477 | _ -> delta |
482 | 478 | ||
483 | mkFunDef a t = fn | 479 | mkFunDef a t = fn |
484 | where | 480 | where |
485 | fn = FunName a 0 (maybe NoDef (DeltaDef (length $ fst $ getParams t) . const) $ getFunDef t a $ \xs -> Neut $ Fun fn (reverse xs) delta) t | 481 | fn = FunName a 0 (maybe NoDef (DeltaDef (length $ fst $ getParams t) . const) $ getFunDef t a $ \xs -> Neut $ Fun fn xs delta) t |
486 | 482 | ||
487 | getFunDef t s f = case show s of | 483 | getFunDef t s f = case show s of |
488 | "'EqCT" -> Just $ \case (t: a: b: _) -> cstr t a b | 484 | "'EqCT" -> Just $ \case (b: a: t: _) -> cstr t a b |
489 | "'T2" -> Just $ \case (a: b: _) -> t2 a b | 485 | "'T2" -> Just $ \case (b: a: _) -> t2 a b |
490 | "'CW" -> Just $ \case (a: _) -> cw a | 486 | "'CW" -> Just $ \case (a: _) -> cw a |
491 | "t2C" -> Just $ \case (a: b: _) -> t2C a b | 487 | "t2C" -> Just $ \case (b: a: _) -> t2C a b |
492 | "coe" -> Just $ \case (a: b: t: d: _) -> evalCoe a b t d | 488 | "coe" -> Just $ \case (d: t: b: a: _) -> evalCoe a b t d |
493 | "parEval" -> Just $ \case (t: a: b: _) -> parEval t a b | 489 | "parEval" -> Just $ \case (b: a: t: _) -> parEval t a b |
494 | where | 490 | where |
495 | parEval _ x@RHS{} _ = x | 491 | parEval _ x@RHS{} _ = x |
496 | parEval _ _ x@RHS{} = x | 492 | parEval _ _ x@RHS{} = x |
497 | parEval t a b = ParEval t a b | 493 | parEval t a b = ParEval t a b |
498 | 494 | ||
499 | "unsafeCoerce" -> Just $ \case xs@(_: _: x@(hnf -> NonNeut): _) -> x; xs -> f xs | 495 | "unsafeCoerce" -> Just $ \case xs@(x@(hnf -> NonNeut): _{-2-}) -> x; xs -> f xs |
500 | "reflCstr" -> Just $ \case (a: _) -> TT | 496 | "reflCstr" -> Just $ \case _ -> TT |
501 | "hlistNilCase" -> Just $ \case (_: x: (hnf -> Con n@(ConName _ 0 _) _ _): _) -> x; xs -> f xs | 497 | "hlistNilCase" -> Just $ \case ((hnf -> Con n@(ConName _ 0 _) _ _): x: _{-1-}) -> x; xs -> f xs |
502 | "hlistConsCase" -> Just $ \case (_: _: _: x: (hnf -> Con n@(ConName _ 1 _) _ (_: _: a: b: _)): _) -> x `app_` a `app_` b; xs -> f xs | 498 | "hlistConsCase" -> Just $ \case ((hnf -> Con n@(ConName _ 1 _) _ (_: _: a: b: _)): x: _{-3-}) -> x `app_` a `app_` b; xs -> f xs |
503 | 499 | ||
504 | -- general compiler primitives | 500 | -- general compiler primitives |
505 | "primAddInt" -> Just $ \case (HInt i: HInt j: _) -> HInt (i + j); xs -> f xs | 501 | "primAddInt" -> Just $ \case (HInt j: HInt i: _) -> HInt (i + j); xs -> f xs |
506 | "primSubInt" -> Just $ \case (HInt i: HInt j: _) -> HInt (i - j); xs -> f xs | 502 | "primSubInt" -> Just $ \case (HInt j: HInt i: _) -> HInt (i - j); xs -> f xs |
507 | "primModInt" -> Just $ \case (HInt i: HInt j: _) -> HInt (i `mod` j); xs -> f xs | 503 | "primModInt" -> Just $ \case (HInt j: HInt i: _) -> HInt (i `mod` j); xs -> f xs |
508 | "primSqrtFloat" -> Just $ \case (HFloat i: _) -> HFloat $ sqrt i; xs -> f xs | 504 | "primSqrtFloat" -> Just $ \case (HFloat i: _) -> HFloat $ sqrt i; xs -> f xs |
509 | "primRound" -> Just $ \case (HFloat i: _) -> HInt $ round i; xs -> f xs | 505 | "primRound" -> Just $ \case (HFloat i: _) -> HInt $ round i; xs -> f xs |
510 | "primIntToFloat" -> Just $ \case (HInt i: _) -> HFloat $ fromIntegral i; xs -> f xs | 506 | "primIntToFloat" -> Just $ \case (HInt i: _) -> HFloat $ fromIntegral i; xs -> f xs |
511 | "primIntToNat" -> Just $ \case (HInt i: _) -> ENat $ fromIntegral i; xs -> f xs | 507 | "primIntToNat" -> Just $ \case (HInt i: _) -> ENat $ fromIntegral i; xs -> f xs |
512 | "primCompareInt" -> Just $ \case (HInt x: HInt y: _) -> mkOrdering $ x `compare` y; xs -> f xs | 508 | "primCompareInt" -> Just $ \case (HInt y: HInt x: _) -> mkOrdering $ x `compare` y; xs -> f xs |
513 | "primCompareFloat" -> Just $ \case (HFloat x: HFloat y: _) -> mkOrdering $ x `compare` y; xs -> f xs | 509 | "primCompareFloat" -> Just $ \case (HFloat y: HFloat x: _) -> mkOrdering $ x `compare` y; xs -> f xs |
514 | "primCompareChar" -> Just $ \case (HChar x: HChar y: _) -> mkOrdering $ x `compare` y; xs -> f xs | 510 | "primCompareChar" -> Just $ \case (HChar y: HChar x: _) -> mkOrdering $ x `compare` y; xs -> f xs |
515 | "primCompareString" -> Just $ \case (HString x: HString y: _) -> mkOrdering $ x `compare` y; xs -> f xs | 511 | "primCompareString" -> Just $ \case (HString y: HString x: _) -> mkOrdering $ x `compare` y; xs -> f xs |
516 | 512 | ||
517 | -- LambdaCube 3D specific primitives | 513 | -- LambdaCube 3D specific primitives |
518 | "PrimGreaterThan" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (>) t x y -> r; xs -> f xs | 514 | "PrimGreaterThan" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (>) x y -> r; xs -> f xs |
519 | "PrimGreaterThanEqual" | 515 | "PrimGreaterThanEqual" |
520 | -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (>=) t x y -> r; xs -> f xs | 516 | -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (>=) x y -> r; xs -> f xs |
521 | "PrimLessThan" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (<) t x y -> r; xs -> f xs | 517 | "PrimLessThan" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (<) x y -> r; xs -> f xs |
522 | "PrimLessThanEqual" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (<=) t x y -> r; xs -> f xs | 518 | "PrimLessThanEqual" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (<=) x y -> r; xs -> f xs |
523 | "PrimEqualV" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (==) t x y -> r; xs -> f xs | 519 | "PrimEqualV" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (==) x y -> r; xs -> f xs |
524 | "PrimNotEqualV" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (/=) t x y -> r; xs -> f xs | 520 | "PrimNotEqualV" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (/=) x y -> r; xs -> f xs |
525 | "PrimEqual" -> Just $ \case (t: _: _: x: y: _) | Just r <- twoOpBool (==) t x y -> r; xs -> f xs | 521 | "PrimEqual" -> Just $ \case (y: x: _{-3-}) | Just r <- twoOpBool (==) x y -> r; xs -> f xs |
526 | "PrimNotEqual" -> Just $ \case (t: _: _: x: y: _) | Just r <- twoOpBool (/=) t x y -> r; xs -> f xs | 522 | "PrimNotEqual" -> Just $ \case (y: x: _{-3-}) | Just r <- twoOpBool (/=) x y -> r; xs -> f xs |
527 | "PrimSubS" -> Just $ \case (_: _: _: _: x: y: _) | Just r <- twoOp (-) x y -> r; xs -> f xs | 523 | "PrimSubS" -> Just $ \case (y: x: _{-4-}) | Just r <- twoOp (-) x y -> r; xs -> f xs |
528 | "PrimSub" -> Just $ \case (_: _: x: y: _) | Just r <- twoOp (-) x y -> r; xs -> f xs | 524 | "PrimSub" -> Just $ \case (y: x: _{-2-}) | Just r <- twoOp (-) x y -> r; xs -> f xs |
529 | "PrimAddS" -> Just $ \case (_: _: _: _: x: y: _) | Just r <- twoOp (+) x y -> r; xs -> f xs | 525 | "PrimAddS" -> Just $ \case (y: x: _{-4-}) | Just r <- twoOp (+) x y -> r; xs -> f xs |
530 | "PrimAdd" -> Just $ \case (_: _: x: y: _) | Just r <- twoOp (+) x y -> r; xs -> f xs | 526 | "PrimAdd" -> Just $ \case (y: x: _{-2-}) | Just r <- twoOp (+) x y -> r; xs -> f xs |
531 | "PrimMulS" -> Just $ \case (_: _: _: _: x: y: _) | Just r <- twoOp (*) x y -> r; xs -> f xs | 527 | "PrimMulS" -> Just $ \case (y: x: _{-4-}) | Just r <- twoOp (*) x y -> r; xs -> f xs |
532 | "PrimMul" -> Just $ \case (_: _: x: y: _) | Just r <- twoOp (*) x y -> r; xs -> f xs | 528 | "PrimMul" -> Just $ \case (y: x: _{-2-}) | Just r <- twoOp (*) x y -> r; xs -> f xs |
533 | "PrimDivS" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs | 529 | "PrimDivS" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs |
534 | "PrimDiv" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs | 530 | "PrimDiv" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs |
535 | "PrimModS" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs | 531 | "PrimModS" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs |
536 | "PrimMod" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs | 532 | "PrimMod" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs |
537 | "PrimNeg" -> Just $ \case (_: x: _) | Just r <- oneOp negate x -> r; xs -> f xs | 533 | "PrimNeg" -> Just $ \case (x: _{-1-}) | Just r <- oneOp negate x -> r; xs -> f xs |
538 | "PrimAnd" -> Just $ \case (EBool x: EBool y: _) -> EBool (x && y); xs -> f xs | 534 | "PrimAnd" -> Just $ \case (EBool y: EBool x: _) -> EBool (x && y); xs -> f xs |
539 | "PrimOr" -> Just $ \case (EBool x: EBool y: _) -> EBool (x || y); xs -> f xs | 535 | "PrimOr" -> Just $ \case (EBool y: EBool x: _) -> EBool (x || y); xs -> f xs |
540 | "PrimXor" -> Just $ \case (EBool x: EBool y: _) -> EBool (x /= y); xs -> f xs | 536 | "PrimXor" -> Just $ \case (EBool y: EBool x: _) -> EBool (x /= y); xs -> f xs |
541 | "PrimNot" -> Just $ \case ((hnf -> TNat): _: _: EBool x: _) -> EBool $ not x; xs -> f xs | 537 | "PrimNot" -> Just $ \case (EBool x: _: _: (hnf -> TNat): _) -> EBool $ not x; xs -> f xs |
542 | 538 | ||
543 | _ -> Nothing | 539 | _ -> Nothing |
544 | where | 540 | where |
545 | twoOpBool :: (forall a . Ord a => a -> a -> Bool) -> Exp -> Exp -> Exp -> Maybe Exp | 541 | |
546 | twoOpBool f t (HFloat x) (HFloat y) = Just $ EBool $ f x y | 542 | twoOpBool :: (forall a . Ord a => a -> a -> Bool) -> Exp -> Exp -> Maybe Exp |
547 | twoOpBool f t (HInt x) (HInt y) = Just $ EBool $ f x y | 543 | twoOpBool f (HFloat x) (HFloat y) = Just $ EBool $ f x y |
548 | twoOpBool f t (HString x) (HString y) = Just $ EBool $ f x y | 544 | twoOpBool f (HInt x) (HInt y) = Just $ EBool $ f x y |
549 | twoOpBool f t (HChar x) (HChar y) = Just $ EBool $ f x y | 545 | twoOpBool f (HString x) (HString y) = Just $ EBool $ f x y |
550 | twoOpBool f t (ENat x) (ENat y) = Just $ EBool $ f x y | 546 | twoOpBool f (HChar x) (HChar y) = Just $ EBool $ f x y |
551 | twoOpBool _ _ _ _ = Nothing | 547 | twoOpBool f (ENat x) (ENat y) = Just $ EBool $ f x y |
548 | twoOpBool _ _ _ = Nothing | ||
552 | 549 | ||
553 | oneOp :: (forall a . Num a => a -> a) -> Exp -> Maybe Exp | 550 | oneOp :: (forall a . Num a => a -> a) -> Exp -> Maybe Exp |
554 | oneOp f = oneOp_ f f | 551 | oneOp f = oneOp_ f f |
@@ -606,11 +603,11 @@ cstr = f [] | |||
606 | f_ (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a' | 603 | f_ (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a' |
607 | f_ ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b') | 604 | f_ ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b') |
608 | 605 | ||
609 | f_ [] TType (UFun F'VecScalar [a, b]) (UFun F'VecScalar [a', b']) = t2 (f [] TNat a a') (f [] TType b b') | 606 | f_ [] TType (UFun F'VecScalar [b, a]) (UFun F'VecScalar [b', a']) = t2 (f [] TNat a a') (f [] TType b b') |
610 | f_ [] TType (UFun F'VecScalar [a, b]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b') | 607 | f_ [] TType (UFun F'VecScalar [b, a]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b') |
611 | f_ [] TType (UFun F'VecScalar [a, b]) t@NonNeut = t2 (f [] TNat a (ENat 1)) (f [] TType b t) | 608 | f_ [] TType (UFun F'VecScalar [b, a]) t@NonNeut = t2 (f [] TNat a (ENat 1)) (f [] TType b t) |
612 | f_ [] TType (TVec a' b') (UFun F'VecScalar [a, b]) = t2 (f [] TNat a' a) (f [] TType b' b) | 609 | f_ [] TType (TVec a' b') (UFun F'VecScalar [b, a]) = t2 (f [] TNat a' a) (f [] TType b' b) |
613 | f_ [] TType t@NonNeut (UFun F'VecScalar [a, b]) = t2 (f [] TNat a (ENat 1)) (f [] TType b t) | 610 | f_ [] TType t@NonNeut (UFun F'VecScalar [b, a]) = t2 (f [] TNat a (ENat 1)) (f [] TType b t) |
614 | 611 | ||
615 | f_ [] typ a@Neut{} a' = CstrT typ a a' | 612 | f_ [] typ a@Neut{} a' = CstrT typ a a' |
616 | f_ [] typ a a'@Neut{} = CstrT typ a a' | 613 | f_ [] typ a a'@Neut{} = CstrT typ a a' |
@@ -626,7 +623,7 @@ nonNeut Neut{} = False | |||
626 | nonNeut _ = True | 623 | nonNeut _ = True |
627 | 624 | ||
628 | t2C (hnf -> TT) (hnf -> TT) = TT | 625 | t2C (hnf -> TT) (hnf -> TT) = TT |
629 | t2C a b = DFun Ft2C (Unit :~> Unit :~> Unit) [a, b] | 626 | t2C a b = DFun Ft2C (Unit :~> Unit :~> Unit) [b, a] |
630 | 627 | ||
631 | cw (hnf -> CUnit) = Unit | 628 | cw (hnf -> CUnit) = Unit |
632 | cw (hnf -> CEmpty a) = Empty a | 629 | cw (hnf -> CEmpty a) = Empty a |
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs index ecd18b46..f726f30b 100644 --- a/src/LambdaCube/Compiler/Infer.hs +++ b/src/LambdaCube/Compiler/Infer.hs | |||
@@ -500,7 +500,7 @@ handleStmt = \case | |||
500 | Primitive n t_ -> do | 500 | Primitive n t_ -> do |
501 | t <- inferType n $ trSExp' t_ | 501 | t <- inferType n $ trSExp' t_ |
502 | tellType (sourceInfo n) t | 502 | tellType (sourceInfo n) t |
503 | addToEnv n $ flip ET t $ lamify t $ Neut . DFunN_ (FName n) t | 503 | addToEnv n $ flip ET t $ lamify' t $ Neut . DFunN_ (FName n) t |
504 | StLet n mt t_ -> do | 504 | StLet n mt t_ -> do |
505 | let t__ = maybe id (flip SAnn) mt t_ | 505 | let t__ = maybe id (flip SAnn) mt t_ |
506 | ET x t <- inferTerm n $ trSExp' t__ | 506 | ET x t <- inferTerm n $ trSExp' t__ |
diff --git a/src/LambdaCube/Compiler/InferMonad.hs b/src/LambdaCube/Compiler/InferMonad.hs index e2895389..8d46f7c7 100644 --- a/src/LambdaCube/Compiler/InferMonad.hs +++ b/src/LambdaCube/Compiler/InferMonad.hs | |||
@@ -17,7 +17,7 @@ module LambdaCube.Compiler.InferMonad where | |||
17 | 17 | ||
18 | import Data.Monoid | 18 | import Data.Monoid |
19 | import Data.List | 19 | import Data.List |
20 | import Data.Maybe | 20 | --import Data.Maybe |
21 | import qualified Data.Set as Set | 21 | import qualified Data.Set as Set |
22 | import qualified Data.Map as Map | 22 | import qualified Data.Map as Map |
23 | 23 | ||
@@ -143,10 +143,13 @@ addLams ps t = foldr (const Lam) t ps | |||
143 | 143 | ||
144 | lamify t x = addLams (fst $ getParams t) $ x $ downTo 0 $ arity t | 144 | lamify t x = addLams (fst $ getParams t) $ x $ downTo 0 $ arity t |
145 | 145 | ||
146 | lamify' t x = addLams (fst $ getParams t) $ x $ downTo' 0 $ arity t | ||
147 | |||
146 | arity :: Exp -> Int | 148 | arity :: Exp -> Int |
147 | arity = length . fst . getParams | 149 | arity = length . fst . getParams |
148 | 150 | ||
149 | downTo n m = map Var [n+m-1, n+m-2..n] | 151 | downTo n m = map Var [n+m-1, n+m-2..n] |
152 | downTo' n m = map Var [n, n+1..n+m-1] | ||
150 | 153 | ||
151 | withEnv e = local $ second (<> e) | 154 | withEnv e = local $ second (<> e) |
152 | 155 | ||