Skip to content

Commit 449008e

Browse files
authored
Merge pull request #6206 from dotty-staging/add-quote-patterns-runntime
Add quote patterns runtime
2 parents 5dced62 + fb889a2 commit 449008e

File tree

18 files changed

+1191
-59
lines changed

18 files changed

+1191
-59
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,12 @@ class Definitions {
723723
lazy val InternalQuoted_patternHoleR: TermRef = InternalQuotedModule.requiredMethodRef("patternHole")
724724
def InternalQuoted_patternHole(implicit ctx: Context): Symbol = InternalQuoted_patternHoleR.symbol
725725

726+
lazy val InternalQuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.internal.quoted.Matcher")
727+
def InternalQuotedMatcherModule(implicit ctx: Context): Symbol = InternalQuotedMatcherModuleRef.symbol
728+
729+
lazy val InternalQuotedMatcher_unapplyR: TermRef = InternalQuotedMatcherModule.requiredMethodRef(nme.unapply)
730+
def InternalQuotedMatcher_unapply(implicit ctx: Context) = InternalQuotedMatcher_unapplyR.symbol
731+
726732
lazy val QuotedExprsModule: TermSymbol = ctx.requiredModule("scala.quoted.Exprs")
727733
def QuotedExprsClass(implicit ctx: Context): ClassSymbol = QuotedExprsModule.asClass
728734

@@ -745,12 +751,6 @@ class Definitions {
745751
lazy val TastyReflectionModule: TermSymbol = ctx.requiredModule("scala.tasty.Reflection")
746752
lazy val TastyReflection_macroContext: TermSymbol = TastyReflectionModule.requiredMethod("macroContext")
747753

748-
lazy val QuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.runtime.quoted.Matcher")
749-
def QuotedMatcherModule(implicit ctx: Context): Symbol = QuotedMatcherModuleRef.symbol
750-
751-
lazy val QuotedMatcher_unapplyR: TermRef = QuotedMatcherModule.requiredMethodRef(nme.unapply)
752-
def QuotedMatcher_unapply(implicit ctx: Context) = QuotedMatcher_unapplyR.symbol
753-
754754
lazy val EqlType: TypeRef = ctx.requiredClassRef("scala.Eql")
755755
def EqlClass(implicit ctx: Context): ClassSymbol = EqlType.symbol.asClass
756756
def EqlModule(implicit ctx: Context): Symbol = EqlClass.companionModule

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
239239

240240
type Ref = tpd.RefTree
241241

242+
def matchRef(tree: Tree)(implicit ctx: Context): Option[Ref] = tree match {
243+
case x: tpd.RefTree if x.isTerm => Some(x)
244+
case _ => None
245+
}
246+
242247
def Ref_apply(sym: Symbol)(implicit ctx: Context): Ref =
243248
withDefaultPos(ctx => tpd.ref(sym)(ctx).asInstanceOf[tpd.RefTree])
244249

@@ -1730,6 +1735,8 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
17301735
// DEFINITIONS
17311736
//
17321737

1738+
// Symbols
1739+
17331740
def Definitions_RootPackage: Symbol = defn.RootPackage
17341741
def Definitions_RootClass: Symbol = defn.RootClass
17351742

@@ -1778,6 +1785,10 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
17781785
defn.FunctionClass(arity, isImplicit, isErased).asClass
17791786
def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass
17801787

1788+
def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole
1789+
1790+
// Types
1791+
17811792
def Definitions_UnitType: Type = defn.UnitType
17821793
def Definitions_ByteType: Type = defn.ByteType
17831794
def Definitions_ShortType: Type = defn.ShortType

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1944,7 +1944,7 @@ class Typer extends Namer
19441944
val patType = defn.tupleType(splices.tpes.map(_.widen))
19451945
val splicePat = typed(untpd.Tuple(splices.map(untpd.TypedSplice(_))).withSpan(quoted.span), patType)
19461946
UnApply(
1947-
fun = ref(defn.QuotedMatcher_unapplyR).appliedToType(patType),
1947+
fun = ref(defn.InternalQuotedMatcher_unapplyR).appliedToType(patType),
19481948
implicits =
19491949
ref(defn.InternalQuoted_exprQuoteR).appliedToType(shape.tpe).appliedTo(shape) ::
19501950
implicitArgTree(defn.TastyReflectionType, tree.span) :: Nil,
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
package scala.internal.quoted
2+
3+
import scala.annotation.internal.sharable
4+
5+
import scala.quoted._
6+
import scala.tasty._
7+
8+
object Matcher {
9+
10+
private final val debug = false
11+
12+
/** Pattern matches an the scrutineeExpr aquainsnt the patternExpr and returns a tuple
13+
* with the matched holes if successful.
14+
*
15+
* Examples:
16+
* - `Matcher.unapply('{ f(0, myInt) })('{ f(0, myInt) }, _)`
17+
* will return `Some(())` (where `()` is a tuple of arity 0)
18+
* - `Matcher.unapply('{ f(0, myInt) })('{ f(patternHole[Int], patternHole[Int]) }, _)`
19+
* will return `Some(Tuple2('{0}, '{ myInt }))`
20+
* - `Matcher.unapply('{ f(0, "abc") })('{ f(0, patternHole[Int]) }, _)`
21+
* will return `None` due to the missmatch of types in the hole
22+
*
23+
* Holes:
24+
* - scala.internal.Quoted.patternHole[T]: hole that matches an expression `x` of type `Expr[U]`
25+
* if `U <:< T` and returns `x` as part of the match.
26+
*
27+
* @param scrutineeExpr `Expr[_]` on which we are pattern matching
28+
* @param patternExpr `Expr[_]` containing the pattern tree
29+
* @param reflection instance of the reflection API (implicitly provided by the macro)
30+
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
31+
*/
32+
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
33+
import reflection._
34+
// TODO improve performance
35+
36+
/** Check that the trees match and return the contents from the pattern holes.
37+
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
38+
*
39+
* @param scrutinee The tree beeing matched
40+
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
41+
* @param env Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
42+
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
43+
*/
44+
def treeMatches(scrutinee: Tree, pattern: Tree)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
45+
46+
/** Check that both are `val` or both are `lazy val` or both are `var` **/
47+
def checkValFlags(): Boolean = {
48+
import Flags._
49+
val sFlags = scrutinee.symbol.flags
50+
val pFlags = pattern.symbol.flags
51+
sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable)
52+
}
53+
54+
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
55+
if (scrutinees.size != patterns.size) None
56+
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
57+
58+
/** Normalieze the tree */
59+
def normalize(tree: Tree): Tree = tree match {
60+
case Block(Nil, expr) => normalize(expr)
61+
case Inlined(_, Nil, expr) => normalize(expr)
62+
case _ => tree
63+
}
64+
65+
(normalize(scrutinee), normalize(pattern)) match {
66+
67+
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
68+
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
69+
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe =>
70+
Some(Tuple1(scrutinee.seal))
71+
72+
//
73+
// Match two equivalent trees
74+
//
75+
76+
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
77+
Some(())
78+
79+
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
80+
foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2))
81+
82+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || env((scrutinee.symbol, pattern.symbol)) =>
83+
Some(())
84+
85+
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
86+
treeMatches(qual1, qual2)
87+
88+
case (IsRef(_), IsRef(_, _)) if scrutinee.symbol == pattern.symbol =>
89+
Some(())
90+
91+
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
92+
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
93+
94+
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
95+
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
96+
97+
case (Block(stats1, expr1), Block(stats2, expr2)) =>
98+
foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
99+
100+
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
101+
foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2))
102+
103+
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
104+
val lhsMatch =
105+
if (treeMatches(lhs1, lhs2).isDefined) Some(())
106+
else None
107+
foldMatchings(lhsMatch, treeMatches(rhs1, rhs2))
108+
109+
case (While(cond1, body1), While(cond2, body2)) =>
110+
foldMatchings(treeMatches(cond1, cond2), treeMatches(body1, body2))
111+
112+
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
113+
treeMatches(expr1, expr2)
114+
115+
case (New(tpt1), New(tpt2)) =>
116+
treeMatches(tpt1, tpt2)
117+
118+
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
119+
Some(())
120+
121+
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
122+
treeMatches(qual1, qual2)
123+
124+
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
125+
treesMatch(elems1, elems2)
126+
127+
case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
128+
Some(())
129+
130+
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
131+
Some(())
132+
133+
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
134+
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
135+
136+
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
137+
val returnTptMatch = treeMatches(tpt1, tpt2)
138+
val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
139+
val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
140+
foldMatchings(returnTptMatch, rhsMatchings)
141+
142+
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
143+
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
144+
val paramssMatch =
145+
if (paramss1.size != paramss2.size) None
146+
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
147+
val tptMatch = treeMatches(tpt1, tpt2)
148+
val rhsEnv =
149+
env + (scrutinee.symbol -> pattern.symbol) ++
150+
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
151+
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
152+
val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)
153+
154+
foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
155+
156+
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
157+
// TODO match tpt1 with tpt2?
158+
Some(())
159+
160+
case (Match(scru1, cases1), Match(scru2, cases2)) =>
161+
val scrutineeMacth = treeMatches(scru1, scru2)
162+
val casesMatch =
163+
if (cases1.size != cases2.size) None
164+
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
165+
foldMatchings(scrutineeMacth, casesMatch)
166+
167+
case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
168+
val bodyMacth = treeMatches(body1, body2)
169+
val casesMatch =
170+
if (cases1.size != cases2.size) None
171+
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
172+
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
173+
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
174+
175+
// No Match
176+
case _ =>
177+
if (debug)
178+
println(
179+
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
180+
|Scrutinee
181+
| ${scrutinee.showCode}
182+
|
183+
|${scrutinee.show}
184+
|
185+
|did not match pattern
186+
| ${pattern.showCode}
187+
|
188+
|${pattern.show}
189+
|
190+
|
191+
|
192+
|
193+
|""".stripMargin)
194+
None
195+
}
196+
}
197+
198+
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
199+
(scrutinee, pattern) match {
200+
case (Some(x), Some(y)) => treeMatches(x, y)
201+
case (None, None) => Some(())
202+
case _ => None
203+
}
204+
}
205+
206+
def caseMatches(scrutinee: CaseDef, pattern: CaseDef)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
207+
val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)
208+
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)(caseEnv)
209+
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)(caseEnv)
210+
foldMatchings(patternMatch, guardMatch, rhsMatch)
211+
}
212+
213+
/** Check that the pattern trees match and return the contents from the pattern holes.
214+
* Return a tuple with the new environment containing the bindings defined in this pattern and a matching.
215+
* The matching is None if the pattern trees do not match otherwise return Some of a tuple containing all the contents in the holes.
216+
*
217+
* @param scrutinee The pattern tree beeing matched
218+
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
219+
* @param env Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
220+
* @return The new environment containing the bindings defined in this pattern tuppled with
221+
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
222+
*/
223+
def patternMatches(scrutinee: Pattern, pattern: Pattern)(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = (scrutinee, pattern) match {
224+
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
225+
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
226+
(env, Some(Tuple1(v1.seal)))
227+
228+
case (Pattern.Value(v1), Pattern.Value(v2)) =>
229+
(env, treeMatches(v1, v2))
230+
231+
case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
232+
val bindEnv = env + (scrutinee.symbol -> pattern.symbol)
233+
patternMatches(body1, body2)(bindEnv)
234+
235+
case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
236+
val funMatch = treeMatches(fun1, fun2)
237+
val implicitsMatch =
238+
if (implicits1.size != implicits2.size) None
239+
else foldMatchings(implicits1.zip(implicits2).map(treeMatches): _*)
240+
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
241+
(patEnv, foldMatchings(funMatch, implicitsMatch, patternsMatch))
242+
243+
case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) =>
244+
foldPatterns(patterns1, patterns2)
245+
246+
case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) =>
247+
(env, treeMatches(tpt1, tpt2))
248+
249+
case _ =>
250+
if (debug)
251+
println(
252+
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
253+
|Scrutinee
254+
| ${scrutinee.showCode}
255+
|
256+
|${scrutinee.show}
257+
|
258+
|did not match pattern
259+
| ${pattern.showCode}
260+
|
261+
|${pattern.show}
262+
|
263+
|
264+
|
265+
|
266+
|""".stripMargin)
267+
(env, None)
268+
}
269+
270+
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern])(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = {
271+
if (patterns1.size != patterns2.size) (env, None)
272+
else patterns1.zip(patterns2).foldLeft((env, Option[Tuple](()))) { (acc, x) =>
273+
val (env, res) = patternMatches(x._1, x._2)(acc._1)
274+
(env, foldMatchings(acc._2, res))
275+
}
276+
}
277+
278+
treeMatches(scrutineeExpr.unseal, patternExpr.unseal)(Set.empty).asInstanceOf[Option[Tup]]
279+
}
280+
281+
/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
282+
* Otherwise the result is `Some` of the concatenation of the tupples.
283+
*/
284+
private def foldMatchings(matchings: Option[Tuple]*): Option[Tuple] = {
285+
// TODO improve performance
286+
matchings.foldLeft[Option[Tuple]](Some(())) {
287+
case (Some(acc), Some(holes)) => Some(acc ++ holes)
288+
case (_, _) => None
289+
}
290+
}
291+
292+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package scala.internal.quoted
2+
3+
import scala.quoted.Expr
4+
import scala.tasty.Reflection
5+
6+
object Matcher {
7+
8+
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] =
9+
throw new Exception("running on non bootstrapped library")
10+
11+
}

0 commit comments

Comments
 (0)