Skip to content

Commit 01c9a8d

Browse files
Merge pull request #10352 from dotty-staging/improve-owner-handling
Improve owner handling in Reflection
2 parents 7c7a8c8 + 49dd953 commit 01c9a8d

File tree

12 files changed

+338
-34
lines changed

12 files changed

+338
-34
lines changed

compiler/src/scala/quoted/internal/impl/Matcher.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ object Matcher {
211211
}
212212
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
213213
val resType = pattern.tpe
214-
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
214+
val res = Lambda(Symbol.currentOwner, MethodType(names)(_ => argTypes, _ => resType), (meth, x) => bodyFn(x).changeOwner(meth))
215215
matched(res.asExpr)
216216

217217
//

compiler/src/scala/quoted/internal/impl/QuoteContextImpl.scala

+73-18
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ object QuoteContextImpl {
4545

4646
class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickler, QuoteMatching:
4747

48+
private val yCheck: Boolean =
49+
ctx.settings.Ycheck.value(using ctx).exists(x => x == "all" || x == "macros")
50+
4851
extension [T](self: scala.quoted.Expr[T]):
4952
def show: String =
5053
reflect.TreeMethodsImpl.show(reflect.Term.of(self))
@@ -118,6 +121,11 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
118121
QuoteContextImpl.this.asExprOf[T](self.asExpr)(using tp)
119122
end extension
120123

124+
extension [ThisTree <: Tree](self: ThisTree):
125+
def changeOwner(newOwner: Symbol): ThisTree =
126+
tpd.TreeOps(self).changeNonLocalOwners(newOwner).asInstanceOf[ThisTree]
127+
end extension
128+
121129
end TreeMethodsImpl
122130

123131
type PackageClause = tpd.PackageDef
@@ -238,9 +246,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
238246

239247
object DefDef extends DefDefModule:
240248
def apply(symbol: Symbol, rhsFn: List[TypeRepr] => List[List[Term]] => Option[Term]): DefDef =
241-
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => rhsFn(tparams)(vparamss).getOrElse(tpd.EmptyTree)))
249+
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => yCheckedOwners(rhsFn(tparams)(vparamss), symbol).getOrElse(tpd.EmptyTree)))
242250
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
243-
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, rhs.getOrElse(tpd.EmptyTree))
251+
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
244252
def unapply(ddef: DefDef): Option[(String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term])] =
245253
Some((ddef.name.toString, ddef.typeParams, ddef.paramss, ddef.tpt, optional(ddef.rhs)))
246254
end DefDef
@@ -264,9 +272,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
264272

265273
object ValDef extends ValDefModule:
266274
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
267-
tpd.ValDef(symbol.asTerm, rhs.getOrElse(tpd.EmptyTree))
275+
tpd.ValDef(symbol.asTerm, yCheckedOwners(rhs, symbol).getOrElse(tpd.EmptyTree))
268276
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
269-
tpd.cpy.ValDef(original)(name.toTermName, tpt, rhs.getOrElse(tpd.EmptyTree))
277+
tpd.cpy.ValDef(original)(name.toTermName, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
270278
def unapply(vdef: ValDef): Option[(String, TypeTree, Option[Term])] =
271279
Some((vdef.name.toString, vdef.tpt, optional(vdef.rhs)))
272280

@@ -357,15 +365,15 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
357365
def tpe: TypeRepr = self.tpe
358366
def underlyingArgument: Term = new tpd.TreeOps(self).underlyingArgument
359367
def underlying: Term = new tpd.TreeOps(self).underlying
360-
def etaExpand: Term = self.tpe.widen match {
368+
def etaExpand(owner: Symbol): Term = self.tpe.widen match {
361369
case mtpe: Types.MethodType if !mtpe.isParamDependent =>
362370
val closureResType = mtpe.resType match {
363371
case t: Types.MethodType => t.toFunctionType()
364372
case t => t
365373
}
366374
val closureTpe = Types.MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
367-
val closureMethod = dotc.core.Symbols.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
368-
tpd.Closure(closureMethod, tss => new tpd.TreeOps(self).appliedToArgs(tss.head).etaExpand)
375+
val closureMethod = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
376+
tpd.Closure(closureMethod, tss => new tpd.TreeOps(self).appliedToArgs(tss.head).etaExpand(closureMethod))
369377
case _ => self
370378
}
371379

@@ -727,9 +735,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
727735
end ClosureMethodsImpl
728736

729737
object Lambda extends LambdaModule:
730-
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree): Block =
731-
val meth = dotc.core.Symbols.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe)
732-
tpd.Closure(meth, tss => changeOwnerOfTree(rhsFn(tss.head), meth))
738+
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
739+
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
740+
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head), meth))
733741

