Skip to content

Commit 6ea3ea6

Browse files
Replace quoted type variables in signature of HOAS pattern result (#16951)
To be able to construct the lambda returned by the HOAS pattern we need: first resolve the type variables and then use the result to construct the signature of the lambdas. To simplify this transformation, `QuoteMatcher` returns a `Seq[MatchResult]` instead of an untyped `Tuple` containing `Expr[?]`. The tuple is created once we have accumulated and processed all extracted values. Fixes #15165
2 parents 8020c77 + 20174d7 commit 6ea3ea6

File tree

9 files changed

+164
-74
lines changed

9 files changed

+164
-74
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package dotty.tools.dotc.util
2+
3+
import scala.util.boundary
4+
5+
/** Return type that indicates that the method returns a T or aborts to the enclosing boundary with a `None` */
6+
type optional[T] = boundary.Label[None.type] ?=> T
7+
8+
/** A prompt for `Option`, which establishes a boundary which `_.?` on `Option` can return */
9+
object optional:
10+
inline def apply[T](inline body: optional[T]): Option[T] =
11+
boundary(Some(body))
12+
13+
extension [T](r: Option[T])
14+
inline def ? (using label: boundary.Label[None.type]): T = r match
15+
case Some(x) => x
16+
case None => boundary.break(None)
17+
18+
inline def break()(using label: boundary.Label[None.type]): Nothing =
19+
boundary.break(None)

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

+79-61
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
package scala.quoted
22
package runtime.impl
33

4-
54
import dotty.tools.dotc.ast.tpd
65
import dotty.tools.dotc.core.Contexts.*
76
import dotty.tools.dotc.core.Flags.*
87
import dotty.tools.dotc.core.Names.*
98
import dotty.tools.dotc.core.Types.*
109
import dotty.tools.dotc.core.StdNames.nme
1110
import dotty.tools.dotc.core.Symbols.*
11+
import dotty.tools.dotc.util.optional
1212

1313
/** Matches a quoted tree against a quoted pattern tree.
1414
* A quoted pattern tree may have type and term holes in addition to normal terms.
@@ -103,12 +103,13 @@ import dotty.tools.dotc.core.Symbols.*
103103
object QuoteMatcher {
104104
import tpd.*
105105

106-
// TODO improve performance
107-
108106
// TODO use flag from Context. Maybe -debug or add -debug-macros
109107
private inline val debug = false
110108

111-
import Matching._
109+
/** Sequence of matched expressions.
110+
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices.
111+
*/
112+
type MatchingExprs = Seq[MatchResult]
112113

113114
/** A map relating equivalent symbols from the scrutinee and the pattern
114115
* For example in
@@ -121,32 +122,34 @@ object QuoteMatcher {
121122

122123
private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)
123124

124-
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[Tuple] =
125+
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[MatchingExprs] =
125126
given Env = Map.empty
126-
scrutineeTree =?= patternTree
127+
optional:
128+
scrutineeTree =?= patternTree
127129

128130
/** Check that all trees match with `mtch` and concatenate the results with &&& */
129-
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match {
131+
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => MatchingExprs): optional[MatchingExprs] = (l1, l2) match {
130132
case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch)
131133
case (Nil, Nil) => matched
132134
case _ => notMatched
133135
}
134136

135137
extension (scrutinees: List[Tree])
136-
private def =?= (patterns: List[Tree])(using Env, Context): Matching =
138+
private def =?= (patterns: List[Tree])(using Env, Context): optional[MatchingExprs] =
137139
matchLists(scrutinees, patterns)(_ =?= _)
138140

139141
extension (scrutinee0: Tree)
140142

141143
/** Check that the trees match and return the contents from the pattern holes.
142-
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
144+
* Return a sequence containing all the contents in the holes.
145+
* If it does not match, continues to the `optional` with `None`.
143146
*
144147
* @param scrutinee The tree being matched
145148
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
146149
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
147-
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
150+
* @return The sequence with the contents of the holes of the matched expression.
148151
*/
149-
private def =?= (pattern0: Tree)(using Env, Context): Matching =
152+
private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] =
150153

151154
/* Match block flattening */ // TODO move to cases
152155
/** Normalize the tree */
@@ -203,31 +206,12 @@ object QuoteMatcher {
203206
// Matches an open term and wraps it into a lambda that provides the free variables
204207
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
205208
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
206-
def hoasClosure = {
207-
val names: List[TermName] = args.map {
208-
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
209-
case arg => arg.symbol.name.asTermName
210-
}
211-
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
212-
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
213-
val meth = newAnonFun(ctx.owner, methTpe)
214-
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
215-
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
216-
val body = new TreeMap {
217-
override def transform(tree: Tree)(using Context): Tree =
218-
tree match
219-
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
220-
case tree => super.transform(tree)
221-
}.transform(scrutinee)
222-
TreeOps(body).changeNonLocalOwners(meth)
223-
}
224-
Closure(meth, bodyFn)
225-
}
209+
val env = summon[Env]
226210
val capturedArgs = args.map(_.symbol)
227-
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
211+
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
228212
withEnv(captureEnv) {
229213
scrutinee match
230-
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
214+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
231215
case _ => notMatched
232216
}
233217

