Skip to content

Introduce Labeled blocks, and use them in PatternMatcher #4982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
type If = tpd.If
type ValDef = tpd.ValDef
type Throw = tpd.Apply
type Labeled = tpd.Labeled
type Return = tpd.Return
type Block = tpd.Block
type Typed = tpd.Typed
Expand Down Expand Up @@ -193,6 +194,7 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
implicit val LabelDefTag: ClassTag[LabelDef] = ClassTag[LabelDef](classOf[LabelDef])
implicit val ValDefTag: ClassTag[ValDef] = ClassTag[ValDef](classOf[ValDef])
implicit val ThrowTag: ClassTag[Throw] = ClassTag[Throw](classOf[Throw])
implicit val LabeledTag: ClassTag[Labeled] = ClassTag[Labeled](classOf[Labeled])
implicit val ReturnTag: ClassTag[Return] = ClassTag[Return](classOf[Return])
implicit val LiteralTag: ClassTag[Literal] = ClassTag[Literal](classOf[Literal])
implicit val BlockTag: ClassTag[Block] = ClassTag[Block](classOf[Block])
Expand Down Expand Up @@ -1076,8 +1078,14 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
def apply(s: Symbol): This = tpd.This(s.asClass)
}

object Labeled extends LabeledDeconstructor {
def _1: Bind = field.bind
def _2: Tree = field.expr
}

object Return extends ReturnDeconstructor {
def get = field.expr
def _1: Tree = field.expr
def _2: Symbol = if (field.from.symbol.isLabel) field.from.symbol else NoSymbol
}

