diff --git a/src/main/scala/avoidancestlc/Lang.scala b/src/main/scala/avoidancestlc/Lang.scala index 1b0762a..21c088c 100644 --- a/src/main/scala/avoidancestlc/Lang.scala +++ b/src/main/scala/avoidancestlc/Lang.scala @@ -1,4 +1,4 @@ -package diamond.avoidancestlc +package diamond.avoidancestlc.core import diamond._ @@ -50,3 +50,64 @@ enum Expr: case EAssign(lhs: Expr, rhs: Expr) case EDeref(e: Expr) case EAscribe(e: Expr, t: QType) + +/* Auxiliary embedded syntax */ + +import Expr._ +import Type._ + +object TypeSyntax: + val ◆ = Fresh() + extension (t: QType) + def ~>(s: QType): TFun = TFun(freshVar("AnnoFun"), freshVar("Arg"), t, s) + extension (id: String) + def ♯(t: TFun): TFun = TFun(id, t.arg, t.t1, t.t2) + extension (t: Type) + def ^(q: Qual): QType = QType(t, q) + def ^(q: QElem): QType = QType(t, Qual(Set(q))) + def ^(q: Unit): QType = QType(t, Qual(Set())) + def ^(q: Tuple): QType = QType(t, Qual(q.toList.asInstanceOf[List[QElem]].toSet)) + // type to qualified type conversion, default is untracked + given Conversion[Type, QType] = QType(_, Qual.untrack) + +object ExprSyntax: + import Expr._ + import Type._ + + val 𝑥 = "x" + val x = EVar("x") + val 𝑦 = "y" + val y = EVar("y") + val 𝑧 = "z" + val z = EVar("z") + val 𝑛 = "n" + val n = EVar("n") + val 𝑓 = "f" + val f = EVar("f") + val 𝑔 = "g" + val g = EVar("g") + + case class BindTy(id: String, ty: QType) { + def ⇐(e: Expr): Bind = Bind(id, e, Some(ty)) + def ~>(rt: QType): TFun = TFun(freshVar("AnnoFun"), id, ty, rt) + } + case class Bind(id: String, rhs: Expr, ty: Option[QType]) + + extension (id: String) + def ⦂(t: QType): BindTy = BindTy(id, t) + def ⇐(e: Expr): Bind = Bind(id, e, None) + + def λ(f: String, x: String)(ft: TFun)(e: => Expr): ELam = ELam(f, x, ft.t1, e, Some(ft.t2)) + def λ(f: String, xt: BindTy, rt: QType)(e: => Expr): ELam = ELam(f, xt.id, xt.ty, e, Some(rt)) + def λ(xt: BindTy)(e: => Expr): ELam = ELam(freshVar("AnnoFun"), xt.id, xt.ty, e, None) + def let(xv: Bind)(e: Expr): Expr = ELet(xv.id, xv.ty, xv.rhs, e) + def alloc(e: Expr): Expr = EAlloc(e) + + extension (e: Expr) + def apply(a: Expr): Expr = EApp(e, a) + def applyFresh(a: Expr): Expr = EApp(e, a, Some(true)) + def apply(n: Int): Expr = EApp(e, ENum(n)) + def deref: Expr = EDeref(e) + def assign(e0: Expr): Expr = EAssign(e, e0) + + given Conversion[Int, ENum] = ENum(_) diff --git a/src/main/scala/avoidancestlc/Parser.scala b/src/main/scala/avoidancestlc/Parser.scala new file mode 100644 index 0000000..552c4ea --- /dev/null +++ b/src/main/scala/avoidancestlc/Parser.scala @@ -0,0 +1,385 @@ +package diamond.avoidancestlc + +import diamond._ +import diamond.parser._ +import org.antlr.v4.runtime._ +import scala.collection.JavaConverters._ + +// Italic style letters are prefixes for anonymous entities generated during parsing + +val letPre = "ℓ" +def letPre(n: Int): String = letPre + "#" + n +val tyFunPre = "𝐹" +def tyFunPre(n: Int): String = tyFunPre + "#" + n +val funPre = "𝑓" +def funPre(n: Int): String = funPre + "#" + n +val varPre = "𝑥" +def varPre(n: Int): String = varPre + "#" + n +val tyVarPre = "𝑋" +def tyVarPre(n: Int): String = tyVarPre + "#" + n + +package ir { + abstract class IR + trait TopLevel + case class Program(tops: List[Expr]) extends IR { + def toCore: core.Expr = { + if (tops.size > 0) { + val (newTops, last) = tops.last match { + case Expr(e) => (tops.dropRight(1), e) + } + newTops.foldRight(last) { + case (Expr(e), last) => + core.Expr.ELet(freshVar(letPre), None, e, last) + } + } else core.Expr.EUnit + } + } + + case class Type(t: core.Type) extends IR { + def toCore = t + } + + case class Qual(q: core.Qual) extends IR { + def toCore = q + } + + case class QType(t: core.Type, q: core.Qual) extends IR { + def toCore: core.QType = core.QType(t, q) + } + + case class ParamList(params: List[Param]) extends IR + case class Param(name: String, qty: core.QType) extends IR + + case class TyParamList(tyParams: List[TyParam]) extends IR + case class TyParam(tvar: String, qvar: String, bound: core.QType) extends IR + + case class ArgList(args: List[core.Expr]) extends IR + case class TyArgList(args: List[core.QType]) extends IR + + case class Expr(e: core.Expr) extends IR { + def toCore: core.Expr = e + } + + abstract class Def extends IR { + def toLet(e: core.Expr): core.Expr.ELet + } + case class MonoFunDef(name: String, params: List[Param], rt: Option[QType], body: core.Expr) extends Def { + def toLet(e: core.Expr): core.Expr.ELet = { + val realParams: List[Param] = + if (params.size == 0) List(Param(freshVar(varPre), core.QType(core.Type.TUnit, core.Qual.untrack))) + else params + val rhs = realParams.zipWithIndex.foldRight(body) { + case ((param, idx), body) => + val funName = if (idx == 0) name else freshVar(funPre) + val realRt = if (idx == realParams.size-1) rt.map(_.toCore) else None + core.Expr.ELam(funName, param.name, param.qty, body, realRt) + } + val rhsTy = None + core.Expr.ELet(name, rhsTy, rhs, e) + } + } + case class PolyFunDef(name: String, tyParams: List[TyParam], params: List[Param], rt: Option[QType], body: core.Expr) extends Def { + def toLet(e: core.Expr): core.Expr.ELet = { + val realParams: List[Param] = + if (params.size == 0) List(Param(freshVar(varPre), core.QType(core.Type.TUnit, core.Qual.untrack))) + else params + val lam = realParams.zipWithIndex.foldRight(body) { + case ((param, idx), body) => + val realRt = if (idx == realParams.size-1) rt.map(_.toCore) else None + core.Expr.ELam(freshVar(funPre), param.name, param.qty, body, realRt) + } + val rhs = tyParams.zipWithIndex.foldRight(lam) { + case ((ty, idx), body) => + val funName = if (idx == 0) name else freshVar(tyFunPre) + ??? + //core.Expr.ETyLam(funName, ty.tvar, ty.qvar, ty.bound, body, None) + } + val rhsTy = None + core.Expr.ELet(name, rhsTy, rhs, e) + } + } +} + +class DiamondVisitor extends DiamondParserBaseVisitor[ir.IR] { + import DiamondParser._ + import ir._ + + val coreTop: core.QType = ??? //core.QType(core.Type.TTop, core.Qual.fresh) + + def error = ??? + + override def visitIdQty(ctx: IdQtyContext): Param = { + Param(ctx.ID.getText.toString, visitQty(ctx.qty).toCore) + } + + override def visitParam(ctx: ParamContext): Param = { + if (ctx.qty != null) Param(freshVar("Arg"), visitQty(ctx.qty).toCore) + else if (ctx.idQty != null) visitIdQty(ctx.idQty) + else error + } + + override def visitParamList(ctx: ParamListContext): ParamList = + ParamList(ctx.param.asScala.map(visitParam(_)).toList) + + override def visitFunTy(ctx: FunTyContext): Type = { + val f = if (ctx.ID != null) ctx.ID.getText.toString else freshVar(funPre) + val args = + if (ctx.paramList != null) visitParamList(ctx.paramList).params + else List(Param(freshVar(varPre), core.QType(core.Type.TUnit, core.Qual.untrack))) + val ret = visitQty(ctx.qty).toCore + val rest = args.zipWithIndex.drop(1).foldRight(ret) { + case ((arg, idx), rt) => + val q = core.Qual(args.take(idx).map(_.name).toSet) + core.QType(core.Type.TFun(freshVar(funPre), arg.name, arg.qty, rt), q) + } + val fty = core.Type.TFun(f, args(0).name, args(0).qty, rest) + Type(fty) + } + + override def visitTyParam(ctx: TyParamContext): TyParam = { + if (ctx.ID.size == 1 && ctx.ty == null) { + TyParam(ctx.ID(0).getText.toString, freshVar(varPre), coreTop) + } else if (ctx.ID.size == 1 && ctx.ty != null) { + TyParam(ctx.ID(0).getText.toString, freshVar(varPre), core.QType(visitTy(ctx.ty).toCore, core.Qual.fresh)) + } else if (ctx.ID.size == 2) { + TyParam(ctx.ID(0).getText.toString, ctx.ID(1).getText.toString, visitQty(ctx.qty).toCore) + } else error + } + + override def visitTyParamList(ctx: TyParamListContext): TyParamList = TyParamList(ctx.tyParam.asScala.map(visitTyParam(_)).toList) + + override def visitTyFunTy(ctx: TyFunTyContext): Type = { + val f = if (ctx.ID != null) ctx.ID.getText.toString else freshVar(tyFunPre) + val args = visitTyParamList(ctx.tyParamList).tyParams + val ret = visitQty(ctx.qty).toCore + //Note: we have not supported multi-argument forall types, they can only be curried + if (args.size == 1) { + ??? + //Type(core.Type.TForall(f, args(0).tvar, args(0).qvar, args(0).bound, ret)) + } else error + } + + override def visitTy(ctx: TyContext): Type = { + if (ctx.ID != null) { + val s = ctx.ID.getText.toString + if (s == "Int") Type(core.Type.TNum) + else if (s == "Unit") Type(core.Type.TUnit) + else if (s == "Top") ??? //Type(core.Type.TTop) + else if (s == "Bool") Type(core.Type.TBool) + else ??? //Type(core.Type.TVar(s)) + } else if (ctx.REF != null) { + Type(core.Type.TRef(visitQty(ctx.qty).toCore)) + } + else if (ctx.funTy != null) visitFunTy(ctx.funTy) + else if (ctx.tyFunTy != null) visitTyFunTy(ctx.tyFunTy) + else visitTy(ctx.ty) + } + + override def visitQual(ctx: QualContext): Qual = { + if (ctx.FRESH != null) Qual(core.Qual.fresh) + else if (ctx.ID != null) Qual(core.Qual.singleton(ctx.ID.getText.toString)) + else if (ctx.qualElems != null) { + val elems: Set[core.QElem] = ctx.qualElems.qualElem.asScala.map( e => + if (e.ID != null) e.ID.getText.toString + else core.Fresh() + ).toSet + Qual(core.Qual(elems)) + } else Qual(core.Qual.untrack) + } + + override def visitQty(ctx: QtyContext): QType = { + val ty = visitTy(ctx.ty).t + if (ctx.qual != null) QType(ty, visitQual(ctx.qual).q) + else QType(ty, core.Qual.untrack) + } + + override def visitNamedParamList(ctx: NamedParamListContext): ParamList = + ParamList(ctx.idQty.asScala.map(visitIdQty(_)).toList) + + override def visitMonoFunDef(ctx: MonoFunDefContext): MonoFunDef = { + val name = ctx.ID.getText.toString + val args = + if (ctx.namedParamList != null) + visitNamedParamList(ctx.namedParamList).params + else List() + val rt = if (ctx.qty != null) Some(visitQty(ctx.qty)) else None + val body = visitExpr(ctx.expr).toCore + MonoFunDef(name, args, rt, body) + } + + override def visitPolyFunDef(ctx: PolyFunDefContext): PolyFunDef = { + val name = ctx.ID.getText.toString + val tyArgs = + if (ctx.tyParamList != null) + visitTyParamList(ctx.tyParamList).tyParams + else List() + val args = + if (ctx.namedParamList != null) + visitNamedParamList(ctx.namedParamList).params + else List() + val rt = if (ctx.qty != null) Some(visitQty(ctx.qty)) else None + val body = visitExpr(ctx.expr).toCore + PolyFunDef(name, tyArgs, args, rt, body) + } + + override def visitLam(ctx: LamContext): Expr = { + val name = if (ctx.ID != null) ctx.ID.getText.toString else freshVar(funPre) + val args = + if (ctx.namedParamList != null) + visitNamedParamList(ctx.namedParamList).params + else List(Param(freshVar(varPre), core.QType(core.Type.TUnit, core.Qual.untrack))) + val rt = if (ctx.qty != null) Some(visitQty(ctx.qty).toCore) else None + val body = visitExpr(ctx.expr).toCore + val ret = args.zipWithIndex.foldRight(body) { + case ((arg, idx), body) => + val realName = if (idx == 0) name else freshVar(funPre) + val realRt = if (idx == args.size-1) rt else None + core.Expr.ELam(realName, arg.name, arg.qty, body, realRt) + } + Expr(ret) + } + + override def visitTyLam(ctx: TyLamContext): Expr = { + val name = if (ctx.ID != null) ctx.ID.getText.toString else freshVar(tyFunPre) + val tyArgs = + if (ctx.tyParamList != null) + visitTyParamList(ctx.tyParamList).tyParams + else List() + val rt = if (ctx.qty != null) Some(visitQty(ctx.qty).toCore) else None + val body = visitExpr(ctx.expr).toCore + //Note: we have not supported multi-argument type lambdas, they can only be curried + if (tyArgs.size == 1) { + ??? + //Expr(core.Expr.ETyLam(name, tyArgs(0).tvar, tyArgs(0).qvar, tyArgs(0).bound, body, rt)) + } else error + } + + override def visitValue(ctx: ValueContext): Expr = { + if (ctx.TRUE != null) Expr(core.Expr.EBool(true)) + else if (ctx.FALSE != null) Expr(core.Expr.EBool(false)) + else if (ctx.UNIT != null) Expr(core.Expr.EUnit) + else if (ctx.INT != null) Expr(core.Expr.ENum(ctx.INT.getText.toInt)) + else if (ctx.lam != null) visitLam(ctx.lam) + else if (ctx.tyLam != null) visitTyLam(ctx.tyLam) + else error + } + + override def visitAlloc(ctx: AllocContext): Expr = Expr(core.Expr.EAlloc(visitExpr(ctx.expr).toCore)) + + override def visitDeref(ctx: DerefContext): Expr = Expr(core.Expr.EDeref(visitExpr(ctx.expr).toCore)) + + override def visitLet(ctx: LetContext): Expr = { + val rhs = visitExpr(ctx.expr(0)).toCore + val body = visitExpr(ctx.expr(1)).toCore + val isGlobal = (ctx.valDecl.TOPVAL != null) + if (ctx.ID != null) { + val x = ctx.ID.getText.toString + Expr(core.Expr.ELet(x, None, rhs, body, isGlobal)) + } else { + val Param(x, qty) = visitIdQty(ctx.idQty) + Expr(core.Expr.ELet(x, Some(qty), rhs, body, isGlobal)) + } + } + + override def visitArgs(ctx: ArgsContext): ArgList = + ArgList(ctx.expr.asScala.map(visitExpr(_).toCore).toList) + + override def visitTyArgs(ctx: TyArgsContext): TyArgList = + TyArgList(ctx.qty.asScala.map(visitQty(_).toCore).toList) + + override def visitExpr(ctx: ExprContext): Expr = { + val e: core.Expr = + if (ctx.ID != null) core.Expr.EVar(ctx.ID.getText.toString) + else if (ctx.op1 != null) { + val e = visitExpr(ctx.expr(0)).toCore + ??? + //core.Expr.EUnaryOp(ctx.op1.getText.toString, e) + } else if (ctx.boolOp2 != null) { + val arg1 = visitExpr(ctx.expr(0)).toCore + val arg2 = visitExpr(ctx.expr(1)).toCore + ??? + //core.Expr.EBinOp(ctx.boolOp2.getText.toString, arg1, arg2) + } else if (ctx.ADD != null) { + val arg1 = visitExpr(ctx.expr(0)).toCore + val arg2 = visitExpr(ctx.expr(1)).toCore + ??? + //core.Expr.EBinOp("+", arg1, arg2) + } else if (ctx.MINUS != null) { + val arg1 = visitExpr(ctx.expr(0)).toCore + val arg2 = visitExpr(ctx.expr(1)).toCore + ??? + //core.Expr.EBinOp("-", arg1, arg2) + } else if (ctx.MULT != null) { + val arg1 = visitExpr(ctx.expr(0)).toCore + val arg2 = visitExpr(ctx.expr(1)).toCore + ??? + //core.Expr.EBinOp("*", arg1, arg2) + } else if (ctx.DIV != null) { + val arg1 = visitExpr(ctx.expr(0)).toCore + val arg2 = visitExpr(ctx.expr(1)).toCore + ??? + //core.Expr.EBinOp("/", arg1, arg2) + } else if (ctx.COLONEQ != null) { + val lhs = visitExpr(ctx.expr(0)).toCore + val rhs = visitExpr(ctx.expr(1)).toCore + core.Expr.EAssign(lhs, rhs) + } else if (ctx.funDef != null) { + val body = visitExpr(ctx.expr(0)).toCore + super.visit(ctx.funDef).asInstanceOf[Def].toLet(body) + } else if (ctx.args != null) { + val f = visitExpr(ctx.expr(0)).toCore + val args = visitArgs(ctx.args).args + val fresh = if (ctx.AT != null) Some(true) else None + args.foldLeft(f) { case (f, arg) => core.Expr.EApp(f, arg, fresh) } + } else if (ctx.tyArgs != null) { + // there is at least one type in tyArgs + val fresh = if (ctx.AT != null) Some(true) else None + val f = visitExpr(ctx.expr(0)).toCore + val tyArgs = visitTyArgs(ctx.tyArgs).args + ??? + //tyArgs.foldLeft(f) { case (f, tyArg) => core.Expr.ETyApp(f, tyArg, fresh) } + } else if (ctx.IF != null && ctx.ELSE != null) { + val cnd = visitExpr(ctx.expr(0)).toCore + val thn = visitExpr(ctx.expr(1)).toCore + val els = visitExpr(ctx.expr(2)).toCore + ??? + //core.Expr.ECond(cnd, thn, els) + } else if (ctx.LPAREN != null && ctx.RPAREN != null) { + // term application without argument + val f = visitExpr(ctx.expr(0)).toCore + val arg = core.Expr.EUnit + core.Expr.EApp(f, arg, None) + } else { + super.visitExpr(ctx).asInstanceOf[Expr].toCore // value, alloc, deref, let, wrapped expr + } + Expr(e) + } + + override def visitWrapExpr(ctx: WrapExprContext): Expr = { + Expr(visitExpr(ctx.expr).toCore) + } + + override def visitProgram(ctx: ProgramContext): IR = { + val exprs = ctx.expr.asScala.map(visit(_)).toList + Program(exprs.asInstanceOf[List[Expr]]) + } +} + +object Parser { + def parse(input: String): ir.Program = { + Counter.reset + val charStream = new ANTLRInputStream(input) + val lexer = new DiamondLexer(charStream) + val tokens = new CommonTokenStream(lexer) + val parser = new DiamondParser(tokens) + val visitor = new DiamondVisitor() + val res = visitor.visit(parser.program).asInstanceOf[ir.Program] + res + } + + def parseFile(filepath: String): ir.Program = parse(scala.io.Source.fromFile(filepath).mkString) + + def parseToCore(input: String) = parse(input).toCore + + def parseFileToCore(filepath: String) = parseFile(filepath).toCore +} \ No newline at end of file diff --git a/src/main/scala/avoidancestlc/TypeCheck.scala b/src/main/scala/avoidancestlc/TypeCheck.scala index 566b07a..71166dc 100644 --- a/src/main/scala/avoidancestlc/TypeCheck.scala +++ b/src/main/scala/avoidancestlc/TypeCheck.scala @@ -1,8 +1,10 @@ package diamond.avoidancestlc import diamond._ -import Type._ -import Expr._ +import diamond.avoidancestlc.core._ + +import core.Type._ +import core.Expr._ /* Typing environment */ @@ -66,6 +68,20 @@ extension (q: Qual) def rename(from: String, to: String): Qual = q.subst(from, Qual.singleton(to)) +/* Auxiliary functions for expressions */ + +extension (e: Expr) + def freeVars: Set[String] = e match + case EUnit | ENum(_) | EBool(_) => Set() + case EVar(x) => Set(x) + case ELam(f, x, at, e, rt) => e.freeVars -- Set(f, x) + case EApp(e1, e2, _) => e1.freeVars ++ e2.freeVars + case ELet(x, _, rhs, body, _) => rhs.freeVars ++ (body.freeVars - x) + case EAlloc(e) => e.freeVars + case EAssign(e1, e2) => e1.freeVars ++ e2.freeVars + case EDeref(e) => e.freeVars + case EAscribe(e, t) => e.freeVars + /* Auxiliary functions for types */ extension (t: Type) @@ -209,6 +225,7 @@ def subtypeCheck(tenv: TEnv, t1: Type, t2: Type): (Qual /*filter*/, Qual /*growt val G1 = TFun(g, x1, qt3, qt4.rename(y, x1)) subtypeCheck(tenv, F1, G1) } else throw new RuntimeException("Impossible") + case _ => throw new RuntimeException(s"Not subtype $t1 <: $t2") } } @@ -410,6 +427,12 @@ def infer(tenv: TEnv, e: Expr): (Qual, QType) = { } else Qual.untrack val fl = fl1 ++ fl2 ++ fl3 ++ (r \ Qual(Set(f, x, Fresh()))) (fl, QType(u, r.subst(x, p1).subst(f, q))) + // We consider a lambda term with type annotation as "ascription" + case ELam(f, x, at, body, Some(rt)) => + val q = Qual((body.freeVars -- Set(f, x)).asInstanceOf[Set[QElem]]) + val tq = QType(TFun(f, x, at, rt), q) + val fl = check(tenv, e, tq) + (fl, tq) } } @@ -417,4 +440,11 @@ def checkInfer(tenv: TEnv, e: Expr, t: Type): (Qual/*filter*/, Qual/*qual*/) = { val (fl1, QType(t1, q)) = infer(tenv, e) val (fl2, gr) = subtypeCheck(tenv, t1, t) (fl1 ++ fl2, q ++ gr) +} + +def topTypeCheck(e: Expr): QType = { + println(e) + Counter.reset + val (fl, qt) = infer(TEnv.empty, e) + qt } \ No newline at end of file diff --git a/src/test/scala/avoidancestlc/TypeCheck.scala b/src/test/scala/avoidancestlc/TypeCheck.scala new file mode 100644 index 0000000..53bd357 --- /dev/null +++ b/src/test/scala/avoidancestlc/TypeCheck.scala @@ -0,0 +1,26 @@ +package diamond.avoidancestlc +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import diamond._ +import diamond.avoidancestlc.core._ + +import core.Type._ +import core.Expr._ +import TypeSyntax._ +import ExprSyntax._ + +import TypeSyntax.given_Conversion_Type_QType +import ExprSyntax.given_Conversion_Int_ENum + +class AvoidanceSTLCTests extends AnyFunSuite { + /* + test("escaping closures") { + val e1 = + let("x" ⇐ alloc(3)) { + λ("f", "z")("f"♯(TNum ~> TNum)) { x.deref } + } + assert(topTypeCheck(e1) == (TFun("f", "z", TNum^(), TNum^()) ^ ◆)) + } + */ +} \ No newline at end of file