summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPéter Diviánszky <divipp@gmail.com>2016-05-13 23:43:31 +0200
committerPéter Diviánszky <divipp@gmail.com>2016-05-13 23:43:31 +0200
commitee858ba089b2f8f582f86bdff38893b6ed17bd01 (patch)
tree415794bbe0d6c69b2d3adbce0bc905035ad5b073 /src
parentabf6df57ab706f7f27d35b406702bccabc892ce5 (diff)
refactoring: use less reverse call
Diffstat (limited to 'src')
-rw-r--r--src/LambdaCube/Compiler/Core.hs139
-rw-r--r--src/LambdaCube/Compiler/Infer.hs2
-rw-r--r--src/LambdaCube/Compiler/InferMonad.hs5
3 files changed, 73 insertions, 73 deletions
diff --git a/src/LambdaCube/Compiler/Core.hs b/src/LambdaCube/Compiler/Core.hs
index eb5a99f1..eb7c59cf 100644
--- a/src/LambdaCube/Compiler/Core.hs
+++ b/src/LambdaCube/Compiler/Core.hs
@@ -17,7 +17,7 @@ import Data.Function
17import Data.List 17import Data.List
18import Control.Arrow hiding ((<+>)) 18import Control.Arrow hiding ((<+>))
19 19
20--import LambdaCube.Compiler.Utils 20import LambdaCube.Compiler.Utils
21import LambdaCube.Compiler.DeBruijn 21import LambdaCube.Compiler.DeBruijn
22import LambdaCube.Compiler.Pretty hiding (braces, parens) 22import LambdaCube.Compiler.Pretty hiding (braces, parens)
23import LambdaCube.Compiler.DesugaredSource 23import LambdaCube.Compiler.DesugaredSource
@@ -35,7 +35,7 @@ data CaseFunName = CaseFunName FName Type Int{-num of parameters-}
35data TyCaseFunName = TyCaseFunName FName Type 35data TyCaseFunName = TyCaseFunName FName Type
36 36
37data FunDef 37data FunDef
38 = DeltaDef !Int{-arity-} (FreeVars -> [Exp] -> Exp) 38 = DeltaDef !Int{-arity-} (FreeVars -> [Exp]{-args in reversed order-} -> Exp)
39 | NoDef 39 | NoDef
40 | ExpDef Exp 40 | ExpDef Exp
41 41
@@ -120,10 +120,6 @@ pattern NoRHS <- (isRHS -> False)
120isRHS RHS{} = True 120isRHS RHS{} = True
121isRHS _ = False 121isRHS _ = False
122 122
123-- TODO: elim
124pattern Reverse xs <- (reverse -> xs)
125 where Reverse = reverse
126
127pattern Fun f xs n <- Fun_ _ f xs n 123pattern Fun f xs n <- Fun_ _ f xs n
128 where Fun f xs n = Fun_ (foldMap getFreeVars xs) f xs n 124 where Fun f xs n = Fun_ (foldMap getFreeVars xs) f xs n
129pattern CaseFun_ a b c <- CaseFun__ _ a b c 125pattern CaseFun_ a b c <- CaseFun__ _ a b c
@@ -155,13 +151,13 @@ pattern App a b <- Neut (App_ (Neut -> a) b)
155pattern DFun a t b = Neut (DFunN a t b) 151pattern DFun a t b = Neut (DFunN a t b)
156 152
157-- unreducable function application 153-- unreducable function application
158pattern UFun a b <- Neut (Fun (FunName (FTag a) _ _ t) (reverse -> b) NoRHS) 154pattern UFun a b <- Neut (Fun (FunName (FTag a) _ _ t) b NoRHS)
159 155
160-- saturated delta function application 156-- saturated delta function application
161pattern DFunN a t xs = DFunN_ (FTag a) t xs 157pattern DFunN a t xs = DFunN_ (FTag a) t xs
162 158
163pattern DFunN_ a t xs <- Fun (FunName' a t) (Reverse xs) _ 159pattern DFunN_ a t xs <- Fun (FunName' a t) xs _
164 where DFunN_ a t xs = Fun (FunName' a t) (Reverse xs) delta 160 where DFunN_ a t xs = Fun (FunName' a t) xs delta
165 161
166conParams (conTypeName -> TyConName _ _ _ _ (CaseFunName _ _ pars)) = pars 162conParams (conTypeName -> TyConName _ _ _ _ (CaseFunName _ _ pars)) = pars
167mkConPars n (snd . getParams . hnf -> TyCon (TyConName _ _ _ _ (CaseFunName _ _ pars)) xs) = take (min n pars) xs 163mkConPars n (snd . getParams . hnf -> TyCon (TyConName _ _ _ _ (CaseFunName _ _ pars)) xs) = take (min n pars) xs
@@ -207,12 +203,12 @@ pattern CEmpty s <- ConN FCEmpty (HString s: _)
207 where CEmpty s = tCon FCEmpty 1 (TString :~> TConstraint) [HString s] 203 where CEmpty s = tCon FCEmpty 1 (TString :~> TConstraint) [HString s]
208 204
209pattern CstrT t a b = Neut (CstrT' t a b) 205pattern CstrT t a b = Neut (CstrT' t a b)
210pattern CstrT' t a b = DFunN F'EqCT (TType :~> Var 0 :~> Var 1 :~> TConstraint) [t, a, b] 206pattern CstrT' t a b = DFunN F'EqCT (TType :~> Var 0 :~> Var 1 :~> TConstraint) [b, a, t]
211pattern Coe a b w x = DFun Fcoe (TType :~> TType :~> CW (CstrT TType (Var 1) (Var 0)) :~> Var 2 :~> Var 2) [a,b,w,x] 207pattern Coe a b w x = DFun Fcoe (TType :~> TType :~> CW (CstrT TType (Var 1) (Var 0)) :~> Var 2 :~> Var 2) [x,w,b,a]
212pattern ParEval t a b = DFun FparEval (TType :~> Var 0 :~> Var 1 :~> Var 2) [t, a, b] 208pattern ParEval t a b = DFun FparEval (TType :~> Var 0 :~> Var 1 :~> Var 2) [b, a, t]
213pattern T2 a b = DFun F'T2 (TConstraint :~> TConstraint :~> TConstraint) [a, b] 209pattern T2 a b = DFun F'T2 (TConstraint :~> TConstraint :~> TConstraint) [b, a]
214pattern CW a = DFun F'CW (TConstraint :~> TType) [a] 210pattern CW a = DFun F'CW (TConstraint :~> TType) [a]
215pattern CSplit a b c <- UFun F'Split [a, b, c] 211pattern CSplit a b c <- UFun F'Split [c, b, a]
216 212
217pattern HLit a <- (hnf -> ELit a) 213pattern HLit a <- (hnf -> ELit a)
218 where HLit = ELit 214 where HLit = ELit
@@ -247,7 +243,7 @@ mkOrdering x = case x of
247conTypeName :: ConName -> TyConName 243conTypeName :: ConName -> TyConName
248conTypeName (ConName _ _ t) = case snd $ getParams t of TyCon n _ -> n 244conTypeName (ConName _ _ t) = case snd $ getParams t of TyCon n _ -> n
249 245
250mkFun_ md (FunName _ _ (DeltaDef ar f) _) as _ | length as == ar = f md $ reverse as 246mkFun_ md (FunName _ _ (DeltaDef ar f) _) as _ | length as == ar = f md as
251mkFun_ md f xs y = Neut $ Fun_ md f xs $ hnf y 247mkFun_ md f xs y = Neut $ Fun_ md f xs $ hnf y
252 248
253mkFun :: FunName -> [Exp] -> Exp -> Exp 249mkFun :: FunName -> [Exp] -> Exp -> Exp
@@ -440,7 +436,7 @@ getFixLam _ = Nothing
440instance MkDoc Neutral where 436instance MkDoc Neutral where
441 mkDoc pr@(reduce, body) = \case 437 mkDoc pr@(reduce, body) = \case
442 CstrT' t a b -> shCstr (mkDoc pr a) (mkDoc pr (ET b t)) 438 CstrT' t a b -> shCstr (mkDoc pr a) (mkDoc pr (ET b t))
443 Fun (FunName _ _ (ExpDef d) _) xs _ | body -> mkDoc (reduce, False) (foldl app_ d $ reverse xs) 439 Fun (FunName _ _ (ExpDef d) _) xs _ | body -> mkDoc (reduce, False) (foldlrev app_ d xs)
444 FFix (getFixLam -> Just (s, xs)) | not body -> foldl DApp (pShow s) $ mkDoc pr <$> xs 440 FFix (getFixLam -> Just (s, xs)) | not body -> foldl DApp (pShow s) $ mkDoc pr <$> xs
445 FFix f {- | body -} -> foldl DApp "primFix" [{-pShow t -}"_", mkDoc pr f] 441 FFix f {- | body -} -> foldl DApp "primFix" [{-pShow t -}"_", mkDoc pr f]
446 Fun (FunName _ _ (DeltaDef n _) _) _ _ | body -> text $ "<<delta function with arity " ++ show n ++ ">>" 442 Fun (FunName _ _ (DeltaDef n _) _) _ _ | body -> text $ "<<delta function with arity " ++ show n ++ ">>"
@@ -476,79 +472,80 @@ pattern FunName' a t <- FunName a _ _ t
476mkFunDef a@(show -> "primFix") t = fn 472mkFunDef a@(show -> "primFix") t = fn
477 where 473 where
478 fn = FunName a 0 (DeltaDef (length $ fst $ getParams t) fx) t 474 fn = FunName a 0 (DeltaDef (length $ fst $ getParams t) fx) t
479 fx s xs = Neut $ Fun_ s fn (reverse xs) $ case xs of 475 fx s xs = Neut $ Fun_ s fn xs $ case xs of
480 _: f: _ -> RHS x where x = f `app_` Neut (Fun_ s fn (reverse xs) $ RHS x) 476 f: _{-1-} -> RHS x where x = f `app_` Neut (Fun_ s fn xs $ RHS x)
481 _ -> delta 477 _ -> delta
482 478
483mkFunDef a t = fn 479mkFunDef a t = fn
484 where 480 where
485 fn = FunName a 0 (maybe NoDef (DeltaDef (length $ fst $ getParams t) . const) $ getFunDef t a $ \xs -> Neut $ Fun fn (reverse xs) delta) t 481 fn = FunName a 0 (maybe NoDef (DeltaDef (length $ fst $ getParams t) . const) $ getFunDef t a $ \xs -> Neut $ Fun fn xs delta) t
486 482
487getFunDef t s f = case show s of 483getFunDef t s f = case show s of
488 "'EqCT" -> Just $ \case (t: a: b: _) -> cstr t a b 484 "'EqCT" -> Just $ \case (b: a: t: _) -> cstr t a b
489 "'T2" -> Just $ \case (a: b: _) -> t2 a b 485 "'T2" -> Just $ \case (b: a: _) -> t2 a b
490 "'CW" -> Just $ \case (a: _) -> cw a 486 "'CW" -> Just $ \case (a: _) -> cw a
491 "t2C" -> Just $ \case (a: b: _) -> t2C a b 487 "t2C" -> Just $ \case (b: a: _) -> t2C a b
492 "coe" -> Just $ \case (a: b: t: d: _) -> evalCoe a b t d 488 "coe" -> Just $ \case (d: t: b: a: _) -> evalCoe a b t d
493 "parEval" -> Just $ \case (t: a: b: _) -> parEval t a b 489 "parEval" -> Just $ \case (b: a: t: _) -> parEval t a b
494 where 490 where
495 parEval _ x@RHS{} _ = x 491 parEval _ x@RHS{} _ = x
496 parEval _ _ x@RHS{} = x 492 parEval _ _ x@RHS{} = x
497 parEval t a b = ParEval t a b 493 parEval t a b = ParEval t a b
498 494
499 "unsafeCoerce" -> Just $ \case xs@(_: _: x@(hnf -> NonNeut): _) -> x; xs -> f xs 495 "unsafeCoerce" -> Just $ \case xs@(x@(hnf -> NonNeut): _{-2-}) -> x; xs -> f xs
500 "reflCstr" -> Just $ \case (a: _) -> TT 496 "reflCstr" -> Just $ \case _ -> TT
501 "hlistNilCase" -> Just $ \case (_: x: (hnf -> Con n@(ConName _ 0 _) _ _): _) -> x; xs -> f xs 497 "hlistNilCase" -> Just $ \case ((hnf -> Con n@(ConName _ 0 _) _ _): x: _{-1-}) -> x; xs -> f xs
502 "hlistConsCase" -> Just $ \case (_: _: _: x: (hnf -> Con n@(ConName _ 1 _) _ (_: _: a: b: _)): _) -> x `app_` a `app_` b; xs -> f xs 498 "hlistConsCase" -> Just $ \case ((hnf -> Con n@(ConName _ 1 _) _ (_: _: a: b: _)): x: _{-3-}) -> x `app_` a `app_` b; xs -> f xs
503 499
504 -- general compiler primitives 500 -- general compiler primitives
505 "primAddInt" -> Just $ \case (HInt i: HInt j: _) -> HInt (i + j); xs -> f xs 501 "primAddInt" -> Just $ \case (HInt j: HInt i: _) -> HInt (i + j); xs -> f xs
506 "primSubInt" -> Just $ \case (HInt i: HInt j: _) -> HInt (i - j); xs -> f xs 502 "primSubInt" -> Just $ \case (HInt j: HInt i: _) -> HInt (i - j); xs -> f xs
507 "primModInt" -> Just $ \case (HInt i: HInt j: _) -> HInt (i `mod` j); xs -> f xs 503 "primModInt" -> Just $ \case (HInt j: HInt i: _) -> HInt (i `mod` j); xs -> f xs
508 "primSqrtFloat" -> Just $ \case (HFloat i: _) -> HFloat $ sqrt i; xs -> f xs 504 "primSqrtFloat" -> Just $ \case (HFloat i: _) -> HFloat $ sqrt i; xs -> f xs
509 "primRound" -> Just $ \case (HFloat i: _) -> HInt $ round i; xs -> f xs 505 "primRound" -> Just $ \case (HFloat i: _) -> HInt $ round i; xs -> f xs
510 "primIntToFloat" -> Just $ \case (HInt i: _) -> HFloat $ fromIntegral i; xs -> f xs 506 "primIntToFloat" -> Just $ \case (HInt i: _) -> HFloat $ fromIntegral i; xs -> f xs
511 "primIntToNat" -> Just $ \case (HInt i: _) -> ENat $ fromIntegral i; xs -> f xs 507 "primIntToNat" -> Just $ \case (HInt i: _) -> ENat $ fromIntegral i; xs -> f xs
512 "primCompareInt" -> Just $ \case (HInt x: HInt y: _) -> mkOrdering $ x `compare` y; xs -> f xs 508 "primCompareInt" -> Just $ \case (HInt y: HInt x: _) -> mkOrdering $ x `compare` y; xs -> f xs
513 "primCompareFloat" -> Just $ \case (HFloat x: HFloat y: _) -> mkOrdering $ x `compare` y; xs -> f xs 509 "primCompareFloat" -> Just $ \case (HFloat y: HFloat x: _) -> mkOrdering $ x `compare` y; xs -> f xs
514 "primCompareChar" -> Just $ \case (HChar x: HChar y: _) -> mkOrdering $ x `compare` y; xs -> f xs 510 "primCompareChar" -> Just $ \case (HChar y: HChar x: _) -> mkOrdering $ x `compare` y; xs -> f xs
515 "primCompareString" -> Just $ \case (HString x: HString y: _) -> mkOrdering $ x `compare` y; xs -> f xs 511 "primCompareString" -> Just $ \case (HString y: HString x: _) -> mkOrdering $ x `compare` y; xs -> f xs
516 512
517 -- LambdaCube 3D specific primitives 513 -- LambdaCube 3D specific primitives
518 "PrimGreaterThan" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (>) t x y -> r; xs -> f xs 514 "PrimGreaterThan" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (>) x y -> r; xs -> f xs
519 "PrimGreaterThanEqual" 515 "PrimGreaterThanEqual"
520 -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (>=) t x y -> r; xs -> f xs 516 -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (>=) x y -> r; xs -> f xs
521 "PrimLessThan" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (<) t x y -> r; xs -> f xs 517 "PrimLessThan" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (<) x y -> r; xs -> f xs
522 "PrimLessThanEqual" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (<=) t x y -> r; xs -> f xs 518 "PrimLessThanEqual" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (<=) x y -> r; xs -> f xs
523 "PrimEqualV" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (==) t x y -> r; xs -> f xs 519 "PrimEqualV" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (==) x y -> r; xs -> f xs
524 "PrimNotEqualV" -> Just $ \case (t: _: _: _: _: _: _: x: y: _) | Just r <- twoOpBool (/=) t x y -> r; xs -> f xs 520 "PrimNotEqualV" -> Just $ \case (y: x: _{-7-}) | Just r <- twoOpBool (/=) x y -> r; xs -> f xs
525 "PrimEqual" -> Just $ \case (t: _: _: x: y: _) | Just r <- twoOpBool (==) t x y -> r; xs -> f xs 521 "PrimEqual" -> Just $ \case (y: x: _{-3-}) | Just r <- twoOpBool (==) x y -> r; xs -> f xs
526 "PrimNotEqual" -> Just $ \case (t: _: _: x: y: _) | Just r <- twoOpBool (/=) t x y -> r; xs -> f xs 522 "PrimNotEqual" -> Just $ \case (y: x: _{-3-}) | Just r <- twoOpBool (/=) x y -> r; xs -> f xs
527 "PrimSubS" -> Just $ \case (_: _: _: _: x: y: _) | Just r <- twoOp (-) x y -> r; xs -> f xs 523 "PrimSubS" -> Just $ \case (y: x: _{-4-}) | Just r <- twoOp (-) x y -> r; xs -> f xs
528 "PrimSub" -> Just $ \case (_: _: x: y: _) | Just r <- twoOp (-) x y -> r; xs -> f xs 524 "PrimSub" -> Just $ \case (y: x: _{-2-}) | Just r <- twoOp (-) x y -> r; xs -> f xs
529 "PrimAddS" -> Just $ \case (_: _: _: _: x: y: _) | Just r <- twoOp (+) x y -> r; xs -> f xs 525 "PrimAddS" -> Just $ \case (y: x: _{-4-}) | Just r <- twoOp (+) x y -> r; xs -> f xs
530 "PrimAdd" -> Just $ \case (_: _: x: y: _) | Just r <- twoOp (+) x y -> r; xs -> f xs 526 "PrimAdd" -> Just $ \case (y: x: _{-2-}) | Just r <- twoOp (+) x y -> r; xs -> f xs
531 "PrimMulS" -> Just $ \case (_: _: _: _: x: y: _) | Just r <- twoOp (*) x y -> r; xs -> f xs 527 "PrimMulS" -> Just $ \case (y: x: _{-4-}) | Just r <- twoOp (*) x y -> r; xs -> f xs
532 "PrimMul" -> Just $ \case (_: _: x: y: _) | Just r <- twoOp (*) x y -> r; xs -> f xs 528 "PrimMul" -> Just $ \case (y: x: _{-2-}) | Just r <- twoOp (*) x y -> r; xs -> f xs
533 "PrimDivS" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs 529 "PrimDivS" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs
534 "PrimDiv" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs 530 "PrimDiv" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ (/) div x y -> r; xs -> f xs
535 "PrimModS" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs 531 "PrimModS" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs
536 "PrimMod" -> Just $ \case (_: _: _: _: _: x: y: _) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs 532 "PrimMod" -> Just $ \case (y: x: _{-5-}) | Just r <- twoOp_ modF mod x y -> r; xs -> f xs
537 "PrimNeg" -> Just $ \case (_: x: _) | Just r <- oneOp negate x -> r; xs -> f xs 533 "PrimNeg" -> Just $ \case (x: _{-1-}) | Just r <- oneOp negate x -> r; xs -> f xs
538 "PrimAnd" -> Just $ \case (EBool x: EBool y: _) -> EBool (x && y); xs -> f xs 534 "PrimAnd" -> Just $ \case (EBool y: EBool x: _) -> EBool (x && y); xs -> f xs
539 "PrimOr" -> Just $ \case (EBool x: EBool y: _) -> EBool (x || y); xs -> f xs 535 "PrimOr" -> Just $ \case (EBool y: EBool x: _) -> EBool (x || y); xs -> f xs
540 "PrimXor" -> Just $ \case (EBool x: EBool y: _) -> EBool (x /= y); xs -> f xs 536 "PrimXor" -> Just $ \case (EBool y: EBool x: _) -> EBool (x /= y); xs -> f xs
541 "PrimNot" -> Just $ \case ((hnf -> TNat): _: _: EBool x: _) -> EBool $ not x; xs -> f xs 537 "PrimNot" -> Just $ \case (EBool x: _: _: (hnf -> TNat): _) -> EBool $ not x; xs -> f xs
542 538
543 _ -> Nothing 539 _ -> Nothing
544 where 540 where
545 twoOpBool :: (forall a . Ord a => a -> a -> Bool) -> Exp -> Exp -> Exp -> Maybe Exp 541
546 twoOpBool f t (HFloat x) (HFloat y) = Just $ EBool $ f x y 542 twoOpBool :: (forall a . Ord a => a -> a -> Bool) -> Exp -> Exp -> Maybe Exp
547 twoOpBool f t (HInt x) (HInt y) = Just $ EBool $ f x y 543 twoOpBool f (HFloat x) (HFloat y) = Just $ EBool $ f x y
548 twoOpBool f t (HString x) (HString y) = Just $ EBool $ f x y 544 twoOpBool f (HInt x) (HInt y) = Just $ EBool $ f x y
549 twoOpBool f t (HChar x) (HChar y) = Just $ EBool $ f x y 545 twoOpBool f (HString x) (HString y) = Just $ EBool $ f x y
550 twoOpBool f t (ENat x) (ENat y) = Just $ EBool $ f x y 546 twoOpBool f (HChar x) (HChar y) = Just $ EBool $ f x y
551 twoOpBool _ _ _ _ = Nothing 547 twoOpBool f (ENat x) (ENat y) = Just $ EBool $ f x y
548 twoOpBool _ _ _ = Nothing
552 549
553 oneOp :: (forall a . Num a => a -> a) -> Exp -> Maybe Exp 550 oneOp :: (forall a . Num a => a -> a) -> Exp -> Maybe Exp
554 oneOp f = oneOp_ f f 551 oneOp f = oneOp_ f f
@@ -606,11 +603,11 @@ cstr = f []
606 f_ (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a' 603 f_ (_: ns) typ{-down?-} (down 0 -> Just a) (down 0 -> Just a') = f ns typ a a'
607 f_ ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b') 604 f_ ns TType (Pi h a b) (Pi h' a' b') | h == h' = t2 (f ns TType a a') (f ((a, a'): ns) TType b b')
608 605
609 f_ [] TType (UFun F'VecScalar [a, b]) (UFun F'VecScalar [a', b']) = t2 (f [] TNat a a') (f [] TType b b') 606 f_ [] TType (UFun F'VecScalar [b, a]) (UFun F'VecScalar [b', a']) = t2 (f [] TNat a a') (f [] TType b b')
610 f_ [] TType (UFun F'VecScalar [a, b]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b') 607 f_ [] TType (UFun F'VecScalar [b, a]) (TVec a' b') = t2 (f [] TNat a a') (f [] TType b b')
611 f_ [] TType (UFun F'VecScalar [a, b]) t@NonNeut = t2 (f [] TNat a (ENat 1)) (f [] TType b t) 608 f_ [] TType (UFun F'VecScalar [b, a]) t@NonNeut = t2 (f [] TNat a (ENat 1)) (f [] TType b t)
612 f_ [] TType (TVec a' b') (UFun F'VecScalar [a, b]) = t2 (f [] TNat a' a) (f [] TType b' b) 609 f_ [] TType (TVec a' b') (UFun F'VecScalar [b, a]) = t2 (f [] TNat a' a) (f [] TType b' b)
613 f_ [] TType t@NonNeut (UFun F'VecScalar [a, b]) = t2 (f [] TNat a (ENat 1)) (f [] TType b t) 610 f_ [] TType t@NonNeut (UFun F'VecScalar [b, a]) = t2 (f [] TNat a (ENat 1)) (f [] TType b t)
614 611
615 f_ [] typ a@Neut{} a' = CstrT typ a a' 612 f_ [] typ a@Neut{} a' = CstrT typ a a'
616 f_ [] typ a a'@Neut{} = CstrT typ a a' 613 f_ [] typ a a'@Neut{} = CstrT typ a a'
@@ -626,7 +623,7 @@ nonNeut Neut{} = False
626nonNeut _ = True 623nonNeut _ = True
627 624
628t2C (hnf -> TT) (hnf -> TT) = TT 625t2C (hnf -> TT) (hnf -> TT) = TT
629t2C a b = DFun Ft2C (Unit :~> Unit :~> Unit) [a, b] 626t2C a b = DFun Ft2C (Unit :~> Unit :~> Unit) [b, a]
630 627
631cw (hnf -> CUnit) = Unit 628cw (hnf -> CUnit) = Unit
632cw (hnf -> CEmpty a) = Empty a 629cw (hnf -> CEmpty a) = Empty a
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs
index ecd18b46..f726f30b 100644
--- a/src/LambdaCube/Compiler/Infer.hs
+++ b/src/LambdaCube/Compiler/Infer.hs
@@ -500,7 +500,7 @@ handleStmt = \case
500 Primitive n t_ -> do 500 Primitive n t_ -> do
501 t <- inferType n $ trSExp' t_ 501 t <- inferType n $ trSExp' t_
502 tellType (sourceInfo n) t 502 tellType (sourceInfo n) t
503 addToEnv n $ flip ET t $ lamify t $ Neut . DFunN_ (FName n) t 503 addToEnv n $ flip ET t $ lamify' t $ Neut . DFunN_ (FName n) t
504 StLet n mt t_ -> do 504 StLet n mt t_ -> do
505 let t__ = maybe id (flip SAnn) mt t_ 505 let t__ = maybe id (flip SAnn) mt t_
506 ET x t <- inferTerm n $ trSExp' t__ 506 ET x t <- inferTerm n $ trSExp' t__
diff --git a/src/LambdaCube/Compiler/InferMonad.hs b/src/LambdaCube/Compiler/InferMonad.hs
index e2895389..8d46f7c7 100644
--- a/src/LambdaCube/Compiler/InferMonad.hs
+++ b/src/LambdaCube/Compiler/InferMonad.hs
@@ -17,7 +17,7 @@ module LambdaCube.Compiler.InferMonad where
17 17
18import Data.Monoid 18import Data.Monoid
19import Data.List 19import Data.List
20import Data.Maybe 20--import Data.Maybe
21import qualified Data.Set as Set 21import qualified Data.Set as Set
22import qualified Data.Map as Map 22import qualified Data.Map as Map
23 23
@@ -143,10 +143,13 @@ addLams ps t = foldr (const Lam) t ps
143 143
144lamify t x = addLams (fst $ getParams t) $ x $ downTo 0 $ arity t 144lamify t x = addLams (fst $ getParams t) $ x $ downTo 0 $ arity t
145 145
146lamify' t x = addLams (fst $ getParams t) $ x $ downTo' 0 $ arity t
147
146arity :: Exp -> Int 148arity :: Exp -> Int
147arity = length . fst . getParams 149arity = length . fst . getParams
148 150
149downTo n m = map Var [n+m-1, n+m-2..n] 151downTo n m = map Var [n+m-1, n+m-2..n]
152downTo' n m = map Var [n, n+1..n+m-1]
150 153
151withEnv e = local $ second (<> e) 154withEnv e = local $ second (<> e)
152 155