From 68c6e4338dc90852bb5aac4260673d6bbd9ecb6f Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Mon, 2 Nov 2020 11:40:29 +0100 Subject: [PATCH] Fix #9894: Fix owner when constructing a Lambda --- .../tools/dotc/quoted/PickledQuotes.scala | 21 ++---- .../tools/dotc/quoted/QuoteContextImpl.scala | 5 +- .../dotty/tools/dotc/quoted/QuoteUtils.scala | 32 +++++++++ .../dotty/tools/dotc/transform/Splicer.scala | 4 +- tests/pos-macros/i9894/Macro_1.scala | 67 +++++++++++++++++++ tests/pos-macros/i9894/Test_2.scala | 15 +++++ 6 files changed, 124 insertions(+), 20 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala create mode 100644 tests/pos-macros/i9894/Macro_1.scala create mode 100644 tests/pos-macros/i9894/Test_2.scala diff --git a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala index efdda53babe8..098927670e3b 100644 --- a/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala +++ b/compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala @@ -23,6 +23,8 @@ import scala.internal.quoted.PickledQuote import scala.quoted.QuoteContext import scala.collection.mutable +import QuoteUtils._ + object PickledQuotes { import tpd._ @@ -39,14 +41,14 @@ object PickledQuotes { def quotedExprToTree[T](expr: quoted.Expr[T])(using Context): Tree = { val expr1 = expr.asInstanceOf[scala.internal.quoted.Expr[Tree]] QuoteContextImpl.checkScopeId(expr1.scopeId) - healOwner(expr1.tree) + changeOwnerOfTree(expr1.tree, ctx.owner) } /** Transform the expression into its fully spliced TypeTree */ def quotedTypeToTree(tpe: quoted.Type[?])(using Context): Tree = { val tpe1 = tpe.asInstanceOf[scala.internal.quoted.Type[Tree]] QuoteContextImpl.checkScopeId(tpe1.scopeId) - healOwner(tpe1.typeTree) + changeOwnerOfTree(tpe1.typeTree, ctx.owner) } /** Unpickle the tree contained in the TastyExpr */ @@ -195,19 +197,4 @@ object PickledQuotes { tree } - /** Make sure that the owner of this tree is `ctx.owner` */ - def healOwner(tree: Tree)(using Context): Tree = { - val getCurrentOwner = new TreeAccumulator[Option[Symbol]] { - def apply(x: Option[Symbol], tree: tpd.Tree)(using Context): Option[Symbol] = - if (x.isDefined) x - else tree match { - case tree: DefTree => Some(tree.symbol.owner) - case _ => foldOver(x, tree) - } - } - getCurrentOwner(None, tree) match { - case Some(owner) if owner != ctx.owner => tree.changeOwner(owner, ctx.owner) - case _ => tree - } - } } diff --git a/compiler/src/dotty/tools/dotc/quoted/QuoteContextImpl.scala b/compiler/src/dotty/tools/dotc/quoted/QuoteContextImpl.scala index 930306f1bd64..9e65455d052e 100644 --- a/compiler/src/dotty/tools/dotc/quoted/QuoteContextImpl.scala +++ b/compiler/src/dotty/tools/dotc/quoted/QuoteContextImpl.scala @@ -10,6 +10,7 @@ import dotty.tools.dotc.core.Flags._ import dotty.tools.dotc.core.NameKinds import dotty.tools.dotc.core.StdNames._ import dotty.tools.dotc.quoted.reflect._ +import dotty.tools.dotc.quoted.QuoteUtils._ import dotty.tools.dotc.core.Decorators._ import scala.quoted.QuoteContext @@ -741,7 +742,9 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext: object Lambda extends LambdaModule: def apply(tpe: MethodType, rhsFn: List[Tree] => Tree): Block = - tpd.Lambda(tpe, rhsFn) + val meth = dotc.core.Symbols.newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe) + tpd.Closure(meth, tss => changeOwnerOfTree(rhsFn(tss.head), meth)) + def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match { case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _)) if ddef.symbol == meth.symbol => diff --git a/compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala b/compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala new file mode 100644 index 000000000000..56c8d3347205 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala @@ -0,0 +1,32 @@ +package dotty.tools.dotc.quoted + +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.Symbols._ + +object QuoteUtils: + import tpd._ + + /** Get the owner of a tree if it has one */ + def treeOwner(tree: Tree)(using Context): Option[Symbol] = { + val getCurrentOwner = new TreeAccumulator[Option[Symbol]] { + def apply(x: Option[Symbol], tree: tpd.Tree)(using Context): Option[Symbol] = + if (x.isDefined) x + else tree match { + case tree: DefTree => Some(tree.symbol.owner) + case _ => foldOver(x, tree) + } + } + getCurrentOwner(None, tree) + } + + /** Changes the owner of the tree based on the current owner of the tree */ + def changeOwnerOfTree(tree: Tree, owner: Symbol)(using Context): Tree = { + treeOwner(tree) match + case Some(oldOwner) if oldOwner != owner => tree.changeOwner(oldOwner, owner) + case _ => tree + } + +end QuoteUtils diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 34c9edc944e9..5784e543cd30 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -323,10 +323,10 @@ object Splicer { } private def interpretQuote(tree: Tree)(implicit env: Env): Object = - new scala.internal.quoted.Expr(Inlined(EmptyTree, Nil, PickledQuotes.healOwner(tree)).withSpan(tree.span), QuoteContextImpl.scopeId) + new scala.internal.quoted.Expr(Inlined(EmptyTree, Nil, QuoteUtils.changeOwnerOfTree(tree, ctx.owner)).withSpan(tree.span), QuoteContextImpl.scopeId) private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object = - new scala.internal.quoted.Type(PickledQuotes.healOwner(tree), QuoteContextImpl.scopeId) + new scala.internal.quoted.Type(QuoteUtils.changeOwnerOfTree(tree, ctx.owner), QuoteContextImpl.scopeId) private def interpretLiteral(value: Any)(implicit env: Env): Object = value.asInstanceOf[Object] diff --git a/tests/pos-macros/i9894/Macro_1.scala b/tests/pos-macros/i9894/Macro_1.scala new file mode 100644 index 000000000000..67e29d29fb69 --- /dev/null +++ b/tests/pos-macros/i9894/Macro_1.scala @@ -0,0 +1,67 @@ +package x + +import scala.quoted._ + +trait CB[T]: + def map[S](f: T=>S): CB[S] = ??? + +class MyArr[A]: + def map1[B](f: A=>B):MyArr[B] = ??? + def map1Out[B](f: A=> CB[B]): CB[MyArr[B]] = ??? + +def await[T](x:CB[T]):T = ??? + +object CBM: + def pure[T](t:T):CB[T] = ??? + +object X: + + inline def process[T](inline f:T) = ${ + processImpl[T]('f) + } + + def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] = + import qctx.reflect._ + + def transform(term:Term):Term = + term match + case ap@Apply(TypeApply(Select(obj,"map1"),targs),args) => + val nArgs = args.map(x => shiftLambda(x)) + val nSelect = Select.unique(obj, "map1Out") + Apply(TypeApply(nSelect,targs),nArgs) + //Apply.copy(ap)(TypeApply(nSelect,targs),nArgs) + case Apply(TypeApply(Ident("await"),targs),args) => args.head + case Apply(x,y) => + Apply(x, y.map(transform)) + case Block(stats, last) => Block(stats, transform(last)) + case Inlined(x,List(),body) => transform(body) + case l@Literal(x) => + '{ CBM.pure(${term.seal}) }.unseal + case other => + throw RuntimeException(s"Not supported $other") + + def shiftLambda(term:Term): Term = + term match + case lt@Lambda(params, body) => + val paramTypes = params.map(_.tpt.tpe) + val paramNames = params.map(_.name) + val mt = MethodType(paramNames)(_ => paramTypes, _ => Type[CB].unseal.tpe.appliedTo(body.tpe.widen) ) + val r = Lambda(mt, args => changeArgs(params,args,transform(body)) ) + r + case _ => + throw RuntimeException("lambda expected") + + def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term = + val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){ + case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam) + case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected") + } + val changes = new TreeMap() { + override def transformTerm(tree:Term)(using Context): Term = + tree match + case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree)) + case _ => super.transformTerm(tree) + } + changes.transformTerm(body) + + transform(f.unseal).seal.cast[CB[T]] diff --git a/tests/pos-macros/i9894/Test_2.scala b/tests/pos-macros/i9894/Test_2.scala new file mode 100644 index 000000000000..88249a9ac329 --- /dev/null +++ b/tests/pos-macros/i9894/Test_2.scala @@ -0,0 +1,15 @@ +package x + + +object Main { + + def main(args:Array[String]):Unit = + val arr = new MyArr[Int]() + val r = X.process{ + arr.map1( zDebug => + await(CBM.pure(1).map(a => zDebug + a)) + ) + } + println("r") + +} \ No newline at end of file