diff --git a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala index 5e350ac1eec1..9e5b7f036aa0 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala @@ -85,8 +85,6 @@ class TreeMapWithImplicits extends tpd.TreeMap { } override def transform(tree: Tree)(using Context): Tree = { - def localCtx = - if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx try tree match { case Block(stats, expr) => inContext(nestedScopeCtx(stats)) { @@ -97,19 +95,13 @@ class TreeMapWithImplicits extends tpd.TreeMap { else super.transform(tree) } case tree: DefDef => - inContext(localCtx) { + inContext(localCtx(tree)) { cpy.DefDef(tree)( tree.name, transformParamss(tree.paramss), transform(tree.tpt), transform(tree.rhs)(using nestedScopeCtx(tree.paramss.flatten))) } - case EmptyValDef => - tree - case _: MemberDef => - super.transform(tree)(using localCtx) - case _: PackageDef => - super.transform(tree)(using ctx.withOwner(tree.symbol.moduleClass)) case impl @ Template(constr, parents, self, _) => cpy.Template(tree)( transformSub(constr), diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index ba2dcc2e16d2..423e5e991288 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -1296,6 +1296,9 @@ object Trees { */ protected def inlineContext(call: Tree)(using Context): Context = ctx + /** The context to use when mapping or accumulating over a tree */ + def localCtx(tree: Tree)(using Context): Context + abstract class TreeMap(val cpy: TreeCopier = inst.cpy) { self => def transform(tree: Tree)(using Context): Tree = { inContext( @@ -1304,9 +1307,6 @@ object Trees { else ctx ){ Stats.record(s"TreeMap.transform/$getClass") - def localCtx = - if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx - if (skipTransform(tree)) tree else tree match { case Ident(name) => @@ -1362,11 +1362,11 @@ object Trees { case AppliedTypeTree(tpt, args) => cpy.AppliedTypeTree(tree)(transform(tpt), transform(args)) case LambdaTypeTree(tparams, body) => - inContext(localCtx) { + inContext(localCtx(tree)) { cpy.LambdaTypeTree(tree)(transformSub(tparams), transform(body)) } case TermLambdaTypeTree(params, body) => - inContext(localCtx) { + inContext(localCtx(tree)) { cpy.TermLambdaTypeTree(tree)(transformSub(params), transform(body)) } case MatchTypeTree(bound, selector, cases) => @@ -1384,17 +1384,17 @@ object Trees { case EmptyValDef => tree case tree @ ValDef(name, tpt, _) => - inContext(localCtx) { + inContext(localCtx(tree)) { val tpt1 = transform(tpt) val rhs1 = transform(tree.rhs) cpy.ValDef(tree)(name, tpt1, rhs1) } case tree @ DefDef(name, paramss, tpt, _) => - inContext(localCtx) { + inContext(localCtx(tree)) { cpy.DefDef(tree)(name, transformParamss(paramss), transform(tpt), transform(tree.rhs)) } case tree @ TypeDef(name, rhs) => - inContext(localCtx) { + inContext(localCtx(tree)) { cpy.TypeDef(tree)(name, transform(rhs)) } case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty => @@ -1404,7 +1404,10 @@ object Trees { case Export(expr, selectors) => cpy.Export(tree)(transform(expr), selectors) case PackageDef(pid, stats) => - cpy.PackageDef(tree)(transformSub(pid), transformStats(stats, pid.symbol.moduleClass)(using localCtx)) + val pid1 = transformSub(pid) + inContext(localCtx(tree)) { + cpy.PackageDef(tree)(pid1, transformStats(stats, ctx.owner)) + } case Annotated(arg, annot) => cpy.Annotated(tree)(transform(arg), transform(annot)) case Thicket(trees) => @@ -1450,8 +1453,6 @@ object Trees { foldOver(x, tree)(using ctx.withSource(tree.source)) else { Stats.record(s"TreeAccumulator.foldOver/$getClass") - def localCtx = - if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx tree match { case Ident(name) => x @@ -1506,11 +1507,11 @@ object Trees { case AppliedTypeTree(tpt, args) => this(this(x, tpt), args) case LambdaTypeTree(tparams, body) => - inContext(localCtx) { + inContext(localCtx(tree)) { this(this(x, tparams), body) } case TermLambdaTypeTree(params, body) => - inContext(localCtx) { + inContext(localCtx(tree)) { this(this(x, params), body) } case MatchTypeTree(bound, selector, cases) => @@ -1526,15 +1527,15 @@ object Trees { case UnApply(fun, implicits, patterns) => this(this(this(x, fun), implicits), patterns) case tree @ ValDef(_, tpt, _) => - inContext(localCtx) { + inContext(localCtx(tree)) { this(this(x, tpt), tree.rhs) } case tree @ DefDef(_, paramss, tpt, _) => - inContext(localCtx) { + inContext(localCtx(tree)) { this(this(paramss.foldLeft(x)(apply), tpt), tree.rhs) } case TypeDef(_, rhs) => - inContext(localCtx) { + inContext(localCtx(tree)) { this(x, rhs) } case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty => @@ -1544,7 +1545,7 @@ object Trees { case Export(expr, _) => this(x, expr) case PackageDef(pid, stats) => - this(this(x, pid), stats)(using localCtx) + this(this(x, pid), stats)(using localCtx(tree)) case Annotated(arg, annot) => this(this(x, arg), annot) case Thicket(ts) => diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 60201ae532a0..0bba42f45881 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -565,6 +565,14 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { else foldOver(sym, tree) } + /** The owner to be used in a local context when traversin a tree */ + def localOwner(tree: Tree)(using Context): Symbol = + val sym = tree.symbol + (if sym.is(PackageVal) then sym.moduleClass else sym).orElse(ctx.owner) + + /** The local context to use when traversing trees */ + def localCtx(tree: Tree)(using Context): Context = ctx.withOwner(localOwner(tree)) + override val cpy: TypedTreeCopier = // Type ascription needed to pick up any new members in TreeCopier (currently there are none) TypedTreeCopier() diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 9a5b222617bd..40467dc5be3f 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -535,6 +535,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { // --------- Copier/Transformer/Accumulator classes for untyped trees ----- + def localCtx(tree: Tree)(using Context): Context = ctx + override val cpy: UntypedTreeCopier = UntypedTreeCopier() class UntypedTreeCopier extends TreeCopier { diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index a1e4eb6a9ad6..60f675ff0b0d 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -542,7 +542,7 @@ object Contexts { def iinfo(using Context) = if (ctx.importInfo == null) "" else i"${ctx.importInfo.selectors}%, %" def cinfo(using Context) = - val core = s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = ${iinfo(using ctx)}" + val core = s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = $iinfo" if (ctx ne NoContext) && (ctx.implicits ne ctx.outer.implicits) then s"$core, implicits = ${ctx.implicits}" else diff --git a/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala b/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala index 34ba86ce4c9e..6a246e509e61 100644 --- a/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala @@ -31,11 +31,8 @@ abstract class MacroTransform extends Phase { class Transformer extends TreeMap(cpy = cpyBetweenPhases) { - protected def localCtx(tree: Tree)(using Context): FreshContext = { - val sym = tree.symbol - val owner = if (sym.is(PackageVal)) sym.moduleClass else sym - ctx.fresh.setTree(tree).setOwner(owner) - } + protected def localCtx(tree: Tree)(using Context): FreshContext = + ctx.fresh.setTree(tree).setOwner(localOwner(tree)) override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = { def transformStat(stat: Tree): Tree = stat match {