Skip to content

Commit cc17c4e

Browse files
committed
Add missing Inlined nodes on quotes
1 parent 909f10a commit cc17c4e

File tree

17 files changed

+158
-22
lines changed

17 files changed

+158
-22
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
4242
Super(qual, if (mixName.isEmpty) untpd.EmptyTypeIdent else untpd.Ident(mixName), inConstrCall, mixinClass)
4343

4444
def Apply(fn: Tree, args: List[Tree])(implicit ctx: Context): Apply = {
45-
assert(fn.isInstanceOf[RefTree] || fn.isInstanceOf[GenericApply[_]])
45+
assert(fn.isInstanceOf[RefTree] || fn.isInstanceOf[GenericApply[_]] || fn.isInstanceOf[Inlined])
4646
ta.assignType(untpd.Apply(fn, args), fn, args)
4747
}
4848

4949
def TypeApply(fn: Tree, args: List[Tree])(implicit ctx: Context): TypeApply = {
50-
assert(fn.isInstanceOf[RefTree] || fn.isInstanceOf[GenericApply[_]])
50+
assert(fn.isInstanceOf[RefTree] || fn.isInstanceOf[GenericApply[_]] || fn.isInstanceOf[Inlined])
5151
ta.assignType(untpd.TypeApply(fn, args), fn, args)
5252
}
5353

