Skip to content

Tree tm #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Symbols.
import CheckTrees._, Denotations._, Decorators._
import config.Printers._
import typer.ErrorReporting._
import scala.annotation.tailrec

/** Some creators for typed trees */
object tpd extends Trees.Instance[Type] with TypedTreeInfo {
Expand Down Expand Up @@ -413,6 +414,92 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def tpes: List[Type] = xs map (_.tpe)
}

/** A tree map that retypes some nodes if their element types have changed,
* instead of simply copying the original type. The potential retyped nodes
* are those nodes where the element type may be part of the parent type.
*/
class RetypingTreeMap extends TreeMap {
def retypeSelect(tree: Select, qualifier: Tree, name: Name)(implicit ctx: Context) = {
val tree1 = cpy.Select(tree, qualifier, name)
if ((tree1 eq tree) || (qualifier.tpe eq tree.qualifier.tpe)) tree1
else (tree1.tpe match {
case tpe: NamedType => tree1.withType(tpe.derivedSelect(qualifier.tpe))
case _ => tree1
})
}
def retypePair(tree: Pair, left: Tree, right: Tree)(implicit ctx: Context) = {
val tree1 = cpy.Pair(tree, left, right)
if ((tree1 eq tree) || ((left.tpe eq tree.left.tpe) && (right.tpe eq tree.right.tpe))) tree1
else ta.assignType(tree1, left, right)
}
def retypeBlock(tree: Block, stats: List[Tree], expr: Tree)(implicit ctx: Context) = {
val tree1 = cpy.Block(tree, stats, expr)
if ((tree1 eq tree) || (expr.tpe eq tree.expr.tpe)) tree1
else ta.assignType(tree1, stats, expr)
}
def retypeIf(tree: If, cond: Tree, thenp: Tree, elsep: Tree)(implicit ctx: Context) = {
val tree1 = cpy.If(tree, cond, thenp, elsep)
if ((tree1 eq tree) || (thenp.tpe eq tree.thenp.tpe) && (elsep.tpe eq tree.elsep.tpe)) tree1
else ta.assignType(tree1, thenp, elsep)
}
def retypeMatch(tree: Match, selector: Tree, cases: List[CaseDef])(implicit ctx: Context) = {
val tree1 = cpy.Match(tree, selector, cases)
if ((tree1 eq tree) || sameTypes(cases, tree.cases)) tree1
else ta.assignType(tree1, cases)
}
def retypeCaseDef(tree: CaseDef, pat: Tree, guard: Tree, body: Tree)(implicit ctx: Context) = {
val tree1 = cpy.CaseDef(tree, pat, guard, body)
if ((tree eq tree1) || (body.tpe eq tree.body.tpe)) tree1
else ta.assignType(tree1, body)
}
def retypeTry(tree: Try, expr: Tree, handler: Tree, finalizer: Tree)(implicit ctx: Context) = {
val tree1 = cpy.Try(tree, expr, handler, finalizer)
if ((tree1 eq tree) || ((expr.tpe eq tree.expr.tpe) && (handler.tpe eq tree.handler.tpe))) tree
else ta.assignType(tree1, expr, handler)
}
def retypeSeqLiteral(tree: SeqLiteral, elems: List[Tree])(implicit ctx: Context) = {
val tree1 = cpy.SeqLiteral(tree, elems)
if ((tree1 eq tree) || sameTypes(elems, tree.elems)) tree1
else ta.assignType(tree1, elems)
}
def retypeAnnotated(tree: Annotated, annot: Tree, arg: Tree)(implicit ctx: Context) = {
val tree1 = cpy.Annotated(tree, annot, arg)
if ((tree1 eq tree) || (arg.tpe eq tree.arg.tpe) && (annot eq tree.annot)) tree1
else ta.assignType(tree1, annot, arg)
}
override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match {
case tree: Ident => // left here for performance
super.transform(tree)
case tree @ Select(qualifier, name) =>
retypeSelect(tree, transform(qualifier), name)
case tree @ Pair(left, right) =>
retypePair(tree, transform(left), transform(right))
case tree @ Block(stats, expr) =>
retypeBlock(tree, transformStats(stats), transform(expr))
case tree @ If(cond, thenp, elsep) =>
retypeIf(tree, transform(cond), transform(thenp), transform(elsep))
case tree @ Match(selector, cases) =>
retypeMatch(tree, transform(selector), transformSub(cases))
case tree @ CaseDef(pat, guard, body) =>
retypeCaseDef(tree, transform(pat), transform(guard), transform(body))
case tree @ Try(block, handler, finalizer) =>
retypeTry(tree, transform(block), transform(handler), transform(finalizer))
case tree @ SeqLiteral(elems) =>
retypeSeqLiteral(tree, transform(elems))
case tree @ Annotated(annot, arg) =>
retypeAnnotated(tree, transform(annot), transform(arg))
case _ =>
super.transform(tree)
}

@tailrec
final def sameTypes(trees: List[tpd.Tree], trees1: List[tpd.Tree]): Boolean = {
if (trees.isEmpty) trees.isEmpty
else if (trees1.isEmpty) trees.isEmpty
else (trees.head.tpe eq trees1.head.tpe) && sameTypes(trees.tail, trees1.tail)
}
}