object Ident extends IdentDeconstructor {
Expand Down
18 changes: 17 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,15 @@ object Trees {
type ThisTree[-T >: Untyped] = CaseDef[T]
}

/** label[tpt]: { expr } */
case class Labeled[-T >: Untyped] private[ast] (bind: Bind[T], expr: Tree[T])
extends NameTree[T] {
type ThisTree[-T >: Untyped] = Labeled[T]
def name: Name = bind.name
}

/** return expr
* where `from` refers to the method from which the return takes place
* where `from` refers to the method or label from which the return takes place
* After program transformations this is not necessarily the enclosing method, because
* closures can intervene.
*/
Expand Down Expand Up @@ -886,6 +893,7 @@ object Trees {
type Closure = Trees.Closure[T]
type Match = Trees.Match[T]
type CaseDef = Trees.CaseDef[T]
type Labeled = Trees.Labeled[T]
type Return = Trees.Return[T]
type Try = Trees.Try[T]
type SeqLiteral = Trees.SeqLiteral[T]
Expand Down Expand Up @@ -1028,6 +1036,10 @@ object Trees {
case tree: CaseDef if (pat eq tree.pat) && (guard eq tree.guard) && (body eq tree.body) => tree
case _ => finalize(tree, untpd.CaseDef(pat, guard, body))
}
def Labeled(tree: Tree)(bind: Bind, expr: Tree)(implicit ctx: Context): Labeled = tree match {
case tree: Labeled if (bind eq tree.bind) && (expr eq tree.expr) => tree
case _ => finalize(tree, untpd.Labeled(bind, expr))
}
def Return(tree: Tree)(expr: Tree, from: Tree)(implicit ctx: Context): Return = tree match {
case tree: Return if (expr eq tree.expr) && (from eq tree.from) => tree
case _ => finalize(tree, untpd.Return(expr, from))
Expand Down Expand Up @@ -1202,6 +1214,8 @@ object Trees {
cpy.Match(tree)(transform(selector), transformSub(cases))
case CaseDef(pat, guard, body) =>
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
case Labeled(bind, expr) =>
cpy.Labeled(tree)(transformSub(bind), transform(expr))
case Return(expr, from) =>
cpy.Return(tree)(transform(expr), transformSub(from))
case Try(block, cases, finalizer) =>
Expand Down Expand Up @@ -1334,6 +1348,8 @@ object Trees {
this(this(x, selector), cases)
case CaseDef(pat, guard, body) =>
this(this(this(x, pat), guard), body)
case Labeled(bind, expr) =>
this(this(x, bind), expr)
case Return(expr, from) =>
this(this(x, expr), from)
case Try(block, handler, finalizer) =>
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def Match(selector: Tree, cases: List[CaseDef])(implicit ctx: Context): Match =
ta.assignType(untpd.Match(selector, cases), cases)

def Labeled(bind: Bind, expr: Tree)(implicit ctx: Context): Labeled =
ta.assignType(untpd.Labeled(bind, expr))

def Labeled(sym: TermSymbol, expr: Tree)(implicit ctx: Context): Labeled =
Labeled(Bind(sym, EmptyTree), expr)

def Return(expr: Tree, from: Tree)(implicit ctx: Context): Return =
ta.assignType(untpd.Return(expr, from))

Expand Down Expand Up @@ -594,6 +600,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
}
}

override def Labeled(tree: Tree)(bind: Bind, expr: Tree)(implicit ctx: Context): Labeled =
ta.assignType(untpd.cpy.Labeled(tree)(bind, expr))

override def Return(tree: Tree)(expr: Tree, from: Tree)(implicit ctx: Context): Return =
ta.assignType(untpd.cpy.Return(tree)(expr, from))

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def Closure(env: List[Tree], meth: Tree, tpt: Tree): Closure = new Closure(env, meth, tpt)
def Match(selector: Tree, cases: List[CaseDef]): Match = new Match(selector, cases)
def CaseDef(pat: Tree, guard: Tree, body: Tree): CaseDef = new CaseDef(pat, guard, body)
def Labeled(bind: Bind, expr: Tree): Labeled = new Labeled(bind, expr)
def Return(expr: Tree, from: Tree): Return = new Return(expr, from)
def Try(expr: Tree, cases: List[CaseDef], finalizer: Tree): Try = new Try(expr, cases, finalizer)
def SeqLiteral(elems: List[Tree], elemtpt: Tree): SeqLiteral = new SeqLiteral(elems, elemtpt)
Expand Down
8 changes: 2 additions & 6 deletions compiler/src/dotty/tools/dotc/core/NameKinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,8 @@ object NameKinds {

/** Kinds of unique names generated by the pattern matcher */
val PatMatStdBinderName = new UniqueNameKind("x")
val PatMatPiName = new UniqueNameKind("pi") // FIXME: explain what this is
val PatMatPName = new UniqueNameKind("p") // FIXME: explain what this is
val PatMatOName = new UniqueNameKind("o") // FIXME: explain what this is
val PatMatCaseName = new UniqueNameKind("case")
val PatMatMatchFailName = new UniqueNameKind("matchFail")
val PatMatSelectorName = new UniqueNameKind("selector")
val PatMatAltsName = new UniqueNameKind("matchAlts")
val PatMatResultName = new UniqueNameKind("matchResult")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice simplification!


val LocalOptInlineLocalObj = new UniqueNameKind("ilo")

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ object TypeErasure {
if (defn.isPolymorphicAfterErasure(sym)) eraseParamBounds(sym.info.asInstanceOf[PolyType])
else if (sym.isAbstractType) TypeAlias(WildcardType)
else if (sym.isConstructor) outer.addParam(sym.owner.asClass, erase(tp)(erasureCtx))
else if (sym.is(Label, butNot = Method)) erase.eraseResult(sym.info)(erasureCtx)
else erase.eraseInfo(tp, sym)(erasureCtx) match {
case einfo: MethodType =>
if (sym.isGetter && einfo.resultType.isRef(defn.UnitClass))
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,14 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
else changePrec(GlobalPrec) { toText(sel) ~ keywordStr(" match ") ~ blockText(cases) }
case CaseDef(pat, guard, body) =>
keywordStr("case ") ~ inPattern(toText(pat)) ~ optText(guard)(keywordStr(" if ") ~ _) ~ " => " ~ caseBlockText(body)
case Labeled(bind, expr) =>
changePrec(GlobalPrec) { toText(bind.name) ~ keywordStr("[") ~ toText(bind.symbol.info) ~ keywordStr("]: ") ~ toText(expr) }
case Return(expr, from) =>
changePrec(GlobalPrec) { keywordStr("return") ~ optText(expr)(" " ~ _) }
val sym = from.symbol
if (sym.is(Label))
changePrec(GlobalPrec) { keywordStr("return[") ~ toText(sym.name) ~ keywordStr("]") ~ optText(expr)(" " ~ _) }
else
changePrec(GlobalPrec) { keywordStr("return") ~ optText(expr)(" " ~ _) }
case Try(expr, cases, finalizer) =>
changePrec(GlobalPrec) {
keywordStr("try ") ~ toText(expr) ~ optText(cases)(keywordStr(" catch ") ~ _) ~ optText(finalizer)(keywordStr(" finally ") ~ _)
Expand Down
23 changes: 23 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/MegaPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ object MegaPhase {
def prepareForClosure(tree: Closure)(implicit ctx: Context) = ctx
def prepareForMatch(tree: Match)(implicit ctx: Context) = ctx
def prepareForCaseDef(tree: CaseDef)(implicit ctx: Context) = ctx
def prepareForLabeled(tree: Labeled)(implicit ctx: Context) = ctx
def prepareForReturn(tree: Return)(implicit ctx: Context) = ctx
def prepareForTry(tree: Try)(implicit ctx: Context) = ctx
def prepareForSeqLiteral(tree: SeqLiteral)(implicit ctx: Context) = ctx
Expand Down Expand Up @@ -98,6 +99,7 @@ object MegaPhase {
def transformClosure(tree: Closure)(implicit ctx: Context): Tree = tree
def transformMatch(tree: Match)(implicit ctx: Context): Tree = tree
def transformCaseDef(tree: CaseDef)(implicit ctx: Context): Tree = tree
def transformLabeled(tree: Labeled)(implicit ctx: Context): Tree = tree
def transformReturn(tree: Return)(implicit ctx: Context): Tree = tree
def transformTry(tree: Try)(implicit ctx: Context): Tree = tree
def transformSeqLiteral(tree: SeqLiteral)(implicit ctx: Context): Tree = tree
Expand Down Expand Up @@ -165,6 +167,7 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
case tree: ValDef => goValDef(tree, start)
case tree: DefDef => goDefDef(tree, start)
case tree: TypeDef => goTypeDef(tree, start)
case tree: Labeled => goLabeled(tree, start)
case tree: Bind => goBind(tree, start)
case _ => goOther(tree, start)
}
Expand Down Expand Up @@ -249,6 +252,11 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
implicit val ctx = prepTypeDef(tree, start)(outerCtx)
val rhs = transformTree(tree.rhs, start)(localContext)
goTypeDef(cpy.TypeDef(tree)(tree.name, rhs), start)
case tree: Labeled =>
implicit val ctx = prepLabeled(tree, start)(outerCtx)
val bind = transformTree(tree.bind, start).asInstanceOf[Bind]
val expr = transformTree(tree.expr, start)
goLabeled(cpy.Labeled(tree)(bind, expr), start)
case tree: Bind =>
implicit val ctx = prepBind(tree, start)(outerCtx)
val body = transformTree(tree.body, start)
Expand Down Expand Up @@ -745,6 +753,21 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
}
}

def prepLabeled(tree: Labeled, start: Int)(implicit ctx: Context): Context = {
val phase = nxReturnPrepPhase(start)
if (phase == null) ctx
else prepLabeled(tree, phase.idxInGroup + 1)(phase.prepareForLabeled(tree))
}

def goLabeled(tree: Labeled, start: Int)(implicit ctx: Context): Tree = {
val phase = nxReturnTransPhase(start)
if (phase == null) tree
else phase.transformLabeled(tree)(ctx) match {
case tree1: Labeled => goLabeled(tree1, phase.idxInGroup + 1)
case tree1 => transformNode(tree1, phase.idxInGroup + 1)
}
}

def prepReturn(tree: Return, start: Int)(implicit ctx: Context): Context = {
val phase = nxReturnPrepPhase(start)
if (phase == null) ctx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import collection.mutable
object NonLocalReturns {
import ast.tpd._
def isNonLocalReturn(ret: Return)(implicit ctx: Context) =
ret.from.symbol != ctx.owner.enclosingMethod || ctx.owner.is(Lazy)
!ret.from.symbol.is(Label) && (ret.from.symbol != ctx.owner.enclosingMethod || ctx.owner.is(Lazy))
}

/** Implement non-local returns using NonLocalReturnControl exceptions.
Expand Down
Loading