Skip to content

Improve owner handling in Reflection #10352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/scala/quoted/internal/impl/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ object Matcher {
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
val resType = pattern.tpe
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
val res = Lambda(Symbol.currentOwner, MethodType(names)(_ => argTypes, _ => resType), (meth, x) => bodyFn(x).changeOwner(meth))
matched(res.asExpr)

//
Expand Down
91 changes: 73 additions & 18 deletions compiler/src/scala/quoted/internal/impl/QuoteContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ object QuoteContextImpl {

class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickler, QuoteMatching:

private val yCheck: Boolean =
ctx.settings.Ycheck.value(using ctx).exists(x => x == "all" || x == "macros")

extension [T](self: scala.quoted.Expr[T]):
def show: String =
reflect.TreeMethodsImpl.show(reflect.Term.of(self))
Expand Down Expand Up @@ -118,6 +121,11 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
QuoteContextImpl.this.asExprOf[T](self.asExpr)(using tp)
end extension

extension [ThisTree <: Tree](self: ThisTree):
def changeOwner(newOwner: Symbol): ThisTree =
tpd.TreeOps(self).changeNonLocalOwners(newOwner).asInstanceOf[ThisTree]
end extension

end TreeMethodsImpl

type PackageClause = tpd.PackageDef
Expand Down Expand Up @@ -238,9 +246,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl

object DefDef extends DefDefModule:
def apply(symbol: Symbol, rhsFn: List[TypeRepr] => List[List[Term]] => Option[Term]): DefDef =
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => rhsFn(tparams)(vparamss).getOrElse(tpd.EmptyTree)))
withDefaultPos(tpd.polyDefDef(symbol.asTerm, tparams => vparamss => yCheckedOwners(rhsFn(tparams)(vparamss), symbol).getOrElse(tpd.EmptyTree)))
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, rhs.getOrElse(tpd.EmptyTree))
tpd.cpy.DefDef(original)(name.toTermName, typeParams, paramss, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): Option[(String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term])] =
Some((ddef.name.toString, ddef.typeParams, ddef.paramss, ddef.tpt, optional(ddef.rhs)))
end DefDef
Expand All @@ -264,9 +272,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl

object ValDef extends ValDefModule:
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
tpd.ValDef(symbol.asTerm, rhs.getOrElse(tpd.EmptyTree))
tpd.ValDef(symbol.asTerm, yCheckedOwners(rhs, symbol).getOrElse(tpd.EmptyTree))
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
tpd.cpy.ValDef(original)(name.toTermName, tpt, rhs.getOrElse(tpd.EmptyTree))
tpd.cpy.ValDef(original)(name.toTermName, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(vdef: ValDef): Option[(String, TypeTree, Option[Term])] =
Some((vdef.name.toString, vdef.tpt, optional(vdef.rhs)))

Expand Down Expand Up @@ -357,15 +365,15 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
def tpe: TypeRepr = self.tpe
def underlyingArgument: Term = new tpd.TreeOps(self).underlyingArgument
def underlying: Term = new tpd.TreeOps(self).underlying
def etaExpand: Term = self.tpe.widen match {
def etaExpand(owner: Symbol): Term = self.tpe.widen match {
case mtpe: Types.MethodType if !mtpe.isParamDependent =>
val closureResType = mtpe.resType match {
case t: Types.MethodType => t.toFunctionType()
case t => t
}
val closureTpe = Types.MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
val closureMethod = dotc.core.Symbols.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
tpd.Closure(closureMethod, tss => new tpd.TreeOps(self).appliedToArgs(tss.head).etaExpand)
val closureMethod = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
tpd.Closure(closureMethod, tss => new tpd.TreeOps(self).appliedToArgs(tss.head).etaExpand(closureMethod))
case _ => self
}

Expand Down Expand Up @@ -727,9 +735,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
end ClosureMethodsImpl

object Lambda extends LambdaModule:
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree): Block =
val meth = dotc.core.Symbols.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe)
tpd.Closure(meth, tss => changeOwnerOfTree(rhsFn(tss.head), meth))
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
Expand Down Expand Up @@ -2201,14 +2209,14 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
def requiredModule(path: String): Symbol = dotc.core.Symbols.requiredModule(path)
def requiredMethod(path: String): Symbol = dotc.core.Symbols.requiredMethod(path)
def classSymbol(fullName: String): Symbol = dotc.core.Symbols.requiredClass(fullName)
def newMethod(parent: Symbol, name: String, tpe: TypeRepr): Symbol =
newMethod(parent, name, tpe, Flags.EmptyFlags, noSymbol)
def newMethod(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
dotc.core.Symbols.newSymbol(parent, name.toTermName, flags | dotc.core.Flags.Method, tpe, privateWithin)
def newVal(parent: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
dotc.core.Symbols.newSymbol(parent, name.toTermName, flags, tpe, privateWithin)
def newBind(parent: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol =
dotc.core.Symbols.newSymbol(parent, name.toTermName, flags | Case, tpe)
def newMethod(owner: Symbol, name: String, tpe: TypeRepr): Symbol =
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | dotc.core.Flags.Method, tpe, privateWithin)
def newVal(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags, tpe, privateWithin)
def newBind(owner: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol =
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | Case, tpe)
def noSymbol: Symbol = dotc.core.Symbols.NoSymbol
end Symbol

Expand Down Expand Up @@ -2542,6 +2550,53 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext, QuoteUnpickl
private def withDefaultPos[T <: Tree](fn: Context ?=> T): T =
fn(using ctx.withSource(Position.ofMacroExpansion.source)).withSpan(Position.ofMacroExpansion.span)

/** Checks that all definitions in this tree have the expected owner.
* Nested definitions are ignored and assumed to be correct by construction.
*/
private def yCheckedOwners(tree: Option[Tree], owner: Symbol): tree.type =
if yCheck then
tree match
case Some(tree) =>
yCheckOwners(tree, owner)
case _ =>
tree

/** Checks that all definitions in this tree have the expected owner.
* Nested definitions are ignored and assumed to be correct by construction.
*/
private def yCheckedOwners(tree: Tree, owner: Symbol): tree.type =
if yCheck then
yCheckOwners(tree, owner)
tree

/** Checks that all definitions in this tree have the expected owner.
* Nested definitions are ignored and assumed to be correct by construction.
*/
private def yCheckOwners(tree: Tree, owner: Symbol): Unit =
new tpd.TreeTraverser {
def traverse(t: Tree)(using Context): Unit =
t match
case t: tpd.DefTree =>
val defOwner = t.symbol.owner
assert(defOwner == owner,
s"""Tree had an unexpected owner for ${t.symbol}
|Expected: $owner (${owner.fullName})
|But was: $defOwner (${defOwner.fullName})
|
|
|The code of the definition of ${t.symbol} is
|${TreeMethods.show(t)}
|
|which was found in the code
|${TreeMethods.show(tree)}
|
|which has the AST representation
|${TreeMethods.showExtractors(tree)}
|
|""".stripMargin)
case _ => traverseChildren(t)
}.traverse(tree)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very helpful, greatly improves debuggability 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not have been able to fix the tests without this. It proved to be quite useful.


end reflect

def unpickleExpr[T](pickled: String | List[String], typeHole: (Int, Seq[Any]) => scala.quoted.Type[?], termHole: (Int, Seq[Any], scala.quoted.QuoteContext) => scala.quoted.Expr[?]): scala.quoted.Expr[T] =
Expand Down
28 changes: 25 additions & 3 deletions library/src/scala/quoted/QuoteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
// CONTEXTS //
//////////////

/** Compilation context */
/** Context containing information on the current owner */
type Context <: AnyRef

/** Context of the macro expansion */
Expand Down Expand Up @@ -230,6 +230,12 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
/** Convert this tree to an `quoted.Expr[T]` if the tree is a valid expression or throws */
extension [T](self: Tree)
def asExprOf(using scala.quoted.Type[T]): scala.quoted.Expr[T]

extension [ThisTree <: Tree](self: ThisTree):
/** Changes the owner of the symbols in the tree */
def changeOwner(newOwner: Symbol): ThisTree
end extension

}

/** Tree representing a pacakage clause in the source code */
Expand Down Expand Up @@ -469,7 +475,7 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
def underlying: Term

/** Converts a partally applied term into a lambda expression */
def etaExpand: Term
def etaExpand(owner: Symbol): Term

/** A unary apply node with given argument: `tree(arg)` */
def appliedTo(arg: Term): Term
Expand Down Expand Up @@ -954,8 +960,24 @@ trait QuoteContext { self: internal.QuoteUnpickler & internal.QuoteMatching =>
val Lambda: LambdaModule

trait LambdaModule { this: Lambda.type =>
/** Matches a lambda definition of the form
* ```
* Block((DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
* ```
* Extracts the parameter definitions and body.
*
*/
def unapply(tree: Block): Option[(List[ValDef], Term)]
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree): Block

/** Generates a lambda with the given method type.
* ```
* Block((DefDef(_, _, params :: Nil, _, Some(rhsFn(meth, paramRefs)))) :: Nil, Closure(meth, _))
* ```
* @param owner: owner of the generated `meth` symbol
* @param tpe: Type of the definition
* @param rhsFn: Funtion that recieves the `meth` symbol and the a list of references to the `params`
*/
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block
}

given TypeTest[Tree, If] = IfTypeTest
Expand Down
93 changes: 93 additions & 0 deletions tests/pos-macros/i10151/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package x

import scala.quoted._

trait CB[T]:
def map[S](f: T=>S): CB[S] = ???
def flatMap[S](f: T=>CB[S]): CB[S] = ???

class MyArr[AK,AV]:
def map1[BK,BV](f: ((AK,AV)) => (BK, BV)):MyArr[BK,BV] = ???
def map1Out[BK, BV](f: ((AK,AV)) => CB[(BK,BV)]): CB[MyArr[BK,BV]] = ???

def await[T](x:CB[T]):T = ???

object CBM:
def pure[T](t:T):CB[T] = ???

object X:

inline def process[T](inline f:T) = ${
processImpl[T]('f)
}

def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] =
import qctx.reflect._

def transform(term:Term):Term =
term match
case Apply(TypeApply(Select(obj,"map1"),targs),args) =>
val nArgs = args.map(x => shiftLambda(x))
val nSelect = Select.unique(obj, "map1Out")
Apply(TypeApply(nSelect,targs),nArgs)
case Apply(TypeApply(Ident("await"),targs),args) => args.head
case a@Apply(x,List(y,z)) =>
val mty=MethodType(List("y1"))( _ => List(y.tpe.widen), _ => TypeRepr.of[CB].appliedTo(a.tpe.widen))
val mtz=MethodType(List("z1"))( _ => List(z.tpe.widen), _ => a.tpe.widen)
Apply(
TypeApply(Select.unique(transform(y),"flatMap"),
List(Inferred(a.tpe.widen))
),
List(
Lambda(Symbol.currentOwner, mty, (meth, yArgs) =>
Apply(
TypeApply(Select.unique(transform(z),"map"),
List(Inferred(a.tpe.widen))
),
List(
Lambda(Symbol.currentOwner, mtz, (_, zArgs) => {
val termYArgs = yArgs.asInstanceOf[List[Term]]
val termZArgs = zArgs.asInstanceOf[List[Term]]
Apply(x,List(termYArgs.head,termZArgs.head))
})
)
).changeOwner(meth)
)
)
)
case Block(stats, last) => Block(stats, transform(last))
case Inlined(x,List(),body) => transform(body)
case l@Literal(x) =>
l.asExpr match
case '{ $l: lit } =>
Term.of('{ CBM.pure(${term.asExprOf[lit]}) })
case other =>
throw RuntimeException(s"Not supported $other")

def shiftLambda(term:Term): Term =
term match
case lt@Lambda(params, body) =>
val paramTypes = params.map(_.tpt.tpe)
val paramNames = params.map(_.name)
val mt = MethodType(paramNames)(_ => paramTypes, _ => TypeRepr.of[CB].appliedTo(body.tpe.widen) )
Lambda(Symbol.currentOwner, mt, (meth, args) => changeArgs(params,args,transform(body)).changeOwner(meth) )
case Block(stats, last) =>
Block(stats, shiftLambda(last))
case _ =>
throw RuntimeException("lambda expected")

def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term =
val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){
case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam)
case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected")
}
val changes = new TreeMap() {
override def transformTerm(tree:Term)(using Context): Term =
tree match
case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree))
case _ => super.transformTerm(tree)
}
changes.transformTerm(body)

val r = transform(Term.of(f)).asExprOf[CB[T]]
r
14 changes: 14 additions & 0 deletions tests/pos-macros/i10151/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package x

object Main {

def main(args:Array[String]):Unit =
val arr = new MyArr[Int,Int]()
val r = X.process{
arr.map1( (x,y) =>
( 1, await(CBM.pure(x)) )
)
}
println("r")

}
Loading