@@ -431,7 +415,6 @@ object QuoteMatcher {
431415
case _ => scrutinee
432416
val pattern = patternTree.symbol
433417

434-
435418
devirtualizedScrutinee == pattern
436419
|| summon[Env].get(devirtualizedScrutinee).contains(pattern)
437420
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
@@ -452,32 +435,67 @@ object QuoteMatcher {
452435
accumulator.apply(Set.empty, term)
453436
}
454437

455-
/** Result of matching a part of an expression */
456-
private type Matching = Option[Tuple]
457-
458-
private object Matching {
459-
460-
def notMatched: Matching = None
461-
462-
val matched: Matching = Some(Tuple())
463-
464-
def matched(tree: Tree)(using Context): Matching =
465-
Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent)))
466-
467-
extension (self: Matching)
468-
def asOptionOfTuple: Option[Tuple] = self
469-
470-
/** Concatenates the contents of two successful matchings or return a `notMatched` */
471-
def &&& (that: => Matching): Matching = self match {
472-
case Some(x) =>
473-
that match {
474-
case Some(y) => Some(x ++ y)
475-
case _ => None
476-
}
477-
case _ => None
478-
}
479-
end extension
480-
481-
}
438+
enum MatchResult:
439+
/** Closed pattern extracted value
440+
* @param tree Scrutinee sub-tree that matched
441+
*/
442+
case ClosedTree(tree: Tree)
443+
/** HOAS pattern extracted value
444+
*
445+
* @param tree Scrutinee sub-tree that matched
446+
* @param patternTpe Type of the pattern hole (from the pattern)
447+
* @param args HOAS arguments (from the pattern)
448+
* @param env Mapping between scrutinee and pattern variables
449+
*/
450+
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
451+
452+
/** Return the expression that was extracted from a hole.
453+
*
454+
* If it was a closed expression it returns that expression. Otherwise,
455+
* if it is a HOAS pattern, the surrounding lambda is generated using
456+
* `mapTypeHoles` to create the signature of the lambda.
457+
*
458+
* This expression is assumed to be a valid expression in the given splice scope.
459+
*/
460+
def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match
461+
case MatchResult.ClosedTree(tree) =>
462+
new ExprImpl(tree, spliceScope)
463+
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
464+
val names: List[TermName] = args.map {
465+
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
466+
case arg => arg.symbol.name.asTermName
467+
}
468+
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
469+
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
470+
val meth = newAnonFun(ctx.owner, methTpe)
471+
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
472+
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
473+
val body = new TreeMap {
474+
override def transform(tree: Tree)(using Context): Tree =
475+
tree match
476+
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
477+
case tree => super.transform(tree)
478+
}.transform(tree)
479+
TreeOps(body).changeNonLocalOwners(meth)
480+
}
481+
val hoasClosure = Closure(meth, bodyFn)
482+
new ExprImpl(hoasClosure, spliceScope)
483+
484+
private inline def notMatched: optional[MatchingExprs] =
485+
optional.break()
486+
487+
private inline def matched: MatchingExprs =
488+
Seq.empty
489+
490+
private inline def matched(tree: Tree)(using Context): MatchingExprs =
491+
Seq(MatchResult.ClosedTree(tree))
492+
493+
private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
494+
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))
495+
496+
extension (self: MatchingExprs)
497+
/** Concatenates the contents of two successful matchings */
498+
def &&& (that: MatchingExprs): MatchingExprs = self ++ that
499+
end extension
482500

483501
}

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

+20-13
Original file line numberDiff line numberDiff line change
@@ -3137,20 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
31373137
ctx1.gadtState.addToConstraint(typeHoles)
31383138
ctx1
31393139

3140-
val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)
3141-
3142-
if typeHoles.isEmpty then matchings
3143-
else {
3144-
// After matching and doing all subtype checks, we have to approximate all the type bindings
3145-
// that we have found, seal them in a quoted.Type and add them to the result
3146-
def typeHoleApproximation(sym: Symbol) =
3147-
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
3148-
val fullBounds = ctx1.gadt.fullBounds(sym)
3149-
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo
3150-
reflect.TypeReprMethods.asType(tp)
3151-
matchings.map { tup =>
3152-
Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup
3140+
// After matching and doing all subtype checks, we have to approximate all the type bindings
3141+
// that we have found, seal them in a quoted.Type and add them to the result
3142+
def typeHoleApproximation(sym: Symbol) =
3143+
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
3144+
val fullBounds = ctx1.gadt.fullBounds(sym)
3145+
if fromAboveAnnot then fullBounds.hi else fullBounds.lo
3146+
3147+
QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings =>
3148+
import QuoteMatcher.MatchResult.*
3149+
lazy val spliceScope = SpliceScope.getCurrent
3150+
val typeHoleApproximations = typeHoles.map(typeHoleApproximation)
3151+
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*)
3152+
val typeHoleMap = new Types.TypeMap {
3153+
def apply(tp: Types.Type): Types.Type = tp match
3154+
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp)
3155+
case _ => mapOver(tp)
31533156
}
3157+
val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope))
3158+
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType)
3159+
val results = matchedTypes ++ matchedExprs
3160+
Tuple.fromIArray(IArray.unsafeFromArray(results.toArray))
31543161
}
31553162
}
31563163

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ { val ident = ($a: α); $rest(ident): T } } =>
9+
'{ { (y: α) => $rest(y) }.apply(???) }

tests/pos-macros/i15165a/Test_2.scala

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ { val ident = ($a: α); $rest(ident): T } } =>
9+
'{
10+
{ (y: α) =>
11+
${
12+
val bound = '{ ${ rest }(y) }
13+
Expr.betaReduce(bound)
14+
}
15+
}.apply($a)
16+
}

tests/pos-macros/i15165b/Test_2.scala

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.quoted.*
2+
3+
inline def valToFun[T](inline expr: T): T =
4+
${ impl('expr) }
5+
6+
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
7+
expr match
8+
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } =>
9+
'{ { (y: α) => $rest(y) }.apply(???) }

tests/pos-macros/i15165c/Test_2.scala

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = valToFun {
2+
val a: Int = 1
3+
a + 1
4+
}

0 commit comments

Comments
 (0)