diff --git a/library/src-3.x/scala/internal/quoted/Matcher.scala b/library/src-3.x/scala/internal/quoted/Matcher.scala index 52104fd48fa1..12aca77298ea 100644 --- a/library/src-3.x/scala/internal/quoted/Matcher.scala +++ b/library/src-3.x/scala/internal/quoted/Matcher.scala @@ -31,11 +31,23 @@ object Matcher { * @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]`` */ def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = { + // TODO improve performance import reflection.{Bind => BindPattern, _} + import Matching._ type Env = Set[(Symbol, Symbol)] - // TODO improve performance + inline def withEnv[T](env: Env)(body: => given Env => T): T = body given env + + /** Check that all trees match with =#= and concatenate the results with && */ + def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching = { + def rec(l1: List[Tree], l2: List[Tree]): Matching = (l1, l2) match { + case (x :: xs, y :: ys) => x =#= y && rec(xs, ys) + case (Nil, Nil) => matched + case _ => notMatched + } + rec(scrutinees, patterns) + } /** Check that the trees match and return the contents from the pattern holes. * Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes. @@ -45,7 +57,7 @@ object Matcher { * @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. * @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ - def treeMatches(scrutinee: Tree, pattern: Tree) given Env: Option[Tuple] = { + def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = { /** Check that both are `val` or both are `lazy val` or both are `var` **/ def checkValFlags(): Boolean = { @@ -56,7 +68,7 @@ object Matcher { } def bindingMatch(sym: Symbol) = - Some(Tuple1(new Bind(sym.name, sym))) + matched(new Bind(sym.name, sym)) def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match { case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), ""), Nil)) => true @@ -67,10 +79,6 @@ object Matcher { def hasBindAnnotation(sym: Symbol) = sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),""),List()) => true; case _ => true } - def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] = - if (scrutinees.size != patterns.size) None - else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*) - /** Normalieze the tree */ def normalize(tree: Tree): Tree = tree match { case Block(Nil, expr) => normalize(expr) @@ -85,126 +93,130 @@ object Matcher { if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole && s.tpe <:< tpt.tpe && tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) => - Some(Tuple1(scrutinee.seal)) + matched(scrutinee.seal) // Match a scala.internal.Quoted.patternHole and return the scrutinee tree case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil)) if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe => - Some(Tuple1(scrutinee.seal)) + matched(scrutinee.seal) // // Match two equivalent trees // case (Literal(constant1), Literal(constant2)) if constant1 == constant2 => - Some(()) + matched case (Typed(expr1, tpt1), Typed(expr2, tpt2)) => - foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2)) + expr1 =#= expr2 && tpt1 =#= tpt2 case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) => - Some(()) + matched case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol => - treeMatches(qual1, qual2) + qual1 =#= qual2 case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol => - Some(()) + matched case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol => - foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2)) + fn1 =#= fn2 && args1 =##= args2 case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol => - foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2)) + fn1 =#= fn2 && args1 =##= args2 case (Block(stats1, expr1), Block(stats2, expr2)) => - foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2)) + withEnv(the[Env] ++ stats1.map(_.symbol).zip(stats2.map(_.symbol))) { + stats1 =##= stats2 && expr1 =#= expr2 + } case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) => - foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2)) + cond1 =#= cond2 && thenp1 =#= thenp2 && elsep1 =#= elsep2 case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) => val lhsMatch = - if (treeMatches(lhs1, lhs2).isDefined) Some(()) - else None - foldMatchings(lhsMatch, treeMatches(rhs1, rhs2)) + if ((lhs1 =#= lhs2).isMatch) matched + else notMatched + lhsMatch && rhs1 =#= rhs2 case (While(cond1, body1), While(cond2, body2)) => - foldMatchings(treeMatches(cond1, cond2), treeMatches(body1, body2)) + cond1 =#= cond2 && body1 =#= body2 case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 => - treeMatches(expr1, expr2) + expr1 =#= expr2 case (New(tpt1), New(tpt2)) => - treeMatches(tpt1, tpt2) + tpt1 =#= tpt2 case (This(_), This(_)) if scrutinee.symbol == pattern.symbol => - Some(()) + matched case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 => - treeMatches(qual1, qual2) + qual1 =#= qual2 case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size => - treesMatch(elems1, elems2) + elems1 =##= elems2 case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol => - Some(()) + matched case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe => - Some(()) + matched case (Applied(tycon1, args1), Applied(tycon2, args2)) => - foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2)) + tycon1 =#= tycon2 && args1 =##= args2 case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() => val bindMatch = if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol) - else Some(()) - val returnTptMatch = treeMatches(tpt1, tpt2) + else matched + val returnTptMatch = tpt1 =#= tpt2 val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol) val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv - foldMatchings(bindMatch, returnTptMatch, rhsMatchings) + bindMatch && returnTptMatch && rhsMatchings case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) => - val typeParmasMatch = treesMatch(typeParams1, typeParams2) + val typeParmasMatch = typeParams1 =##= typeParams2 val paramssMatch = - if (paramss1.size != paramss2.size) None - else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*) + if (paramss1.size != paramss2.size) notMatched + else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => params1 =##= params2 }: _*) val bindMatch = if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol) - else Some(()) - val tptMatch = treeMatches(tpt1, tpt2) + else matched + val tptMatch = tpt1 =#= tpt2 val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol) ++ typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++ paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol) - val rhsMatch = treeMatches(rhs1, rhs2) given rhsEnv + val rhsMatch = (rhs1 =#= rhs2) given rhsEnv - foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch) + bindMatch && typeParmasMatch && paramssMatch && tptMatch && rhsMatch case (Lambda(_, tpt1), Lambda(_, tpt2)) => // TODO match tpt1 with tpt2? - Some(()) + matched case (Match(scru1, cases1), Match(scru2, cases2)) => - val scrutineeMacth = treeMatches(scru1, scru2) + val scrutineeMacth = scru1 =#= scru2 val casesMatch = - if (cases1.size != cases2.size) None + if (cases1.size != cases2.size) notMatched else foldMatchings(cases1.zip(cases2).map(caseMatches): _*) - foldMatchings(scrutineeMacth, casesMatch) + scrutineeMacth && casesMatch case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) => - val bodyMacth = treeMatches(body1, body2) + val bodyMacth = body1 =#= body2 val casesMatch = - if (cases1.size != cases2.size) None + if (cases1.size != cases2.size) notMatched else foldMatchings(cases1.zip(cases2).map(caseMatches): _*) val finalizerMatch = treeOptMatches(finalizer1, finalizer2) - foldMatchings(bodyMacth, casesMatch, finalizerMatch) + bodyMacth && casesMatch && finalizerMatch // Ignore type annotations - case (Annotated(tpt, _), _) => treeMatches(tpt, pattern) - case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt) + case (Annotated(tpt, _), _) => + tpt =#= pattern + case (_, Annotated(tpt, _)) => + scrutinee =#= tpt // No Match case _ => @@ -225,26 +237,24 @@ object Matcher { | | |""".stripMargin) - None + notMatched } } - def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Option[Tuple] = { + def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = { (scrutinee, pattern) match { - case (Some(x), Some(y)) => treeMatches(x, y) - case (None, None) => Some(()) - case _ => None + case (Some(x), Some(y)) => x =#= y + case (None, None) => matched + case _ => notMatched } } - def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Option[Tuple] = { - val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern) - - { - implied for Env = caseEnv + def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = { + val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern + withEnv(caseEnv) { val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard) - val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs) - foldMatchings(patternMatch, guardMatch, rhsMatch) + val rhsMatch = scrutinee.rhs =#= pattern.rhs + patternMatch && guardMatch && rhsMatch } } @@ -258,34 +268,34 @@ object Matcher { * @return The new environment containing the bindings defined in this pattern tuppled with * `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ - def patternMatches(scrutinee: Pattern, pattern: Pattern) given Env: (Env, Option[Tuple]) = (scrutinee, pattern) match { + def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match { case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil)) if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" => - (the[Env], Some(Tuple1(v1.seal))) + (the[Env], matched(v1.seal)) case (Pattern.Value(v1), Pattern.Value(v2)) => - (the[Env], treeMatches(v1, v2)) + (the[Env], v1 =#= v2) case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) => val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol) - patternMatches(body1, body2) given bindEnv + (body1 =%= body2) given bindEnv case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) => - val funMatch = treeMatches(fun1, fun2) + val funMatch = fun1 =#= fun2 val implicitsMatch = - if (implicits1.size != implicits2.size) None - else foldMatchings(implicits1.zip(implicits2).map(treeMatches): _*) + if (implicits1.size != implicits2.size) notMatched + else foldMatchings(implicits1.zip(implicits2).map((i1, i2) => i1 =#= i2): _*) val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2) - (patEnv, foldMatchings(funMatch, implicitsMatch, patternsMatch)) + (patEnv, funMatch && implicitsMatch && patternsMatch) case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) => foldPatterns(patterns1, patterns2) case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) => - (the[Env], treeMatches(tpt1, tpt2)) + (the[Env], tpt1 =#= tpt2) case (Pattern.WildcardPattern(), Pattern.WildcardPattern()) => - (the[Env], Some(())) + (the[Env], matched) case _ => if (debug) @@ -305,30 +315,57 @@ object Matcher { | | |""".stripMargin) - (the[Env], None) + (the[Env], notMatched) } - def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Option[Tuple]) = { - if (patterns1.size != patterns2.size) (the[Env], None) - else patterns1.zip(patterns2).foldLeft((the[Env], Option[Tuple](()))) { (acc, x) => - val (env, res) = patternMatches(x._1, x._2) given acc._1 - (env, foldMatchings(acc._2, res)) + def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = { + if (patterns1.size != patterns2.size) (the[Env], notMatched) + else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) => + val (env, res) = (x._1 =%= x._2) given acc._1 + (env, acc._2 && res) } } implied for Env = Set.empty - treeMatches(scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf[Option[Tup]] + (scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]] } - /** Joins the mattchings into a single matching. If any matching is `None` the result is `None`. - * Otherwise the result is `Some` of the concatenation of the tupples. - */ - private def foldMatchings(matchings: Option[Tuple]*): Option[Tuple] = { - // TODO improve performance - matchings.foldLeft[Option[Tuple]](Some(())) { - case (Some(acc), Some(holes)) => Some(acc ++ holes) - case (_, _) => None + /** Result of matching a part of an expression */ + private opaque type Matching = Option[Tuple] + + private object Matching { + + def notMatched: Matching = None + val matched: Matching = Some(()) + def matched(x: Any): Matching = Some(Tuple1(x)) + + def (self: Matching) asOptionOfTuple: Option[Tuple] = self + + /** Concatenates the contents of two sucessful matchings or return a `notMatched` */ + // FIXME inline to avoid alocation of by name closure (see #6395) + /*inline*/ def (self: Matching) && (that: => Matching): Matching = self match { + case Some(x) => + that match { + case Some(y) => Some(x ++ y) + case _ => None + } + case _ => None } + + /** Is this matching the result of a successful match */ + def (self: Matching) isMatch: Boolean = self.isDefined + + /** Joins the mattchings into a single matching. If any matching is `None` the result is `None`. + * Otherwise the result is `Some` of the concatenation of the tupples. + */ + def foldMatchings(matchings: Matching*): Matching = { + // TODO improve performance + matchings.foldLeft[Matching](Some(())) { + case (Some(acc), Some(holes)) => Some(acc ++ holes) + case (_, _) => None + } + } + } } diff --git a/tests/run-macros/quote-matcher-runtime.check b/tests/run-macros/quote-matcher-runtime.check index 54ac86017e9f..a630c9d5285a 100644 --- a/tests/run-macros/quote-matcher-runtime.check +++ b/tests/run-macros/quote-matcher-runtime.check @@ -302,6 +302,38 @@ Pattern: { } Result: None +Scrutinee: { + val a: scala.Int = 45 + a.+(a) +} +Pattern: { + val x: scala.Int = 45 + x.+(x) +} +Result: Some(List()) + +Scrutinee: { + val a: scala.Int = 45 + val b: scala.Int = a + () +} +Pattern: { + val x: scala.Int = 45 + val y: scala.Int = x + () +} +Result: Some(List()) + +Scrutinee: { + val a: scala.Int = 45 + a.+(a) +} +Pattern: { + val x: scala.Int = 45 + x.+(scala.internal.Quoted.patternHole[scala.Int]) +} +Result: Some(List(Expr(a))) + Scrutinee: { lazy val a: scala.Int = 45 () @@ -572,6 +604,26 @@ Pattern: { } Result: Some(List()) +Scrutinee: { + def a: scala.Int = a + a.+(a) +} +Pattern: { + def a: scala.Int = a + a.+(a) +} +Result: Some(List()) + +Scrutinee: { + def a: scala.Int = a + a.+(a) +} +Pattern: { + def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] + a.+(scala.internal.Quoted.patternHole[scala.Int]) +} +Result: Some(List(Expr(a), Expr(a))) + Scrutinee: { lazy val a: scala.Int = a () diff --git a/tests/run-macros/quote-matcher-runtime/quoted_2.scala b/tests/run-macros/quote-matcher-runtime/quoted_2.scala index ca424f94de4a..a88342a79657 100644 --- a/tests/run-macros/quote-matcher-runtime/quoted_2.scala +++ b/tests/run-macros/quote-matcher-runtime/quoted_2.scala @@ -90,6 +90,9 @@ object Test { matches({ val a: Int = 45 }, { lazy val a: Int = 45 }) matches({ val a: Int = 45 }, { var a: Int = 45 }) matches({ val a: Int = 45 }, { @patternBindHole var a: Int = patternHole }) + matches({ val a: Int = 45; a + a }, { val x: Int = 45; x + x }) + matches({ val a: Int = 45; val b = a }, { val x: Int = 45; val y = x }) + matches({ val a: Int = 45; a + a }, { val x: Int = 45; x + patternHole[Int] }) matches({ lazy val a: Int = 45 }, { val a: Int = 45 }) matches({ lazy val a: Int = 45 }, { lazy val a: Int = 45 }) matches({ lazy val a: Int = 45 }, { var a: Int = 45 }) @@ -117,6 +120,8 @@ object Test { matches({ def a(x: Int): Int = 45 }, { def a(x: Int @patternBindHole): Int = 45 }) matches({ def a(x: Int): Int = x }, { def b(y: Int): Int = y }) matches({ def a: Int = a }, { def b: Int = b }) + matches({ def a: Int = a; a + a }, { def a: Int = a; a + a }) + matches({ def a: Int = a; a + a }, { def a: Int = patternHole[Int]; a + patternHole[Int] }) matches({ lazy val a: Int = a }, { lazy val b: Int = b }) matches(1 match { case _ => 2 }, 1 match { case _ => 2 }) matches(1 match { case _ => 2 }, patternHole[Int] match { case _ => patternHole[Int] })