Skip to content

Commit fc15089

Browse files
committed
WIP: Pattern match against hoas pattern wtih type vars
1 parent e672e2c commit fc15089

File tree

1 file changed

+97
-27
lines changed

1 file changed

+97
-27
lines changed

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

+97-27
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,42 @@ class QuoteMatcher(debug: Boolean) {
288288
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
289289
withEnv(captureEnv) {
290290
scrutinee match
291-
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env)
291+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), Nil, env)
292+
case _ => notMatched
293+
}
294+
295+
/* Higher order term hole */
296+
// Matches an open term and wraps it into a lambda that provides the free variables
297+
case Apply(TypeApply(Ident(_), List(TypeTree(), targs)), SeqLiteral(args, _) :: Nil)
298+
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes) =>
299+
300+
/* Some of method symbols in arguments of higher-order term hole are eta-expanded.
301+
* e.g.
302+
* g: (Int) => Int
303+
* => {
304+
* def $anonfun(y: Int): Int = g(y)
305+
* closure($anonfun)
306+
* }
307+
*
308+
* f: (using Int) => Int
309+
* => f(using x)
310+
* This function restores the symbol of the original method from
311+
* the eta-expanded function.
312+
*/
313+
def getCapturedIdent(arg: Tree)(using Context): Ident =
314+
arg match
315+
case id: Ident => id
316+
case Apply(fun, _) => getCapturedIdent(fun)
317+
case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs)
318+
case Typed(expr, _) => getCapturedIdent(expr)
319+
320+
val env = summon[Env]
321+
val capturedIds = args.map(getCapturedIdent)
322+
val capturedSymbols = capturedIds.map(_.symbol)
323+
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
324+
withEnv(captureEnv) {
325+
scrutinee match
326+
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), targs, env)
292327
case _ => notMatched
293328
}
294329

@@ -558,9 +593,10 @@ class QuoteMatcher(debug: Boolean) {
558593
* @param patternTpe Type of the pattern hole (from the pattern)
559594
* @param argIds Identifiers of HOAS arguments (from the pattern)
560595
* @param argTypes Eta-expanded types of HOAS arguments (from the pattern)
596+
* @param typeArgs type arguments from the pattern
561597
* @param env Mapping between scrutinee and pattern variables
562598
*/
563-
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)
599+
case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env)
564600

565601
/** Return the expression that was extracted from a hole.
566602
*
@@ -573,29 +609,63 @@ class QuoteMatcher(debug: Boolean) {
573609
def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match
574610
case MatchResult.ClosedTree(tree) =>
575611
new ExprImpl(tree, spliceScope)
576-
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) =>
577-
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
578-
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
579-
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
580-
val meth = newAnonFun(ctx.owner, methTpe)
581-
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
582-
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
583-
val body = new TreeMap {
584-
override def transform(tree: Tree)(using Context): Tree =
585-
tree match
586-
/*
587-
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
588-
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
589-
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
590-
*/
591-
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
592-
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
593-
case tree => super.transform(tree)
594-
}.transform(tree)
595-
TreeOps(body).changeNonLocalOwners(meth)
596-
}
597-
val hoasClosure = Closure(meth, bodyFn)
598-
new ExprImpl(hoasClosure, spliceScope)
612+
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env) =>
613+
if typeArgs.isEmpty then
614+
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
615+
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
616+
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
617+
val meth = newAnonFun(ctx.owner, methTpe)
618+
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
619+
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
620+
val body = new TreeMap {
621+
override def transform(tree: Tree)(using Context): Tree =
622+
tree match
623+
/*
624+
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
625+
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
626+
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
627+
*/
628+
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
629+
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
630+
case tree => super.transform(tree)
631+
}.transform(tree)
632+
TreeOps(body).changeNonLocalOwners(meth)
633+
}
634+
val hoasClosure = Closure(meth, bodyFn)
635+
new ExprImpl(hoasClosure, spliceScope)
636+
else
637+
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
638+
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
639+
640+
val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length)
641+
val bounds = typeArgs map (_ => TypeBounds.empty)
642+
val resultTypeExp = (pt: PolyType) => {
643+
val fromSymbols = typeArgs.map(_.typeSymbol)
644+
val argTypes1 = argTypes.map(_.subst(fromSymbols, pt.paramRefs))
645+
val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs)
646+
MethodType(argTypes1, resultType1)
647+
}
648+
val methTpe = PolyType(typeArgs1)(_ => bounds, resultTypeExp)
649+
val meth = newAnonFun(ctx.owner, methTpe)
650+
// TODO-18271
651+
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
652+
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
653+
val body = new TreeMap {
654+
override def transform(tree: Tree)(using Context): Tree =
655+
tree match
656+
/*
657+
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
658+
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
659+
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
660+
*/
661+
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
662+
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
663+
case tree => super.transform(tree)
664+
}.transform(tree)
665+
TreeOps(body).changeNonLocalOwners(meth)
666+
}
667+
val hoasClosure = Closure(meth, bodyFn)
668+
new ExprImpl(hoasClosure, spliceScope)
599669

600670
private inline def notMatched[T]: optional[T] =
601671
optional.break()
@@ -606,8 +676,8 @@ class QuoteMatcher(debug: Boolean) {
606676
private inline def matched(tree: Tree)(using Context): MatchingExprs =
607677
Seq(MatchResult.ClosedTree(tree))
608678

609-
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs =
610-
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env))
679+
private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env)(using Context): MatchingExprs =
680+
Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env))
611681

612682
extension (self: MatchingExprs)
613683
/** Concatenates the contents of two successful matchings */

0 commit comments

Comments
 (0)