734742
def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
735743
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
@@ -2201,14 +2209,14 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
22012209
def requiredModule(path: String): Symbol = dotc.core.Symbols.requiredModule(path)
22022210
def requiredMethod(path: String): Symbol = dotc.core.Symbols.requiredMethod(path)
22032211
def classSymbol(fullName: String): Symbol = dotc.core.Symbols.requiredClass(fullName)
2204-
def newMethod(parent: Symbol, name: String, tpe: TypeRepr): Symbol =
2205-
newMethod(parent, name, tpe, Flags.EmptyFlags, noSymbol)
2206-
def newMethod(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2207-
dotc.core.Symbols.newSymbol(parent, name.toTermName, flags | dotc.core.Flags.Method, tpe, privateWithin)
2208-
def newVal(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2209-
dotc.core.Symbols.newSymbol(parent, name.toTermName, flags, tpe, privateWithin)
2210-
def newBind(parent: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol =
2211-
dotc.core.Symbols.newSymbol(parent, name.toTermName, flags | Case, tpe)
2212+
def newMethod(owner: Symbol, name: String, tpe: TypeRepr): Symbol =
2213+
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
2214+
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2215+
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | dotc.core.Flags.Method, tpe, privateWithin)
2216+
def newVal(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2217+
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags, tpe, privateWithin)
2218+
def newBind(owner: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol =
2219+
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | Case, tpe)
22122220
def noSymbol: Symbol = dotc.core.Symbols.NoSymbol
22132221
end Symbol
22142222

@@ -2542,6 +2550,53 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
25422550
private def withDefaultPos[T <: Tree](fn: Context ?=> T): T =
25432551
fn(using ctx.withSource(Position.ofMacroExpansion.source)).withSpan(Position.ofMacroExpansion.span)
25442552

2553+
/** Checks that all definitions in this tree have the expected owner.
2554+
* Nested definitions are ignored and assumed to be correct by construction.
2555+
*/
2556+
private def yCheckedOwners(tree: Option[Tree], owner: Symbol): tree.type =
2557+
if yCheck then
2558+
tree match
2559+
case Some(tree) =>
2560+
yCheckOwners(tree, owner)
2561+
case _ =>
2562+
tree
2563+
2564+
/** Checks that all definitions in this tree have the expected owner.
2565+
* Nested definitions are ignored and assumed to be correct by construction.
2566+
*/
2567+
private def yCheckedOwners(tree: Tree, owner: Symbol): tree.type =
2568+
if yCheck then
2569+
yCheckOwners(tree, owner)
2570+
tree
2571+
2572+
/** Checks that all definitions in this tree have the expected owner.
2573+
* Nested definitions are ignored and assumed to be correct by construction.
2574+
*/
2575+
private def yCheckOwners(tree: Tree, owner: Symbol): Unit =
2576+
new tpd.TreeTraverser {
2577+
def traverse(t: Tree)(using Context): Unit =
2578+
t match
2579+
case t: tpd.DefTree =>
2580+
val defOwner = t.symbol.owner
2581+
assert(defOwner == owner,
2582+
s"""Tree had an unexpected owner for ${t.symbol}
2583+
|Expected: $owner (${owner.fullName})
2584+
|But was: $defOwner (${defOwner.fullName})
2585+
|
2586+
|
2587+
|The code of the definition of ${t.symbol} is
2588+
|${TreeMethods.show(t)}
2589+
|
2590+
|which was found in the code
2591+
|${TreeMethods.show(tree)}
2592+
|
2593+
|which has the AST representation
2594+
|${TreeMethods.showExtractors(tree)}
2595+
|
2596+
|""".stripMargin)
2597+
case _ => traverseChildren(t)
2598+
}.traverse(tree)
2599+
25452600
end reflect
25462601

25472602
def unpickleExpr[T](pickled: String | List[String], typeHole: (Int, Seq[Any]) => scala.quoted.Type[?], termHole: (Int, Seq[Any], scala.quoted.QuoteContext) => scala.quoted.Expr[?]): scala.quoted.Expr[T] =

library/src/scala/quoted/QuoteContext.scala

+25-3
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
174174
// CONTEXTS //
175175
//////////////
176176

177-
/** Compilation context */
177+
/** Context containing information on the current owner */
178178
type Context <: AnyRef
179179

180180
/** Context of the macro expansion */
@@ -230,6 +230,12 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
230230
/** Convert this tree to an `quoted.Expr[T]` if the tree is a valid expression or throws */
231231
extension [T](self: Tree)
232232
def asExprOf(using scala.quoted.Type[T]): scala.quoted.Expr[T]
233+
234+
extension [ThisTree <: Tree](self: ThisTree):
235+
/** Changes the owner of the symbols in the tree */
236+
def changeOwner(newOwner: Symbol): ThisTree
237+
end extension
238+
233239
}
234240

235241
/** Tree representing a pacakage clause in the source code */
@@ -469,7 +475,7 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
469475
def underlying: Term
470476

471477
/** Converts a partally applied term into a lambda expression */
472-
def etaExpand: Term
478+
def etaExpand(owner: Symbol): Term
473479

474480
/** A unary apply node with given argument: `tree(arg)` */
475481
def appliedTo(arg: Term): Term
@@ -954,8 +960,24 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
954960
val Lambda: LambdaModule
955961

956962
trait LambdaModule { this: Lambda.type =>
963+
/** Matches a lambda definition of the form
964+
* ```
965+
* Block((DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
966+
* ```
967+
* Extracts the parameter definitions and body.
968+
*
969+
*/
957970
def unapply(tree: Block): Option[(List[ValDef], Term)]
958-
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree): Block
971+
972+
/** Generates a lambda with the given method type.
973+
* ```
974+
* Block((DefDef(_, _, params :: Nil, _, Some(rhsFn(meth, paramRefs)))) :: Nil, Closure(meth, _))
975+
* ```
976+
* @param owner: owner of the generated `meth` symbol
977+
* @param tpe: Type of the definition
978+
* @param rhsFn: Funtion that recieves the `meth` symbol and the a list of references to the `params`
979+
*/
980+
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block
959981
}
960982