compiler/src/dotty/tools/dotc/core/quoted/PickledQuotes.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ object PickledQuotes {
131131
val x1 = SyntheticValDef(NameKinds.UniqueName.fresh("x".toTermName), x)
132132
def x1Ref() = ref(x1.symbol)
133133
def rec(f: Tree): Tree = f match {
134+
case Inlined(call, bindings, expansion) =>
135+
// this case must go before closureDef to avoid dropping the inline node
136+
cpy.Inlined(f)(call, bindings, rec(expansion))
134137
case closureDef(ddef) =>
135138
val paramSym = ddef.vparamss.head.head.symbol
136139
new TreeTypeMap(
@@ -140,8 +143,6 @@ object PickledQuotes {
140143
).transform(ddef.rhs)
141144
case Block(stats, expr) =>
142145
seq(stats, rec(expr))
143-
case Inlined(call, bindings, expansion) =>
144-
Inlined(call, bindings, rec(expansion))
145146
case _ =>
146147
f.select(nme.apply).appliedTo(x1Ref())
147148
}

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ package transform
44
import dotty.tools.dotc.ast.{Trees, tpd, untpd}
55
import scala.collection.mutable
66
import core._
7-
import typer.{Checking, VarianceChecker}
7+
import dotty.tools.dotc.typer.Checking
8+
import dotty.tools.dotc.typer.Inliner
9+
import dotty.tools.dotc.typer.VarianceChecker
810
import Types._, Contexts._, Names._, Flags._, DenotTransformers._, Phases._
911
import SymDenotations._, StdNames._, Annotations._, Trees._, Scopes._
1012
import Decorators._
@@ -237,17 +239,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
237239
super.transform(tree1)
238240
}
239241
case Inlined(call, bindings, expansion) if !call.isEmpty =>
240-
// Leave only a call trace consisting of
241-
// - a reference to the top-level class from which the call was inlined,
242-
// - the call's position
243-
// in the call field of an Inlined node.
244-
// The trace has enough info to completely reconstruct positions.
245-
// The minimization is done for two reasons:
246-
// 1. To save space (calls might contain large inline arguments, which would otherwise
247-
// be duplicated
248-
// 2. To enable correct pickling (calls can share symbols with the inlined code, which
249-
// would trigger an assertion when pickling).
250-
val callTrace = Ident(call.symbol.topLevelClass.typeRef).withPos(call.pos)
242+
val callTrace = Inliner.inlineCallTrace(call.symbol, call.pos)
251243
cpy.Inlined(tree)(callTrace, transformSub(bindings), transform(expansion)(inlineContext(call)))
252244
case tree: Template =>
253245
withNoCheckNews(tree.parents.flatMap(newPart)) {

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ object Splicer {
4141
val interpreter = new Interpreter(pos, classLoader)
4242
try {
4343
// Some parts of the macro are evaluated during the unpickling performed in quotedExprToTree
44-
val interpreted = interpreter.interpret[scala.quoted.Expr[Any]](tree)
45-
interpreted.fold(tree)(x => PickledQuotes.quotedExprToTree(x))
44+
val interpretedExpr = interpreter.interpret[scala.quoted.Expr[Any]](tree)
45+
interpretedExpr.fold(tree)(x => PickledQuotes.quotedExprToTree(x))
4646
}
4747
catch {
4848
case ex: scala.quoted.QuoteError =>

compiler/src/dotty/tools/dotc/transform/Staging.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import typer.Implicits.SearchFailureType
1515
import scala.collection.mutable
1616
import dotty.tools.dotc.core.StdNames._
1717
import dotty.tools.dotc.core.quoted._
18+
import dotty.tools.dotc.typer.Inliner
1819
import dotty.tools.dotc.util.SourcePosition
1920

2021

@@ -381,7 +382,12 @@ class Staging extends MacroTransformWithImplicits {
381382
capturers(body.symbol)(body)
382383
case _=>
383384
val (body1, splices) = nested(isQuote = true).split(body)
384-
if (level == 0 && !ctx.inInlineMethod) pickledQuote(body1, splices, body.tpe, isType).withPos(quote.pos)
385+
if (level == 0 && !ctx.inInlineMethod) {
386+
val body2 =
387+
if (body1.isType) body1
388+
else Inlined(Inliner.inlineCallTrace(ctx.owner, quote.pos), Nil, body1)
389+
pickledQuote(body2, splices, body.tpe, isType).withPos(quote.pos)
390+
}
385391
else {
386392
// In top-level splice in an inline def. Keep the tree as it is, it will be transformed at inline site.
387393
body
@@ -432,7 +438,8 @@ class Staging extends MacroTransformWithImplicits {
432438
else if (level == 1) {
433439
val (body1, quotes) = nested(isQuote = false).split(splice.qualifier)
434440
val tpe = outer.embedded.getHoleType(splice)
435-
makeHole(body1, quotes, tpe).withPos(splice.pos)
441+
val hole = makeHole(body1, quotes, tpe).withPos(splice.pos)
442+
if (splice.isType) hole else Inlined(EmptyTree, Nil, hole)
436443
}
437444
else if (enclosingInlineds.nonEmpty) { // level 0 in an inlined call
438445
val spliceCtx = ctx.outer // drop the last `inlineContext`

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ object Inliner {
149149
(new Reposition).transformInline(inlined)
150150
}
151151
}
152+
153+
/** Leave only a call trace consisting of
154+
* - a reference to the top-level class from which the call was inlined,
155+
* - the call's position
156+
* in the call field of an Inlined node.
157+
* The trace has enough info to completely reconstruct positions.
158+
*/
159+
def inlineCallTrace(callSym: Symbol, pos: Position)(implicit ctx: Context): Tree =
160+
Ident(callSym.topLevelClass.typeRef).withPos(pos)
152161
}
153162

154163
/** Produces an inlined version of `call` via its `inlined` method.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
assertImpl: Test$.main(Test_2.scala:7)
2+
true
3+
assertImpl: Test$.main(Test_2.scala:8)
4+
false
5+
assertImpl: Test$.main(Test_2.scala:9)
6+
hi: Test$.main(Test_2.scala:10)
7+
hi again: Test$.main(Test_2.scala:11)
8+
false
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.quoted._
2+
3+
object Macros {
4+
def printStack(tag: String): Unit = {
5+
println(tag + ": "+ new Exception().getStackTrace().apply(1))
6+
}
7+
def assertImpl(expr: Expr[Boolean]) = '{
8+
printStack("assertImpl")
9+
println(~expr)
10+
}
11+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
object Test {
2+
3+
inline def assert2(expr: => Boolean): Unit = ~Macros.assertImpl('(expr))
4+
5+
def main(args: Array[String]): Unit = {
6+
val x = 1
7+
assert2(x != 0)
8+
assert2(x == 0)
9+
assert2 {
10+
Macros.printStack("hi")
11+
Macros.printStack("hi again")
12+
x == 0
13+
}
14+
}
15+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
assertImpl: Test$.main(Test_2.scala:7)
2+
true
3+
assertImpl: Test$.main(Test_2.scala:8)
4+
false
5+
assertImpl: Test$.main(Test_2.scala:9)
6+
hi: Test$.main(Test_2.scala:10)
7+
hi again: Test$.main(Test_2.scala:11)
8+
false
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import scala.quoted._
2+
3+
object Macros {
4+
def printStack(tag: String): Unit = {
5+
println(tag + ": "+ new Exception().getStackTrace().apply(1))
6+
}
7+
def assertImpl(expr: Expr[Boolean]) = '{
8+
printStack("assertImpl")
9+
println(~expr)
10+
}
11+
12+
inline def assert2(expr: => Boolean): Unit = ~Macros.assertImpl('(expr))
13+
14+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
object Test {
2+
3+
import Macros._
4+
5+
def main(args: Array[String]): Unit = {
6+
val x = 1
7+
assert2(x != 0)
8+
assert2(x == 0)
9+
assert2 {
10+
Macros.printStack("hi")
11+
Macros.printStack("hi again")
12+
x == 0
13+
}
14+
}
15+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
42
2+
43
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import scala.quoted._
2+
3+
import scala.quoted.Toolbox.Default._
4+
5+
object Test {
6+
7+
def main(args: Array[String]): Unit = {
8+
val f = f1.run
9+
println(f(42))
10+
println(f(43))
11+
}
12+
13+
def f1: Expr[Int => Int] = '{ n => ~f2('(n)) }
14+
def f2: Expr[Int => Int] = '{ n => ~f3('(n)) }
15+
def f3: Expr[Int => Int] = '{ n => ~f4('(n)) }
16+
def f4: Expr[Int => Int] = '{ n => n }
17+
}

tests/run-with-compiler/quote-unrolled-foreach.check

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
var i: scala.Int = 0
3434
while (i.<(size)) {
3535
val element: scala.Int = arr.apply(i)
36-
3736
((i: scala.Int) => java.lang.System.out.println(i)).apply(element)
3837
i = i.+(1)
3938
}
@@ -86,7 +85,6 @@
8685
var i: scala.Int = 0
8786
while (i.<(size)) {
8887
val element: scala.Int = arr1.apply(i)
89-
9088
((x: scala.Int) => scala.Predef.println(x)).apply(element)
9189
i = i.+(1)
9290
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
xyz
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import scala.quoted._
2+
3+
object Test {
4+
5+
sealed trait Var {
6+
def get: Expr[String]
7+
def update(x: Expr[String]): Expr[Unit]
8+
}
9+
10+
object Var {
11+
def apply(init: Expr[String])(body: Var => Expr[String]): Expr[String] = '{
12+
var x = ~init
13+
~body(
14+
new Var {
15+
def get: Expr[String] = '(x)
16+
def update(e: Expr[String]): Expr[Unit] = '{ x = ~e }
17+
}
18+
)
19+
}
20+
}
21+
22+
23+
def test1(): Expr[String] = Var('("abc")) { x =>
24+
'{
25+
~x.update('("xyz"))
26+
~x.get
27+
}
28+
}
29+
30+
def main(args: Array[String]): Unit = {
31+
implicit val toolbox: scala.quoted.Toolbox = scala.quoted.Toolbox.make
32+
33+
println(test1().run)
34+
}
35+
}
36+
37+
38+

0 commit comments

Comments
 (0)