summaryrefslogtreecommitdiff
path: root/src/LambdaCube
diff options
context:
space:
mode:
authorPéter Diviánszky <divipp@gmail.com>2016-02-16 07:15:29 +0100
committerPéter Diviánszky <divipp@gmail.com>2016-02-16 07:29:03 +0100
commit6700a057f30e8ca6c8aa2dde71d03516fd7ae6cd (patch)
tree5cafa9969ae9eb0884abd4b1ece21916d1680b91 /src/LambdaCube
parent2c1c5a8ae057c4e3a93ccf2a6f796af87188c0d1 (diff)
generate functions in shaders
Diffstat (limited to 'src/LambdaCube')
-rw-r--r--src/LambdaCube/Compiler.hs6
-rw-r--r--src/LambdaCube/Compiler/CoreToIR.hs156
-rw-r--r--src/LambdaCube/Compiler/Infer.hs77
-rw-r--r--src/LambdaCube/Compiler/Parser.hs3
4 files changed, 145 insertions, 97 deletions
diff --git a/src/LambdaCube/Compiler.hs b/src/LambdaCube/Compiler.hs
index 62f7a49e..5f7bd4fd 100644
--- a/src/LambdaCube/Compiler.hs
+++ b/src/LambdaCube/Compiler.hs
@@ -130,7 +130,7 @@ ioFetch paths imp n = do
130 130
131splitMPath fn = (joinPath as, intercalate "." $ bs ++ [y]) 131splitMPath fn = (joinPath as, intercalate "." $ bs ++ [y])
132 where 132 where
133 (as, bs) = span (\x -> null x || x == "." || x == "/" || isLower (head x)) xs 133 (as, bs) = span (\x -> null x || not (isUpper $ head x)) xs
134 (xs, y) = map takeDirectory . splitPath *** id $ splitFileName $ dropExtension fn 134 (xs, y) = map takeDirectory . splitPath *** id $ splitFileName $ dropExtension fn
135 135
136 136
@@ -161,11 +161,11 @@ loadModule imp mname = do
161 ExportId (snd -> d) -> case Map.lookup d $ getPolyEnv x of 161 ExportId (snd -> d) -> case Map.lookup d $ getPolyEnv x of
162 Just def -> PolyEnv (Map.singleton d def) mempty 162 Just def -> PolyEnv (Map.singleton d def) mempty
163 Nothing -> error $ d ++ " is not defined" 163 Nothing -> error $ d ++ " is not defined"
164 ExportModule (snd -> m) | m == snd (splitMPath mname) -> x 164 ExportModule (snd -> m) | m == snd (splitMPath fname) -> x
165 ExportModule m -> case [ ms 165 ExportModule m -> case [ ms
166 | ((m', is), ms) <- zip (moduleImports e) ms, m' == m] of 166 | ((m', is), ms) <- zip (moduleImports e) ms, m' == m] of
167 [PolyEnv x infos] -> PolyEnv x mempty -- TODO 167 [PolyEnv x infos] -> PolyEnv x mempty -- TODO
168 [] -> error "empty export list" 168 [] -> error $ "empty export list: " ++ show (fname, m, map fst $ moduleImports e, snd (splitMPath fname))
169 _ -> error "export list: internal error" 169 _ -> error "export list: internal error"
170 modify $ Map.insert fname $ Right (x', src) 170 modify $ Map.insert fname $ Right (x', src)
171 return (fname, x') 171 return (fname, x')
diff --git a/src/LambdaCube/Compiler/CoreToIR.hs b/src/LambdaCube/Compiler/CoreToIR.hs
index 1d9f9017..42dcd1bb 100644
--- a/src/LambdaCube/Compiler/CoreToIR.hs
+++ b/src/LambdaCube/Compiler/CoreToIR.hs
@@ -27,7 +27,7 @@ import qualified LambdaCube.Linear as IR
27 27
28import LambdaCube.Compiler.Pretty 28import LambdaCube.Compiler.Pretty
29import Text.PrettyPrint.Compact (nest) 29import Text.PrettyPrint.Compact (nest)
30import LambdaCube.Compiler.Infer hiding (Con, Lam, Pi, TType, Var, ELit) 30import LambdaCube.Compiler.Infer hiding (Con, Lam, Pi, TType, Var, ELit, Func)
31import qualified LambdaCube.Compiler.Infer as I 31import qualified LambdaCube.Compiler.Infer as I
32import LambdaCube.Compiler.Parser (up, Up (..)) 32import LambdaCube.Compiler.Parser (up, Up (..))
33 33
@@ -337,7 +337,7 @@ compFrag x = case x of
337 337
338-- todo: remove 338-- todo: remove
339toGLSLType msg (TTuple []) = "void" 339toGLSLType msg (TTuple []) = "void"
340toGLSLType msg x = showGLSLType msg $ compInputType x 340toGLSLType msg x = showGLSLType msg $ compInputType msg x
341 341
342-- move to lambdacube-ir? 342-- move to lambdacube-ir?
343showGLSLType msg = \case 343showGLSLType msg = \case
@@ -369,40 +369,44 @@ showGLSLType msg = \case
369 IR.FTexture2D -> "sampler2D" 369 IR.FTexture2D -> "sampler2D"
370 t -> error $ "toGLSLType: " ++ msg ++ " " ++ show t 370 t -> error $ "toGLSLType: " ++ msg ++ " " ++ show t
371 371
372compInputType x = case x of 372supType = isJust . compInputType_
373 TFloat -> IR.Float 373
374 TVec 2 TFloat -> IR.V2F 374compInputType_ x = case x of
375 TVec 3 TFloat -> IR.V3F 375 TFloat -> Just IR.Float
376 TVec 4 TFloat -> IR.V4F 376 TVec 2 TFloat -> Just IR.V2F
377 TBool -> IR.Bool 377 TVec 3 TFloat -> Just IR.V3F
378 TVec 2 TBool -> IR.V2B 378 TVec 4 TFloat -> Just IR.V4F
379 TVec 3 TBool -> IR.V3B 379 TBool -> Just IR.Bool
380 TVec 4 TBool -> IR.V4B 380 TVec 2 TBool -> Just IR.V2B
381 TInt -> IR.Int 381 TVec 3 TBool -> Just IR.V3B
382 TVec 2 TInt -> IR.V2I 382 TVec 4 TBool -> Just IR.V4B
383 TVec 3 TInt -> IR.V3I 383 TInt -> Just IR.Int
384 TVec 4 TInt -> IR.V4I 384 TVec 2 TInt -> Just IR.V2I
385 TWord -> IR.Word 385 TVec 3 TInt -> Just IR.V3I
386 TVec 2 TWord -> IR.V2U 386 TVec 4 TInt -> Just IR.V4I
387 TVec 3 TWord -> IR.V3U 387 TWord -> Just IR.Word
388 TVec 4 TWord -> IR.V4U 388 TVec 2 TWord -> Just IR.V2U
389 TMat 2 2 TFloat -> IR.M22F 389 TVec 3 TWord -> Just IR.V3U
390 TMat 2 3 TFloat -> IR.M23F 390 TVec 4 TWord -> Just IR.V4U
391 TMat 2 4 TFloat -> IR.M24F 391 TMat 2 2 TFloat -> Just IR.M22F
392 TMat 3 2 TFloat -> IR.M32F 392 TMat 2 3 TFloat -> Just IR.M23F
393 TMat 3 3 TFloat -> IR.M33F 393 TMat 2 4 TFloat -> Just IR.M24F
394 TMat 3 4 TFloat -> IR.M34F 394 TMat 3 2 TFloat -> Just IR.M32F
395 TMat 4 2 TFloat -> IR.M42F 395 TMat 3 3 TFloat -> Just IR.M33F
396 TMat 4 3 TFloat -> IR.M43F 396 TMat 3 4 TFloat -> Just IR.M34F
397 TMat 4 4 TFloat -> IR.M44F 397 TMat 4 2 TFloat -> Just IR.M42F
398 x -> error $ "compInputType " ++ ppShow x 398 TMat 4 3 TFloat -> Just IR.M43F
399 TMat 4 4 TFloat -> Just IR.M44F
400 _ -> Nothing
401
402compInputType msg x = fromMaybe (error $ "compInputType " ++ msg ++ " " ++ ppShow x) $ compInputType_ x
399 403
400is234 = (`elem` [2,3,4]) 404is234 = (`elem` [2,3,4])
401 405
402 406
403compAttribute x = case x of 407compAttribute x = case x of
404 ETuple a -> concatMap compAttribute a 408 ETuple a -> concatMap compAttribute a
405 A1 "Attribute" (EString s) -> [(s, compInputType $ tyOf x)] 409 A1 "Attribute" (EString s) -> [(s, compInputType "compAttr" $ tyOf x)]
406 x -> error $ "compAttribute " ++ ppShow x 410 x -> error $ "compAttribute " ++ ppShow x
407 411
408compList (A2 "Cons" a x) = compValue a : compList x 412compList (A2 "Cons" a x) = compValue a : compList x
@@ -430,7 +434,7 @@ compAttributeValue x = checkLength $ go x
430 434
431 go = \case 435 go = \case
432 ETuple a -> concatMap go a 436 ETuple a -> concatMap go a
433 a -> let A1 "List" (compInputType -> t) = tyOf a 437 a -> let A1 "List" (compInputType "compAV" -> t) = tyOf a
434 values = compList a 438 values = compList a
435 in 439 in
436 [(length values,(t,foldr (flatten t) (emptyArray t) values))] 440 [(length values,(t,foldr (flatten t) (emptyArray t) values))]
@@ -524,6 +528,7 @@ genGLSLs backend
524 uniformDecls vertUniforms 528 uniformDecls vertUniforms
525 <> [shaderDecl (caseWO "attribute" "in") (text t) (text n) | (n, t) <- zip vertInNames vertIns] 529 <> [shaderDecl (caseWO "attribute" "in") (text t) (text n) | (n, t) <- zip vertInNames vertIns]
526 <> vertOutDecls "out" 530 <> vertOutDecls "out"
531 <> map shaderF vertFuncs
527 <> [mainFunc $ 532 <> [mainFunc $
528 [shaderLet (text n) x | (n, x) <- zip vertOutNamesWithPosition vertGLSL] 533 [shaderLet (text n) x | (n, x) <- zip vertOutNamesWithPosition vertGLSL]
529 <> [shaderLet "gl_PointSize" x | Just x <- [ptGLSL]] 534 <> [shaderLet "gl_PointSize" x | Just x <- [ptGLSL]]
@@ -534,6 +539,7 @@ genGLSLs backend
534 uniformDecls fragUniforms 539 uniformDecls fragUniforms
535 <> vertOutDecls "in" 540 <> vertOutDecls "in"
536 <> [shaderDecl "out" (toGLSLType "4" tfrag) fragColorName | noUnit tfrag, backend == OpenGL33] 541 <> [shaderDecl "out" (toGLSLType "4" tfrag) fragColorName | noUnit tfrag, backend == OpenGL33]
542 <> map shaderF fragFuncs
537 <> [mainFunc $ 543 <> [mainFunc $
538 [shaderStmt $ "if" <+> parens ("!" <> parens filt) <+> "discard" | Just filt <- [filtGLSL]] 544 [shaderStmt $ "if" <+> parens ("!" <> parens filt) <+> "discard" | Just filt <- [filtGLSL]]
539 <> [shaderLet fragColorName $ fromMaybe (text $ head vertOutNames) fragGLSL | noUnit tfrag] 545 <> [shaderLet fragColorName $ fromMaybe (text $ head vertOutNames) fragGLSL | noUnit tfrag]
@@ -547,14 +553,41 @@ genGLSLs backend
547 Just (etaRed -> Just (vertIns, verts)) -> (toGLSLType "3" <$> vertIns, eTuple verts) 553 Just (etaRed -> Just (vertIns, verts)) -> (toGLSLType "3" <$> vertIns, eTuple verts)
548 Nothing -> ([], [mkTVar 0 tvert]) 554 Nothing -> ([], [mkTVar 0 tvert])
549 555
550 (((vertGLSL, ptGLSL), vertUniforms), ((filtGLSL, fragGLSL), fragUniforms)) = flip evalState shaderNames $ (,) 556 (((vertGLSL, ptGLSL), (vertUniforms, vertFuncs)), ((filtGLSL, fragGLSL), (fragUniforms, fragFuncs))) = flip evalState shaderNames $ do
551 <$> runWriterT ((,) 557 ((g1, (us1, verts)), (g2, (us2, frags))) <- (,)
552 <$> traverse (genGLSL' vertInNames . (,) vertIns) verts 558 <$> runWriterT ((,)
553 <*> traverse (genGLSL' vertOutNamesWithPosition . red) rp) 559 <$> traverse (genGLSL' vertInNames . (,) vertIns) verts
554 <*> runWriterT ((,) 560 <*> traverse (genGLSL' vertOutNamesWithPosition . red) rp)
555 <$> traverse (genGLSL' vertOutNames . red) ffilter 561 <*> runWriterT ((,)
556 <*> traverse (genGLSL' vertOutNames . red) frag) 562 <$> traverse (genGLSL' vertOutNames . red) ffilter
563 <*> traverse (genGLSL' vertOutNames . red) frag)
564 (,) <$> ((,) g1 <$> fixFuncs us1 [] verts) <*> ((,) g2 <$> fixFuncs us2 [] frags)
565
566 fixFuncs :: Uniforms -> [(SName, (Doc, ExpTV, [ExpTV]))] -> Map.Map SName (ExpTV, ExpTV, [ExpTV]) -> State [SName] (Uniforms, [(SName, (Doc, ExpTV, [ExpTV]))])
567 fixFuncs us fsb fsa
568 | Map.null fsa = return (us, fsb)
569 | otherwise = do
570 (unzip -> (defs, unzip -> (us', fs'))) <- forM fsa' $ \(fn, (def, ty, tys)) -> do
571 let
572 removeLams 0 x = x
573 removeLams i (ELam _ x) = removeLams (i-1) x
574 removeLams i (Lam Hidden _ x) = removeLams i x
575 removeLams i x = error $ "removeLams: " ++ fn ++ " " ++ show i ++ " " ++ show x
576
577 runWriterT $ genGLSL (reverse $ take (length tys) funArgs) $ removeLams (length tys) def
578 let fsb' = zipWith combine fsa' defs <> fsb
579 fixFuncs (us <> mconcat us') fsb' (Map.filterWithKey (\k _ -> k `notElem` map fst fsb') $ mconcat fs')
580 where
581 fsa' = Map.toList fsa
582 combine (fn, (_, ty, tys)) def = (fn, (def, ty, tys))
583
584 shaderF (fn, (def, ty, tys))
585 = shaderFunc' (toGLSLType "44" ty) (text fn)
586 (zipWith (<+>) (map (toGLSLType "45") tys) (map text funArgs))
587 def
557 588
589
590 funArgs = map (("z" ++) . show) [0..]
558 shaderNames = map (("s" ++) . show) [0..] 591 shaderNames = map (("s" ++) . show) [0..]
559 vertInNames = map (("vi" ++) . show) [1..length vertIns] 592 vertInNames = map (("vi" ++) . show) [1..length vertIns]
560 vertOutNames = map (("vo" ++) . show) [1..length vertOuts] 593 vertOutNames = map (("vo" ++) . show) [1..length vertOuts]
@@ -583,6 +616,9 @@ genGLSLs backend
583 <> [shaderFunc "vec4" "texture2D" ["sampler2D s", "vec2 uv"] [shaderReturn "texture(s,uv)"] | backend == OpenGL33] 616 <> [shaderFunc "vec4" "texture2D" ["sampler2D s", "vec2 uv"] [shaderReturn "texture(s,uv)"] | backend == OpenGL33]
584 <> xs 617 <> xs
585 618
619 shaderFunc' ot n [] b = shaderLet (ot <+> n) b
620 shaderFunc' ot n ps b = shaderFunc ot n ps [shaderReturn b]
621
586 shaderFunc outtype name pars body = nest 4 (outtype <+> name <> tupled pars <+> "{" <$$> vcat body) <$$> "}" 622 shaderFunc outtype name pars body = nest 4 (outtype <+> name <> tupled pars <+> "{" <$$> vcat body) <$$> "}"
587 mainFunc xs = shaderFunc "void" "main" [] xs 623 mainFunc xs = shaderFunc "void" "main" [] xs
588 shaderStmt xs = nest 4 $ xs <> ";" 624 shaderStmt xs = nest 4 $ xs <> ";"
@@ -606,12 +642,16 @@ data Uniform
606 642
607type Uniforms = Map String (Uniform, IR.InputType) 643type Uniforms = Map String (Uniform, IR.InputType)
608 644
609genGLSL :: [SName] -> ExpTV -> WriterT Uniforms (State [String]) Doc 645tellUniform x = tell (x, mempty)
646
647genGLSL :: [SName] -> ExpTV -> WriterT (Uniforms, Map.Map SName (ExpTV, ExpTV, [ExpTV])) (State [String]) Doc
610genGLSL dns e = case e of 648genGLSL dns e = case e of
611 649
612 ELit a -> pure $ text $ show a 650 ELit a -> pure $ text $ show a
613 Var i _ -> pure $ text $ dns !! i 651 Var i _ -> pure $ text $ dns !! i
614 652
653 Func fn def ty xs -> tell (mempty, Map.singleton fn (def, ty, map tyOf xs)) >> call fn xs
654
615 Con cn xs -> case cn of 655 Con cn xs -> case cn of
616 "primIfThenElse" -> case xs of [a, b, c] -> hsep <$> sequence [gen a, pure "?", gen b, pure ":", gen c] 656 "primIfThenElse" -> case xs of [a, b, c] -> hsep <$> sequence [gen a, pure "?", gen b, pure ":", gen c]
617 657
@@ -620,15 +660,15 @@ genGLSL dns e = case e of
620 660
621 "Uniform" -> case xs of 661 "Uniform" -> case xs of
622 [EString s] -> do 662 [EString s] -> do
623 tell $ Map.singleton s $ (,) UUniform $ compInputType $ tyOf e 663 tellUniform $ Map.singleton s $ (,) UUniform $ compInputType "unif" $ tyOf e
624 pure $ text s 664 pure $ text s
625 "Sampler" -> case xs of 665 "Sampler" -> case xs of
626 [_, _, A1 "Texture2DSlot" (EString s)] -> do 666 [_, _, A1 "Texture2DSlot" (EString s)] -> do
627 tell $ Map.singleton s $ (,) UTexture2DSlot IR.FTexture2D{-compInputType $ tyOf e -- TODO-} 667 tellUniform $ Map.singleton s $ (,) UTexture2DSlot IR.FTexture2D{-compInputType $ tyOf e -- TODO-}
628 pure $ text s 668 pure $ text s
629 [_, _, A2 "Texture2D" (A2 "V2" (EInt w) (EInt h)) b] -> do 669 [_, _, A2 "Texture2D" (A2 "V2" (EInt w) (EInt h)) b] -> do
630 s <- newName 670 s <- newName
631 tell $ Map.singleton s $ (,) (UTexture2D w h b) IR.FTexture2D 671 tellUniform $ Map.singleton s $ (,) (UTexture2D w h b) IR.FTexture2D
632 pure $ text s 672 pure $ text s
633 673
634 'P':'r':'i':'m':n | n'@(_:_) <- trName (dropS n) -> call n' xs 674 'P':'r':'i':'m':n | n'@(_:_) <- trName (dropS n) -> call n' xs
@@ -735,7 +775,7 @@ genGLSL dns e = case e of
735 "==" -> "==" 775 "==" -> "=="
736 776
737 n | n `elem` ["primNegateWord", "primNegateInt", "primNegateFloat"] -> "-_" 777 n | n `elem` ["primNegateWord", "primNegateInt", "primNegateFloat"] -> "-_"
738 n | n `elem` ["V2", "V3", "V4"] -> toGLSLType "5" $ tyOf e 778 n | n `elem` ["V2", "V3", "V4"] -> toGLSLType (n ++ " " ++ show (length xs)) $ tyOf e
739 _ -> "" 779 _ -> ""
740 780
741 -- not supported 781 -- not supported
@@ -793,7 +833,7 @@ genGLSL dns e = case e of
793data ExpTV = ExpTV_ Exp Exp [Exp] 833data ExpTV = ExpTV_ Exp Exp [Exp]
794 deriving (Show, Eq) 834 deriving (Show, Eq)
795 835
796pattern ExpTV a b c <- ExpTV_ a b c where ExpTV a b c = ExpTV_ (unLab' a) (unLab' b) c 836pattern ExpTV a b c <- ExpTV_ a b c where ExpTV a b c = ExpTV_ (a) (unLab' b) c
797 837
798type Ty = ExpTV 838type Ty = ExpTV
799 839
@@ -809,7 +849,8 @@ pattern Con h b <- (mkCon -> Just (h, b))
809pattern App a b <- (mkApp -> Just (a, b)) 849pattern App a b <- (mkApp -> Just (a, b))
810pattern Var a b <- (mkVar -> Just (a, b)) 850pattern Var a b <- (mkVar -> Just (a, b))
811pattern ELit l <- ExpTV (I.ELit l) _ _ 851pattern ELit l <- ExpTV (I.ELit l) _ _
812pattern TType <- ExpTV I.TType _ _ 852pattern TType <- ExpTV (unLab' -> I.TType) _ _
853pattern Func fn def ty xs <- (mkFunc -> Just (fn, def, ty, xs))
813 854
814pattern EString s <- ELit (LString s) 855pattern EString s <- ELit (LString s)
815pattern EFloat s <- ELit (LFloat s) 856pattern EFloat s <- ELit (LFloat s)
@@ -818,26 +859,33 @@ pattern EInt s <- ELit (LInt s)
818t .@ vs = ExpTV t I.TType vs 859t .@ vs = ExpTV t I.TType vs
819infix 1 .@ 860infix 1 .@
820 861
821mkVar (ExpTV (I.Var i) t vs) = Just (i, t .@ vs) 862mkVar (ExpTV (unLab' -> I.Var i) t vs) = Just (i, t .@ vs)
822mkVar _ = Nothing 863mkVar _ = Nothing
823 864
824mkPi (ExpTV (I.Pi b x y) _ vs) = Just (b, x .@ vs, y .@ addToEnv x vs) 865mkPi (ExpTV (unLab' -> I.Pi b x y) _ vs) = Just (b, x .@ vs, y .@ addToEnv x vs)
825mkPi _ = Nothing 866mkPi _ = Nothing
826 867
827mkLam (ExpTV (I.Lam y) (I.Pi b x yt) vs) = Just (b, x .@ vs, ExpTV y yt $ addToEnv x vs) 868mkLam (ExpTV (unLab' -> I.Lam y) (I.Pi b x yt) vs) = Just (b, x .@ vs, ExpTV y yt $ addToEnv x vs)
828mkLam _ = Nothing 869mkLam _ = Nothing
829 870
830mkCon (ExpTV (I.Con s n xs) et vs) = Just (untick $ show s, chain vs (conType et s) $ mkConPars n et ++ xs) 871mkCon (ExpTV (unLab' -> I.Con s n xs) et vs) = Just (untick $ show s, chain vs (conType et s) $ mkConPars n et ++ xs)
831mkCon (ExpTV (TyCon s xs) et vs) = Just (untick $ show s, chain vs (nType s) xs) 872mkCon (ExpTV (unLab' -> TyCon s xs) et vs) = Just (untick $ show s, chain vs (nType s) xs)
832mkCon (ExpTV (Neut (I.Fun s i (reverse -> xs) def)) et vs) = Just (untick $ show s, chain vs (nType s) xs) 873mkCon (ExpTV (unLab' -> Neut (I.Fun s i (reverse -> xs) def)) et vs) = Just (untick $ show s, chain vs (nType s) xs)
833mkCon (ExpTV (CaseFun s xs n) et vs) = Just (untick $ show s, chain vs (nType s) $ makeCaseFunPars' (mkEnv vs) n ++ xs ++ [Neut n]) 874mkCon (ExpTV (unLab' -> CaseFun s xs n) et vs) = Just (untick $ show s, chain vs (nType s) $ makeCaseFunPars' (mkEnv vs) n ++ xs ++ [Neut n])
834mkCon (ExpTV (TyCaseFun s [m, t, f] n) et vs) = Just (untick $ show s, chain vs (nType s) [m, t, Neut n, f]) 875mkCon (ExpTV (unLab' -> TyCaseFun s [m, t, f] n) et vs) = Just (untick $ show s, chain vs (nType s) [m, t, Neut n, f])
835mkCon _ = Nothing 876mkCon _ = Nothing
836 877
837mkApp (ExpTV (Neut (I.App_ a b)) et vs) = Just (ExpTV (Neut a) t vs, head $ chain vs t [b]) 878mkApp (ExpTV (unLab' -> Neut (I.App_ a b)) et vs) = Just (ExpTV (Neut a) t vs, head $ chain vs t [b])
838 where t = neutType' (mkEnv vs) a 879 where t = neutType' (mkEnv vs) a
839mkApp _ = Nothing 880mkApp _ = Nothing
840 881
882mkFunc r@(ExpTV (I.Func n def nt xs) ty vs) | all (supType . tyOf) (r: xs') && n `notElem` ["typeAnn"] && all validChar n
883 = Just (untick n, toExp (def, nt), tyOf r, xs')
884 where
885 xs' = chain vs nt $ reverse xs
886 validChar = isAlphaNum
887mkFunc _ = Nothing
888
841chain vs t@(I.Pi Hidden at y) (a: as) = chain vs (appTy t a) as 889chain vs t@(I.Pi Hidden at y) (a: as) = chain vs (appTy t a) as
842chain vs t xs = chain' vs t xs 890chain vs t xs = chain' vs t xs
843 891
diff --git a/src/LambdaCube/Compiler/Infer.hs b/src/LambdaCube/Compiler/Infer.hs
index 329ed140..ee78422d 100644
--- a/src/LambdaCube/Compiler/Infer.hs
+++ b/src/LambdaCube/Compiler/Infer.hs
@@ -18,10 +18,9 @@ module LambdaCube.Compiler.Infer
18 ( Binder (..), SName, Lit(..), Visibility(..) 18 ( Binder (..), SName, Lit(..), Visibility(..)
19 , Exp (..), Neutral (..), ExpType, GlobalEnv 19 , Exp (..), Neutral (..), ExpType, GlobalEnv
20 , pattern Var, pattern CaseFun, pattern TyCaseFun, pattern App_ 20 , pattern Var, pattern CaseFun, pattern TyCaseFun, pattern App_
21 , pattern Con, pattern TyCon, pattern Pi, pattern Lam, pattern Fun, pattern ELit 21 , pattern Con, pattern TyCon, pattern Pi, pattern Lam, pattern Fun, pattern ELit, pattern Func, pattern LabelEnd
22 , outputType, boolType, trueExp 22 , outputType, boolType, trueExp
23 , down, Subst (..), free 23 , down, Subst (..), free
24 , litType
25 , initEnv, Env(..), pattern EBind2 24 , initEnv, Env(..), pattern EBind2
26 , SI(..), Range(..) -- todo: remove 25 , SI(..), Range(..) -- todo: remove
27 , Info(..), Infos, listAllInfos, listTypeInfos, listTraceInfos 26 , Info(..), Infos, listAllInfos, listTypeInfos, listTraceInfos
@@ -65,7 +64,7 @@ data Exp
65pattern ELit a <- (unfixlabel -> ELit_ a) where ELit = ELit_ 64pattern ELit a <- (unfixlabel -> ELit_ a) where ELit = ELit_
66 65
67data Neutral 66data Neutral
68 = Fun_ MaxDB FunName !Int{-number of missing parameters-} [Exp]{-given parameters, reversed-} Neutral{-unfolded expression-}{-not neut?-} 67 = Fun_ MaxDB FunName [Exp]{-local vars-} !Int{-number of missing parameters-} [Exp]{-given parameters, reversed-} Neutral{-unfolded expression-}{-not neut?-}
69 | CaseFun__ MaxDB CaseFunName [Exp] Neutral 68 | CaseFun__ MaxDB CaseFunName [Exp] Neutral
70 | TyCaseFun__ MaxDB TyCaseFunName [Exp] Neutral 69 | TyCaseFun__ MaxDB TyCaseFunName [Exp] Neutral
71 | App__ MaxDB Neutral Exp 70 | App__ MaxDB Neutral Exp
@@ -78,7 +77,7 @@ data ConName = ConName SName Int{-ordinal number, e.g. Zero:0, Succ:1-} Type
78 77
79data TyConName = TyConName SName Int{-num of indices-} Type [(ConName, Type)]{-constructors-} CaseFunName 78data TyConName = TyConName SName Int{-num of indices-} Type [(ConName, Type)]{-constructors-} CaseFunName
80 79
81data FunName = FunName SName Type 80data FunName = FunName SName (Maybe Exp) Type
82 81
83data CaseFunName = CaseFunName SName Type Int{-num of parameters-} 82data CaseFunName = CaseFunName SName Type Int{-num of parameters-}
84 83
@@ -92,8 +91,8 @@ instance Show ConName where show (ConName n _ _) = n
92instance Eq ConName where ConName _ n _ == ConName _ n' _ = n == n' 91instance Eq ConName where ConName _ n _ == ConName _ n' _ = n == n'
93instance Show TyConName where show (TyConName n _ _ _ _) = n 92instance Show TyConName where show (TyConName n _ _ _ _) = n
94instance Eq TyConName where TyConName n _ _ _ _ == TyConName n' _ _ _ _ = n == n' 93instance Eq TyConName where TyConName n _ _ _ _ == TyConName n' _ _ _ _ = n == n'
95instance Show FunName where show (FunName n _) = n 94instance Show FunName where show (FunName n _ _) = n
96instance Eq FunName where FunName n _ == FunName n' _ = n == n' 95instance Eq FunName where FunName n _ _ == FunName n' _ _ = n == n'
97instance Show CaseFunName where show (CaseFunName n _ _) = caseName n 96instance Show CaseFunName where show (CaseFunName n _ _) = caseName n
98instance Eq CaseFunName where CaseFunName n _ _ == CaseFunName n' _ _ = n == n' 97instance Eq CaseFunName where CaseFunName n _ _ == CaseFunName n' _ _ = n == n'
99instance Show TyCaseFunName where show (TyCaseFunName n _) = MatchName n 98instance Show TyCaseFunName where show (TyCaseFunName n _) = MatchName n
@@ -109,13 +108,13 @@ pattern NoLE <- (isNoLabelEnd -> True)
109isNoLabelEnd (LabelEnd_ _) = False 108isNoLabelEnd (LabelEnd_ _) = False
110isNoLabelEnd _ = True 109isNoLabelEnd _ = True
111 110
112pattern Fun f i xs n <- Fun_ _ f i xs n where Fun f i xs n = Fun_ (foldMap maxDB_ xs {- <> iterateN i lowerDB (maxDB_ n)-}) f i xs n 111pattern Fun f i xs n <- Fun_ _ f _ i xs n where Fun f i xs n = Fun_ (foldMap maxDB_ xs {- <> iterateN i lowerDB (maxDB_ n)-}) f [] i xs n
113pattern UTFun a t b <- Neut (Fun (FunName a t) _ (reverse -> b) NoLE) 112pattern UTFun a t b <- Neut (Fun (FunName a _ t) _ (reverse -> b) NoLE)
114pattern UFunN a b <- UTFun a _ b 113pattern UFunN a b <- UTFun a _ b
115pattern DFun_ fn xs <- Fun fn 0 (reverse -> xs) (Delta _) where 114pattern DFun_ fn xs <- Fun fn 0 (reverse -> xs) (Delta _) where
116 DFun_ fn@(FunName n _) xs = Fun fn 0 (reverse xs) d where 115 DFun_ fn@(FunName n _ _) xs = Fun fn 0 (reverse xs) d where
117 d = Delta $ SData $ getFunDef n $ \xs -> Neut $ Fun fn 0 (reverse xs) d 116 d = Delta $ SData $ getFunDef n $ \xs -> Neut $ Fun fn 0 (reverse xs) d
118pattern TFun' a t b = DFun_ (FunName a t) b 117pattern TFun' a t b = DFun_ (FunName a Nothing t) b
119pattern TFun a t b = Neut (TFun' a t b) 118pattern TFun a t b = Neut (TFun' a t b)
120 119
121pattern CaseFun_ a b c <- CaseFun__ _ a b c where CaseFun_ a b c = CaseFun__ (foldMap maxDB_ b <> maxDB_ c) a b c 120pattern CaseFun_ a b c <- CaseFun__ _ a b c where CaseFun_ a b c = CaseFun__ (foldMap maxDB_ b <> maxDB_ c) a b c
@@ -184,12 +183,9 @@ pattern Succ n <- ConN "Succ" (n:_) where Succ n = tCon "Succ" 1 (TNat :~>
184 183
185pattern CstrT t a b = Neut (CstrT' t a b) 184pattern CstrT t a b = Neut (CstrT' t a b)
186pattern CstrT' t a b = TFun' "'EqCT" (TType :~> Var 0 :~> Var 1 :~> TType) [t, a, b] 185pattern CstrT' t a b = TFun' "'EqCT" (TType :~> Var 0 :~> Var 1 :~> TType) [t, a, b]
187--pattern ReflCstr x = TFun "reflCstr" (TType :~> CstrT TType (Var 0) (Var 0)) [x]
188pattern Coe a b w x = TFun "coe" (TType :~> TType :~> CstrT TType (Var 1) (Var 0) :~> Var 2 :~> Var 2) [a,b,w,x] 186pattern Coe a b w x = TFun "coe" (TType :~> TType :~> CstrT TType (Var 1) (Var 0) :~> Var 2 :~> Var 2) [a,b,w,x]
189pattern ParEval t a b = TFun "parEval" (TType :~> Var 0 :~> Var 1 :~> Var 2) [t, a, b] 187pattern ParEval t a b = TFun "parEval" (TType :~> Var 0 :~> Var 1 :~> Var 2) [t, a, b]
190pattern Undef t = TFun "undefined" (Pi Hidden TType (Var 0)) [t]
191pattern T2 a b = TFun "'T2" (TType :~> TType :~> TType) [a, b] 188pattern T2 a b = TFun "'T2" (TType :~> TType :~> TType) [a, b]
192pattern T2C a b = TFun "t2C" (Unit :~> Unit :~> Unit) [a, b]
193pattern CSplit a b c <- UFunN "'Split" [a, b, c] 189pattern CSplit a b c <- UFunN "'Split" [a, b, c]
194 190
195pattern EInt a = ELit (LInt a) 191pattern EInt a = ELit (LInt a)
@@ -243,7 +239,7 @@ trueExp = EBool True
243pattern LabelEnd x = Neut (LabelEnd_ x) 239pattern LabelEnd x = Neut (LabelEnd_ x)
244 240
245pmLabel' :: FunName -> Int -> [Exp] -> Exp -> Exp 241pmLabel' :: FunName -> Int -> [Exp] -> Exp -> Exp
246pmLabel' (FunName _ _) 0 as (Neut (Delta (SData f))) = f $ reverse as 242pmLabel' (FunName _ _ _) 0 as (Neut (Delta (SData f))) = f $ reverse as
247pmLabel' f i xs (unfixlabel -> Neut y) = Neut $ Fun f i xs y 243pmLabel' f i xs (unfixlabel -> Neut y) = Neut $ Fun f i xs y
248pmLabel' f i xs y = error $ "pmLabel: " ++ show (f, i, length xs, y) 244pmLabel' f i xs y = error $ "pmLabel: " ++ show (f, i, length xs, y)
249 245
@@ -259,6 +255,15 @@ numLams x = 0
259pattern FL' y <- Fun f 0 xs (LabelEnd_ y) 255pattern FL' y <- Fun f 0 xs (LabelEnd_ y)
260pattern FL y <- Neut (FL' y) 256pattern FL y <- Neut (FL' y)
261 257
258pattern Func n def ty xs <- (mkFunc -> Just (n, def, ty, xs))
259
260mkFunc (Neut (Fun (FunName n (Just def) ty) 0 xs LabelEnd_{})) | Just def' <- removeLams (length xs) def = Just (n, def', ty, xs)
261mkFunc _ = Nothing
262
263removeLams 0 (LabelEnd x) = Just x
264removeLams n (Lam x) | n > 0 = Lam <$> removeLams (n-1) x
265removeLams _ _ = Nothing
266
262unfixlabel (FL y) = unfixlabel y 267unfixlabel (FL y) = unfixlabel y
263unfixlabel a = a 268unfixlabel a = a
264 269
@@ -393,7 +398,7 @@ instance Up Neutral where
393 CaseFun__ c _ _ _ -> c 398 CaseFun__ c _ _ _ -> c
394 TyCaseFun__ c _ _ _ -> c 399 TyCaseFun__ c _ _ _ -> c
395 App__ c a b -> c 400 App__ c a b -> c
396 Fun_ c _ _ _ _ -> c 401 Fun_ c _ _ _ _ _ -> c
397 LabelEnd_ x -> maxDB_ x 402 LabelEnd_ x -> maxDB_ x
398 Delta{} -> mempty 403 Delta{} -> mempty
399 404
@@ -402,7 +407,7 @@ instance Up Neutral where
402 CaseFun__ _ a as n -> CaseFun__ mempty a (closedExp <$> as) (closedExp n) 407 CaseFun__ _ a as n -> CaseFun__ mempty a (closedExp <$> as) (closedExp n)
403 TyCaseFun__ _ a as n -> TyCaseFun__ mempty a (closedExp <$> as) (closedExp n) 408 TyCaseFun__ _ a as n -> TyCaseFun__ mempty a (closedExp <$> as) (closedExp n)
404 App__ _ a b -> App__ mempty (closedExp a) (closedExp b) 409 App__ _ a b -> App__ mempty (closedExp a) (closedExp b)
405 Fun_ _ f i x y -> Fun_ mempty f i (closedExp <$> x) y 410 Fun_ _ f l i x y -> Fun_ mempty f l i (closedExp <$> x) y
406 LabelEnd_ a -> LabelEnd_ (closedExp a) 411 LabelEnd_ a -> LabelEnd_ (closedExp a)
407 d@Delta{} -> d 412 d@Delta{} -> d
408 413
@@ -454,7 +459,7 @@ evalCoe a b t d = Coe a b t d
454getFunDef s f = case s of 459getFunDef s f = case s of
455 "unsafeCoerce" -> \case xs@[_, _, x] -> case x of x@FL{} -> x; Neut{} -> f xs; _ -> x 460 "unsafeCoerce" -> \case xs@[_, _, x] -> case x of x@FL{} -> x; Neut{} -> f xs; _ -> x
456 "'EqCT" -> \case [t, a, b] -> cstr t a b 461 "'EqCT" -> \case [t, a, b] -> cstr t a b
457 "reflCstr" -> \case [a] -> reflCstr a 462 "reflCstr" -> \case [a] -> TT
458 "coe" -> \case [a, b, t, d] -> evalCoe a b t d 463 "coe" -> \case [a, b, t, d] -> evalCoe a b t d
459 "'T2" -> \case [a, b] -> t2 a b 464 "'T2" -> \case [a, b] -> t2 a b
460 "t2C" -> \case [a, b] -> t2C a b 465 "t2C" -> \case [a, b] -> t2C a b
@@ -551,19 +556,8 @@ cstr = f []
551 556
552 isElemTy n = n `elem` ["'Bool", "'Float", "'Int"] 557 isElemTy n = n `elem` ["'Bool", "'Float", "'Int"]
553 558
554
555reflCstr = \case
556{-
557 Unit -> TT
558 TType -> TT -- ?
559 Con n xs -> foldl (t2C te{-todo: more precise env-}) TT $ map (reflCstr te{-todo: more precise env-}) xs
560 TyCon n xs -> foldl (t2C te{-todo: more precise env-}) TT $ map (reflCstr te{-todo: more precise env-}) xs
561 x -> {-error $ "reflCstr: " ++ show x-} ReflCstr x
562-}
563 x -> TT
564
565t2C (unfixlabel -> TT) (unfixlabel -> TT) = TT 559t2C (unfixlabel -> TT) (unfixlabel -> TT) = TT
566t2C a b = T2C a b 560t2C a b = TFun "t2C" (Unit :~> Unit :~> Unit) [a, b]
567 561
568t2 (unfixlabel -> Unit) a = a 562t2 (unfixlabel -> Unit) a = a
569t2 a (unfixlabel -> Unit) = a 563t2 a (unfixlabel -> Unit) = a
@@ -696,7 +690,7 @@ litType = \case
696 690
697class NType a where nType :: a -> Type 691class NType a where nType :: a -> Type
698 692
699instance NType FunName where nType (FunName _ t) = t 693instance NType FunName where nType (FunName _ _ t) = t
700instance NType TyConName where nType (TyConName _ _ t _ _) = t 694instance NType TyConName where nType (TyConName _ _ t _ _) = t
701instance NType CaseFunName where nType (CaseFunName _ t _) = t 695instance NType CaseFunName where nType (CaseFunName _ t _) = t
702instance NType TyCaseFunName where nType (TyCaseFunName _ t) = t 696instance NType TyCaseFunName where nType (TyCaseFunName _ t) = t
@@ -802,19 +796,22 @@ inferN_ tellTrace = infer where
802 checkN te x t = tellTrace "check" (showEnvSExpType te x t) $ checkN_ te x t 796 checkN te x t = tellTrace "check" (showEnvSExpType te x t) $ checkN_ te x t
803 797
804 checkN_ te e t 798 checkN_ te e t
805 -- temporal hack 799 | x@(SGlobal (si, MatchName n)) `SAppV` SLamV (Wildcard _) `SAppV` a `SAppV` SVar siv v `SAppV` b <- e
806 | x@(SGlobal (si, MatchName n)) `SAppV` SLamV (Wildcard_ siw _) `SAppV` a `SAppV` SVar siv v `SAppV` b <- e
807 = infer te $ x `SAppV` SLam Visible SType (STyped mempty (subst (v+1) (Var 0) $ up 1 t, TType)) `SAppV` a `SAppV` SVar siv v `SAppV` b 800 = infer te $ x `SAppV` SLam Visible SType (STyped mempty (subst (v+1) (Var 0) $ up 1 t, TType)) `SAppV` a `SAppV` SVar siv v `SAppV` b
808 -- temporal hack 801 -- temporal hack
809 | x@(SGlobal (si, "'NatCase")) `SAppV` SLamV (Wildcard_ siw _) `SAppV` a `SAppV` b `SAppV` SVar siv v <- e 802 | x@(SGlobal (si, "'NatCase")) `SAppV` SLamV (Wildcard _) `SAppV` a `SAppV` b `SAppV` SVar siv v <- e
810 = infer te $ x `SAppV` STyped mempty (Lam $ subst (v+1) (Var 0) $ up 1 t, TNat :~> TType) `SAppV` a `SAppV` b `SAppV` SVar siv v 803 = infer te $ x `SAppV` SLamV (STyped mempty (subst (v+1) (Var 0) $ up1_ (v+2) $ up 1 t, TType)) `SAppV` a `SAppV` b `SAppV` SVar siv v
804 -- temporal hack
805 | x@(SGlobal (si, "'VecSCase")) `SAppV` SLamV (SLamV (Wildcard _)) `SAppV` a `SAppV` b `SAppV` c `SAppV` SVar siv v <- e
806 , TVec (Var n') _ <- snd $ varType "xx" v te
807 = infer te $ x `SAppV` SLamV (SLamV (STyped mempty (subst (n'+2) (Var 1) $ up1_ (n'+3) $ up 2 t, TType))) `SAppV` a `SAppV` b `SAppV` c `SAppV` SVar siv v
808
811{- 809{-
812 -- temporal hack 810 -- temporal hack
813 | x@(SGlobal "'VecSCase") `SAppV` SLamV (SLamV (Wildcard _)) `SAppV` a `SAppV` b `SAppV` c `SAppV` SVar v <- e 811 | x@(SGlobal (si, "'HListCase")) `SAppV` SLamV (SLamV (Wildcard _)) `SAppV` a `SAppV` b `SAppV` SVar siv v <- e
814 = infer te $ x `SAppV` (SLamV (SLamV (STyped (subst (v+1) (Var 0) $ up 2 t, TType)))) `SAppV` a `SAppV` b `SAppV` c `SAppV` SVar v 812 , TVec (Var n') _ <- snd $ varType "xx" v te
813 = infer te $ x `SAppV` SLamV (SLamV (STyped mempty (subst (n'+2) (Var 1) $ up1_ (n'+3) $ up 2 t, TType))) `SAppV` a `SAppV` b `SAppV` SVar siv v
815-} 814-}
816 -- temporal hack
817 | SGlobal (si, "undefined") <- e = focus_' te e (Undef t, t)
818 | SLabelEnd x <- e = checkN (ELabelEnd te) x t 815 | SLabelEnd x <- e = checkN (ELabelEnd te) x t
819 | SApp si h a b <- e = infer (CheckAppType si h t te b) a 816 | SApp si h a b <- e = infer (CheckAppType si h t te b) a
820 | SLam h a b <- e, Pi h' x y <- t, h == h' = do 817 | SLam h a b <- e, Pi h' x y <- t, h == h' = do
@@ -1143,7 +1140,7 @@ handleStmt defs = \case
1143 Primitive n mf (trSExp' -> t_) -> do 1140 Primitive n mf (trSExp' -> t_) -> do
1144 t <- inferType =<< ($ t_) <$> addF 1141 t <- inferType =<< ($ t_) <$> addF
1145 tellType (fst n) t 1142 tellType (fst n) t
1146 addToEnv n mf $ flip (,) t $ lamify t $ Neut . DFun_ (FunName (snd n) t) 1143 addToEnv n mf $ flip (,) t $ lamify t $ Neut . DFun_ (FunName (snd n) Nothing t)
1147 Let n mf mt t_ -> do 1144 Let n mf mt t_ -> do
1148 af <- addF 1145 af <- addF
1149 let t__ = maybe id (flip SAnn . af) mt t_ 1146 let t__ = maybe id (flip SAnn . af) mt t_
@@ -1224,7 +1221,7 @@ withEnv e = local $ second (<> e)
1224mkELet (False, n) x xt = x 1221mkELet (False, n) x xt = x
1225mkELet (True, n) x xt = term 1222mkELet (True, n) x xt = term
1226 where 1223 where
1227 fn = FunName (snd n) xt 1224 fn = FunName (snd n) (Just x) xt
1228 1225
1229 term = pmLabel fn 0 [] $ getFix x 0 1226 term = pmLabel fn 0 [] $ getFix x 0
1230 1227
diff --git a/src/LambdaCube/Compiler/Parser.hs b/src/LambdaCube/Compiler/Parser.hs
index aa9efe28..1bdf19f2 100644
--- a/src/LambdaCube/Compiler/Parser.hs
+++ b/src/LambdaCube/Compiler/Parser.hs
@@ -280,6 +280,9 @@ instance Up Void where
280 280
281instance Up a => Up (SExp' a) where 281instance Up a => Up (SExp' a) where
282 up_ n = mapS' (\sn j i -> SVar sn $ if j < i then j else j+n) (+1) 282 up_ n = mapS' (\sn j i -> SVar sn $ if j < i then j else j+n) (+1)
283 where
284 mapS' = mapS__ (\i si x -> STyped si $ up_ n i x) (const . SGlobal)
285
283 fold f = foldS (\i si x -> fold f i x) mempty $ \sn j i -> f j i 286 fold f = foldS (\i si x -> fold f i x) mempty $ \sn j i -> f j i
284 maxDB_ _ = error "maxDB @SExp" 287 maxDB_ _ = error "maxDB @SExp"
285 288