961983
given TypeTest[Tree, If] = IfTypeTest

tests/pos-macros/i10151/Macro_1.scala

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package x
2+
3+
import scala.quoted._
4+
5+
trait CB[T]:
6+
def map[S](f: T=>S): CB[S] = ???
7+
def flatMap[S](f: T=>CB[S]): CB[S] = ???
8+
9+
class MyArr[AK,AV]:
10+
def map1[BK,BV](f: ((AK,AV)) => (BK, BV)):MyArr[BK,BV] = ???
11+
def map1Out[BK, BV](f: ((AK,AV)) => CB[(BK,BV)]): CB[MyArr[BK,BV]] = ???
12+
13+
def await[T](x:CB[T]):T = ???
14+
15+
object CBM:
16+
def pure[T](t:T):CB[T] = ???
17+
18+
object X:
19+
20+
inline def process[T](inline f:T) = ${
21+
processImpl[T]('f)
22+
}
23+
24+
def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] =
25+
import qctx.reflect._
26+
27+
def transform(term:Term):Term =
28+
term match
29+
case Apply(TypeApply(Select(obj,"map1"),targs),args) =>
30+
val nArgs = args.map(x => shiftLambda(x))
31+
val nSelect = Select.unique(obj, "map1Out")
32+
Apply(TypeApply(nSelect,targs),nArgs)
33+
case Apply(TypeApply(Ident("await"),targs),args) => args.head
34+
case a@Apply(x,List(y,z)) =>
35+
val mty=MethodType(List("y1"))( _ => List(y.tpe.widen), _ => TypeRepr.of[CB].appliedTo(a.tpe.widen))
36+
val mtz=MethodType(List("z1"))( _ => List(z.tpe.widen), _ => a.tpe.widen)
37+
Apply(
38+
TypeApply(Select.unique(transform(y),"flatMap"),
39+
List(Inferred(a.tpe.widen))
40+
),
41+
List(
42+
Lambda(Symbol.currentOwner, mty, (meth, yArgs) =>
43+
Apply(
44+
TypeApply(Select.unique(transform(z),"map"),
45+
List(Inferred(a.tpe.widen))
46+
),
47+
List(
48+
Lambda(Symbol.currentOwner, mtz, (_, zArgs) => {
49+
val termYArgs = yArgs.asInstanceOf[List[Term]]
50+
val termZArgs = zArgs.asInstanceOf[List[Term]]
51+
Apply(x,List(termYArgs.head,termZArgs.head))
52+
})
53+
)
54+
).changeOwner(meth)
55+
)
56+
)
57+
)
58+
case Block(stats, last) => Block(stats, transform(last))
59+
case Inlined(x,List(),body) => transform(body)
60+
case l@Literal(x) =>
61+
l.asExpr match
62+
case '{ $l: lit } =>
63+
Term.of('{ CBM.pure(${term.asExprOf[lit]}) })
64+
case other =>
65+
throw RuntimeException(s"Not supported $other")
66+
67+
def shiftLambda(term:Term): Term =
68+
term match
69+
case lt@Lambda(params, body) =>
70+
val paramTypes = params.map(_.tpt.tpe)
71+
val paramNames = params.map(_.name)
72+
val mt = MethodType(paramNames)(_ => paramTypes, _ => TypeRepr.of[CB].appliedTo(body.tpe.widen) )
73+
Lambda(Symbol.currentOwner, mt, (meth, args) => changeArgs(params,args,transform(body)).changeOwner(meth) )
74+
case Block(stats, last) =>
75+
Block(stats, shiftLambda(last))
76+
case _ =>
77+
throw RuntimeException("lambda expected")
78+
79+
def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term =
80+
val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){
81+
case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam)
82+
case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected")
83+
}
84+
val changes = new TreeMap() {
85+
override def transformTerm(tree:Term)(using Context): Term =
86+
tree match
87+
case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree))
88+
case _ => super.transformTerm(tree)
89+
}
90+
changes.transformTerm(body)
91+
92+
val r = transform(Term.of(f)).asExprOf[CB[T]]
93+
r

tests/pos-macros/i10151/Test_2.scala

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package x
2+
3+
object Main {
4+
5+
def main(args:Array[String]):Unit =
6+
val arr = new MyArr[Int,Int]()
7+
val r = X.process{
8+
arr.map1( (x,y) =>
9+
( 1, await(CBM.pure(x)) )
10+
)
11+
}
12+
println("r")
13+
14+
}

0 commit comments

Comments
 (0)