Skip to content

Commit 7e45696

Browse files
Merge pull request #13824 from dotty-staging/fix-#13809
Add `-Xmacro-check` for Block constructors
2 parents 6ac2be5 + 46c4c63 commit 7e45696

File tree

4 files changed

+311
-3
lines changed

4 files changed

+311
-3
lines changed

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -753,9 +753,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
753753

754754
object Block extends BlockModule:
755755
def apply(stats: List[Statement], expr: Term): Block =
756-
withDefaultPos(tpd.Block(stats, expr))
756+
xCheckMacroBlockOwners(withDefaultPos(tpd.Block(stats, expr)))
757757
def copy(original: Tree)(stats: List[Statement], expr: Term): Block =
758-
tpd.cpy.Block(original)(stats, expr)
758+
xCheckMacroBlockOwners(tpd.cpy.Block(original)(stats, expr))
759759
def unapply(x: Block): (List[Statement], Term) =
760760
(x.statements, x.expr)
761761
end Block
@@ -2913,6 +2913,28 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
29132913
case _ => traverseChildren(t)
29142914
}.traverse(tree)
29152915

2916+
/** Checks that all definitions in this block have the same owner.
2917+
* Nested definitions are ignored and assumed to be correct by construction.
2918+
*/
2919+
private def xCheckMacroBlockOwners(tree: Tree): tree.type =
2920+
if xCheckMacro then
2921+
val defs = new tpd.TreeAccumulator[List[Tree]] {
2922+
def apply(defs: List[Tree], tree: Tree)(using Context): List[Tree] =
2923+
tree match
2924+
case tree: tpd.DefTree => tree :: defs
2925+
case _ => foldOver(defs, tree)
2926+
}.apply(Nil, tree)
2927+
val defOwners = defs.groupBy(_.symbol.owner)
2928+
assert(defOwners.size <= 1,
2929+
s"""Block contains definition with different owners.
2930+
|Found definitions ${defOwners.size} distinct owners: ${defOwners.keys.mkString(", ")}
2931+
|
2932+
|Block: ${Printer.TreeCode.show(tree)}
2933+
|
2934+
|${defOwners.map((owner, trees) => s"Definitions owned by $owner: \n${trees.map(Printer.TreeCode.show).mkString("\n")}").mkString("\n\n")}
2935+
|""".stripMargin)
2936+
tree
2937+
29162938
private def xCheckMacroValidExprs(terms: List[Term]): terms.type =
29172939
if xCheckMacro then terms.foreach(xCheckMacroValidExpr)
29182940
terms
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
package x
2+
3+
import scala.annotation._
4+
import scala.quoted._
5+
6+
trait CB[+T]
7+
8+
object CBM:
9+
def pure[T](t:T):CB[T] = ???
10+
def map[A,B](fa:CB[A])(f: A=>B):CB[B] = ???
11+
def flatMap[A,B](fa:CB[A])(f: A=>CB[B]):CB[B] = ???
12+
def spawn[A](op: =>CB[A]): CB[A] = ???
13+
14+
15+
@compileTimeOnly("await should be inside async block")
16+
def await[T](f: CB[T]): T = ???
17+
18+
19+
trait CpsExpr[T:Type](prev: Seq[Expr[?]]):
20+
21+
def fLast(using Quotes): Expr[CB[T]]
22+
def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T]
23+
def append[A:Type](chunk: CpsExpr[A])(using Quotes): CpsExpr[A]
24+
def syncOrigin(using Quotes): Option[Expr[T]]
25+
def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] =
26+
MappedCpsExpr[T,A](Seq(),this,f)
27+
def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] =
28+
FlatMappedCpsExpr[T,A](Seq(),this,f)
29+
30+
def transformed(using Quotes): Expr[CB[T]] =
31+
import quotes.reflect._
32+
Block(prev.toList.map(_.asTerm), fLast.asTerm).asExprOf[CB[T]]
33+
34+
35+
case class GenericSyncCpsExpr[T:Type](prev: Seq[Expr[?]],last: Expr[T]) extends CpsExpr[T](prev):
36+
37+
override def fLast(using Quotes): Expr[CB[T]] =
38+
'{ CBM.pure(${last}:T) }
39+
40+
override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
41+
copy(prev = exprs ++: prev)
42+
43+
override def syncOrigin(using Quotes): Option[Expr[T]] =
44+
import quotes.reflect._
45+
Some(Block(prev.toList.map(_.asTerm), last.asTerm).asExprOf[T])
46+
47+
override def append[A:Type](e: CpsExpr[A])(using Quotes) =
48+
e.prependExprs(Seq(last)).prependExprs(prev)
49+
50+
override def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] =
51+
copy(last = '{ $f($last) })
52+
53+
override def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] =
54+
GenericAsyncCpsExpr[A](prev, '{ CBM.flatMap(CBM.pure($last))($f) } )
55+
56+
57+
abstract class AsyncCpsExpr[T:Type](
58+
prev: Seq[Expr[?]]
59+
) extends CpsExpr[T](prev):
60+
61+
override def append[A:Type](e: CpsExpr[A])(using Quotes): CpsExpr[A] =
62+
flatMap( '{ (x:T) => ${e.transformed} })
63+
64+
override def syncOrigin(using Quotes): Option[Expr[T]] = None
65+
66+
67+
68+
case class GenericAsyncCpsExpr[T:Type](
69+
prev: Seq[Expr[?]],
70+
fLastExpr: Expr[CB[T]]
71+
) extends AsyncCpsExpr[T](prev):
72+
73+
override def fLast(using Quotes): Expr[CB[T]] = fLastExpr
74+
75+
override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
76+
copy(prev = exprs ++: prev)
77+
78+
override def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] =
79+
MappedCpsExpr(Seq(),this,f)
80+
81+
override def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] =
82+
FlatMappedCpsExpr(Seq(),this,f)
83+
84+
85+
86+
case class MappedCpsExpr[S:Type, T:Type](
87+
prev: Seq[Expr[?]],
88+
point: CpsExpr[S],
89+
mapping: Expr[S=>T]
90+
) extends AsyncCpsExpr[T](prev):
91+
92+
override def fLast(using Quotes): Expr[CB[T]] =
93+
'{ CBM.map(${point.transformed})($mapping) }
94+
95+
override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
96+
copy(prev = exprs ++: prev)
97+
98+
99+
100+
case class FlatMappedCpsExpr[S:Type, T:Type](
101+
prev: Seq[Expr[?]],
102+
point: CpsExpr[S],
103+
mapping: Expr[S => CB[T]]
104+
) extends AsyncCpsExpr[T](prev):
105+
106+
override def fLast(using Quotes): Expr[CB[T]] =
107+
'{ CBM.flatMap(${point.transformed})($mapping) }
108+
109+
override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
110+
copy(prev = exprs ++: prev)
111+
112+
113+
class ValRhsFlatMappedCpsExpr[T:Type, V:Type](using thisQuotes: Quotes)
114+
(
115+
prev: Seq[Expr[?]],
116+
oldValDef: quotes.reflect.ValDef,
117+
cpsRhs: CpsExpr[V],
118+
next: CpsExpr[T]
119+
)
120+
extends AsyncCpsExpr[T](prev) {
121+
122+
override def fLast(using Quotes):Expr[CB[T]] =
123+
import quotes.reflect._
124+
next.syncOrigin match
125+
case Some(nextOrigin) =>
126+
// owner of this block is incorrect
127+
'{
128+
CBM.map(${cpsRhs.transformed})((vx:V) =>
129+
${buildAppendBlockExpr('vx, nextOrigin)})
130+
}
131+
case None =>
132+
'{
133+
CBM.flatMap(${cpsRhs.transformed})((v:V)=>
134+
${buildAppendBlockExpr('v, next.transformed)})
135+
}
136+
137+
138+
override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] =
139+
ValRhsFlatMappedCpsExpr(using thisQuotes)(exprs ++: prev,oldValDef,cpsRhs,next)
140+
141+
override def append[A:quoted.Type](e: CpsExpr[A])(using Quotes) =
142+
ValRhsFlatMappedCpsExpr(using thisQuotes)(prev,oldValDef,cpsRhs,next.append(e))
143+
144+
145+
private def buildAppendBlock(using Quotes)(rhs:quotes.reflect.Term,
146+
exprTerm:quotes.reflect.Term): quotes.reflect.Term =
147+
import quotes.reflect._
148+
import scala.quoted.Expr
149+
150+
val castedOldValDef = oldValDef.asInstanceOf[quotes.reflect.ValDef]
151+
val valDef = ValDef(castedOldValDef.symbol, Some(rhs.changeOwner(castedOldValDef.symbol)))
152+
exprTerm match
153+
case Block(stats,last) =>
154+
Block(valDef::stats, last)
155+
case other =>
156+
Block(valDef::Nil,other)
157+
158+
private def buildAppendBlockExpr[A:Type](using Quotes)(rhs: Expr[V], expr:Expr[A]):Expr[A] =
159+
import quotes.reflect._
160+
buildAppendBlock(rhs.asTerm,expr.asTerm).asExprOf[A]
161+
162+
}
163+
164+
165+
object CpsExpr:
166+
167+
def sync[T:Type](f: Expr[T]): CpsExpr[T] =
168+
GenericSyncCpsExpr[T](Seq(), f)
169+
170+
def async[T:Type](f: Expr[CB[T]]): CpsExpr[T] =
171+
GenericAsyncCpsExpr[T](Seq(), f)
172+
173+
174+
object Async:
175+
176+
transparent inline def transform[T](inline expr: T) = ${
177+
Async.transformImpl[T]('expr)
178+
}
179+
180+
def transformImpl[T:Type](f: Expr[T])(using Quotes): Expr[CB[T]] =
181+
import quotes.reflect._
182+
// println(s"before transformed: ${f.show}")
183+
val cpsExpr = rootTransform[T](f)
184+
val r = '{ CBM.spawn(${cpsExpr.transformed}) }
185+
// println(s"transformed value: ${r.show}")
186+
r
187+
188+
def rootTransform[T:Type](f: Expr[T])(using Quotes): CpsExpr[T] = {
189+
import quotes.reflect._
190+
f match
191+
case '{ while ($cond) { $repeat } } =>
192+
val cpsRepeat = rootTransform(repeat.asExprOf[Unit])
193+
CpsExpr.async('{
194+
def _whilefun():CB[Unit] =
195+
if ($cond) {
196+
${cpsRepeat.flatMap('{(x:Unit) => _whilefun()}).transformed}
197+
} else {
198+
CBM.pure(())
199+
}
200+
_whilefun()
201+
}.asExprOf[CB[T]])
202+
case _ =>
203+
val fTree = f.asTerm
204+
fTree match {
205+
case fun@Apply(fun1@TypeApply(obj2,targs2), args1) =>
206+
if (obj2.symbol.name == "await") {
207+
val awaitArg = args1.head
208+
CpsExpr.async(awaitArg.asExprOf[CB[T]])
209+
} else {
210+
???
211+
}
212+
case Assign(left,right) =>
213+
left match
214+
case id@Ident(x) =>
215+
right.tpe.widen.asType match
216+
case '[r] =>
217+
val cpsRight = rootTransform(right.asExprOf[r])
218+
CpsExpr.async(
219+
cpsRight.map[T](
220+
'{ (x:r) => ${Assign(left,'x.asTerm).asExprOf[T] }
221+
}).transformed )
222+
case _ => ???
223+
case Block(prevs,last) =>
224+
val rPrevs = prevs.map{ p =>
225+
p match
226+
case v@ValDef(vName,vtt,optRhs) =>
227+
optRhs.get.tpe.widen.asType match
228+
case '[l] =>
229+
val cpsRight = rootTransform(optRhs.get.asExprOf[l])
230+
ValRhsFlatMappedCpsExpr(using quotes)(Seq(), v, cpsRight, CpsExpr.sync('{}))
231+
case t: Term =>
232+
// TODO: rootTransform
233+
t.asExpr match
234+
case '{ $p: tp } =>
235+
rootTransform(p)
236+
case other =>
237+
printf(other.show)
238+
throw RuntimeException(s"can't handle term in block: $other")
239+
case other =>
240+
printf(other.show)
241+
throw RuntimeException(s"unknown tree type in block: $other")
242+
}
243+
val rLast = rootTransform(last.asExprOf[T])
244+
val blockResult = rPrevs.foldRight(rLast)((e,s) => e.append(s))
245+
val retval = CpsExpr.async(blockResult.transformed)
246+
retval
247+
//BlockTransform(cpsCtx).run(prevs,last)
248+
case id@Ident(name) =>
249+
CpsExpr.sync(id.asExprOf[T])
250+
case tid@Typed(Ident(name), tp) =>
251+
CpsExpr.sync(tid.asExprOf[T])
252+
case matchTerm@Match(scrutinee, caseDefs) =>
253+
val nCases = caseDefs.map{ old =>
254+
CaseDef.copy(old)(old.pattern, old.guard, rootTransform(old.rhs.asExprOf[T]).transformed.asTerm)
255+
}
256+
CpsExpr.async(Match(scrutinee, nCases).asExprOf[CB[T]])
257+
case inlinedTerm@ Inlined(call,List(),body) =>
258+
rootTransform(body.asExprOf[T])
259+
case constTerm@Literal(_)=>
260+
CpsExpr.sync(constTerm.asExprOf[T])
261+
case _ =>
262+
throw RuntimeException(s"language construction is not supported: ${fTree}")
263+
}
264+
}
265+

tests/neg-macros/i13809/Test_2.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package x
2+
3+
object VP1:
4+
5+
///*
6+
def allocateServiceOperator(optInUsername: Option[String]): CB[Unit] = Async.transform { // error
7+
val username = optInUsername match
8+
case None =>
9+
while(false) {
10+
val nextResult = await(op1())
11+
val countResult = await(op1())
12+
}
13+
case Some(inUsername) =>
14+
val x = await(op1())
15+
inUsername
16+
}
17+
//*/
18+
19+
def op1(): CB[String] = ???

tests/pos-macros/i10151/Macro_1.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ object X:
5555
)
5656
)
5757
)
58-
case Block(stats, last) => Block(stats, transform(last))
58+
case Block(stats, last) =>
59+
val recoverdOwner = stats.headOption.map(_.symbol.owner).getOrElse(Symbol.spliceOwner) // hacky workaround to missing owner tracking in transform
60+
Block(stats, transform(last).changeOwner(recoverdOwner))
5961
case Inlined(x,List(),body) => transform(body)
6062
case l@Literal(x) =>
6163
l.asExpr match

0 commit comments

Comments
 (0)