/** A map that applies three functions together to a tree and makes sure
* they are coordinated so that the result is well-typed. The functions are
* @param typeMap A function from Type to type that gets applied to the
Expand All @@ -425,7 +512,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
final class TreeTypeMap(
val typeMap: Type => Type = IdentityTypeMap,
val ownerMap: Symbol => Symbol = identity _,
val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends TreeMap {
val treeMap: Tree => Tree = identity _)(implicit ctx: Context) extends RetypingTreeMap {

override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = {
val tree1 = treeMap(tree)
Expand All @@ -436,10 +523,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
cpy.DefDef(ddef, mods, name, tparams1, vparamss1, tmap2.transform(tpt), tmap2.transform(rhs))
case blk @ Block(stats, expr) =>
val (tmap1, stats1) = transformDefs(stats)
cpy.Block(blk, stats1, tmap1.transform(expr))
retypeBlock(blk, stats1, tmap1.transform(expr))
case cdef @ CaseDef(pat, guard, rhs) =>
val tmap = withMappedSyms(patVars(pat))
cpy.CaseDef(cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs))
retypeCaseDef(cdef, tmap.transform(pat), tmap.transform(guard), tmap.transform(rhs))
case tree1 =>
super.transform(tree1)
}
Expand Down
6 changes: 3 additions & 3 deletions src/dotty/tools/dotc/transform/FullParameterization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ trait FullParameterization {
* followed by the class parameters of its enclosing class.
*/
private def allInstanceTypeParams(originalDef: DefDef)(implicit ctx: Context): List[Symbol] =
originalDef.tparams.map(_.symbol) ::: originalDef.symbol.owner.typeParams
originalDef.tparams.map(_.symbol) ::: originalDef.symbol.enclosingClass.typeParams

/** Given an instance method definition `originalDef`, return a
* fully parameterized method definition derived from `originalDef`, which
Expand All @@ -152,7 +152,7 @@ trait FullParameterization {
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree =
polyDefDef(derived, trefs => vrefss => {
val origMeth = originalDef.symbol
val origClass = origMeth.owner.asClass
val origClass = origMeth.enclosingClass.asClass
val origTParams = allInstanceTypeParams(originalDef)
val origVParams = originalDef.vparamss.flatten map (_.symbol)
val thisRef :: argRefs = vrefss.flatten
Expand Down Expand Up @@ -219,7 +219,7 @@ trait FullParameterization {
def forwarder(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree =
ref(derived.termRef)
.appliedToTypes(allInstanceTypeParams(originalDef).map(_.typeRef))
.appliedTo(This(originalDef.symbol.owner.asClass))
.appliedTo(This(originalDef.symbol.enclosingClass.asClass))
.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol)))
.withPos(originalDef.rhs.pos)
}
2 changes: 1 addition & 1 deletion src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TreeChecker {
ownerMatches(symOwner, ctxOwner.owner)
for (tree <- trees if tree.isDef)
assert(ownerMatches(tree.symbol.owner, ctx.owner),
i"bad owner; $tree has owner ${tree.symbol.owner}, expected was ${ctx.owner}")
i"bad owner; $tree has owner ${tree.symbol.owner.showLocated}, expected was ${ctx.owner.showLocated}")
super.index(trees)
}
}
Expand Down