Skip to content

Commit ab8a613

Browse files
committed
Generalize HOAS patterns to take type parameters (experimental feature)
1 parent 2616c8b commit ab8a613

36 files changed

+629
-118
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -2025,7 +2025,7 @@ object desugar {
20252025
case Quote(body, _) =>
20262026
new UntypedTreeTraverser {
20272027
def traverse(tree: untpd.Tree)(using Context): Unit = tree match {
2028-
case SplicePattern(body, _) => collect(body)
2028+
case SplicePattern(body, _, _) => collect(body)
20292029
case _ => traverseChildren(tree)
20302030
}
20312031
}.traverse(body)

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
816816
}
817817
private object quotePatVars extends TreeAccumulator[List[Symbol]] {
818818
def apply(syms: List[Symbol], tree: Tree)(using Context) = tree match {
819-
case SplicePattern(pat, _) => outer.apply(syms, pat)
819+
case SplicePattern(pat, _, _) => outer.apply(syms, pat)
820820
case _ => foldOver(syms, tree)
821821
}
822822
}

compiler/src/dotty/tools/dotc/ast/Trees.scala

+9-8
Original file line numberDiff line numberDiff line change
@@ -760,9 +760,10 @@ object Trees {
760760
* `SplicePattern` can only be contained within a `QuotePattern`.
761761
*
762762
* @param body The tree that was spliced
763+
* @param typeargs The type arguments of the splice (the HOAS arguments)
763764
* @param args The arguments of the splice (the HOAS arguments)
764765
*/
765-
case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
766+
case class SplicePattern[+T <: Untyped] private[ast] (body: Tree[T], typeargs: List[Tree[T]], args: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
766767
extends TermTree[T] {
767768
type ThisTree[+T <: Untyped] = SplicePattern[T]
768769
}
@@ -1367,9 +1368,9 @@ object Trees {
13671368
case tree: QuotePattern if (bindings eq tree.bindings) && (body eq tree.body) && (quotes eq tree.quotes) => tree
13681369
case _ => finalize(tree, untpd.QuotePattern(bindings, body, quotes)(sourceFile(tree)))
13691370
}
1370-
def SplicePattern(tree: Tree)(body: Tree, args: List[Tree])(using Context): SplicePattern = tree match {
1371-
case tree: SplicePattern if (body eq tree.body) && (args eq tree.args) => tree
1372-
case _ => finalize(tree, untpd.SplicePattern(body, args)(sourceFile(tree)))
1371+
def SplicePattern(tree: Tree)(body: Tree, typeargs: List[Tree], args: List[Tree])(using Context): SplicePattern = tree match {
1372+
case tree: SplicePattern if (body eq tree.body) && (typeargs eq tree.typeargs) & (args eq tree.args) => tree
1373+
case _ => finalize(tree, untpd.SplicePattern(body, typeargs, args)(sourceFile(tree)))
13731374
}
13741375
def SingletonTypeTree(tree: Tree)(ref: Tree)(using Context): SingletonTypeTree = tree match {
13751376
case tree: SingletonTypeTree if (ref eq tree.ref) => tree
@@ -1617,8 +1618,8 @@ object Trees {
16171618
cpy.Splice(tree)(transform(expr)(using spliceContext))
16181619
case tree @ QuotePattern(bindings, body, quotes) =>
16191620
cpy.QuotePattern(tree)(transform(bindings), transform(body)(using quoteContext), transform(quotes))
1620-
case tree @ SplicePattern(body, args) =>
1621-
cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(args))
1621+
case tree @ SplicePattern(body, targs, args) =>
1622+
cpy.SplicePattern(tree)(transform(body)(using spliceContext), transform(targs), transform(args))
16221623
case tree @ Hole(isTerm, idx, args, content) =>
16231624
cpy.Hole(tree)(isTerm, idx, transform(args), transform(content))
16241625
case _ =>
@@ -1766,8 +1767,8 @@ object Trees {
17661767
this(x, expr)(using spliceContext)
17671768
case QuotePattern(bindings, body, quotes) =>
17681769
this(this(this(x, bindings), body)(using quoteContext), quotes)
1769-
case SplicePattern(body, args) =>
1770-
this(this(x, body)(using spliceContext), args)
1770+
case SplicePattern(body, typeargs, args) =>
1771+
this(this(this(x, body)(using spliceContext), typeargs), args)
17711772
case Hole(_, _, args, content) =>
17721773
this(this(x, args), content)
17731774
case _ =>

compiler/src/dotty/tools/dotc/ast/untpd.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
409409
def Quote(body: Tree, tags: List[Tree])(implicit src: SourceFile): Quote = new Quote(body, tags)
410410
def Splice(expr: Tree)(implicit src: SourceFile): Splice = new Splice(expr)
411411
def QuotePattern(bindings: List[Tree], body: Tree, quotes: Tree)(implicit src: SourceFile): QuotePattern = new QuotePattern(bindings, body, quotes)
412-
def SplicePattern(body: Tree, args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, args)
412+
def SplicePattern(body: Tree, typeargs: List[Tree], args: List[Tree])(implicit src: SourceFile): SplicePattern = new SplicePattern(body, typeargs, args)
413413
def TypeTree()(implicit src: SourceFile): TypeTree = new TypeTree()
414414
def InferredTypeTree()(implicit src: SourceFile): TypeTree = new InferredTypeTree()
415415
def SingletonTypeTree(ref: Tree)(implicit src: SourceFile): SingletonTypeTree = new SingletonTypeTree(ref)

compiler/src/dotty/tools/dotc/config/Feature.scala

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object Feature:
3333
val pureFunctions = experimental("pureFunctions")
3434
val captureChecking = experimental("captureChecking")
3535
val into = experimental("into")
36+
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
3637

3738
val globalOnlyImports: Set[TermName] = Set(pureFunctions, captureChecking)
3839

@@ -84,6 +85,9 @@ object Feature:
8485

8586
def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)
8687

88+
def quotedPatternsWithPolymorphicFunctionsEnabled(using Context) =
89+
enabled(quotedPatternsWithPolymorphicFunctions)
90+
8791
/** Is pureFunctions enabled for this compilation unit? */
8892
def pureFunsEnabled(using Context) =
8993
enabledBySetting(pureFunctions)

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ class Definitions {
882882
@tu lazy val QuotedRuntimePatterns: Symbol = requiredModule("scala.quoted.runtime.Patterns")
883883
@tu lazy val QuotedRuntimePatterns_patternHole: Symbol = QuotedRuntimePatterns.requiredMethod("patternHole")
884884
@tu lazy val QuotedRuntimePatterns_higherOrderHole: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHole")
885+
@tu lazy val QuotedRuntimePatterns_higherOrderHoleWithTypes: Symbol = QuotedRuntimePatterns.requiredMethod("higherOrderHoleWithTypes")
885886
@tu lazy val QuotedRuntimePatterns_patternTypeAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("patternType")
886887
@tu lazy val QuotedRuntimePatterns_fromAboveAnnot: ClassSymbol = QuotedRuntimePatterns.requiredClass("fromAbove")
887888

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,7 @@ object Parsers {
17871787
syntaxError(em"$msg\n\nHint: $hint", Span(start, in.lastOffset))
17881788
Ident(nme.ERROR.toTypeName)
17891789
else if inPattern then
1790-
SplicePattern(expr, Nil)
1790+
SplicePattern(expr, Nil, Nil)
17911791
else
17921792
Splice(expr)
17931793
}

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -749,11 +749,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
749749
val open = if (body.isTerm) keywordStr("{") else keywordStr("[")
750750
val close = if (body.isTerm) keywordStr("}") else keywordStr("]")
751751
keywordStr("'") ~ quotesText ~ open ~ bindingsText ~ toTextGlobal(body) ~ close
752-
case SplicePattern(pattern, args) =>
752+
case SplicePattern(pattern, typeargs, args) =>
753753
val spliceTypeText = (keywordStr("[") ~ toTextGlobal(tree.typeOpt) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists)
754754
keywordStr("$") ~ spliceTypeText ~ {
755-
if args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}")
756-
else toText(pattern.symbol.name) ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
755+
if typeargs.isEmpty && args.isEmpty then keywordStr("{") ~ inPattern(toText(pattern)) ~ keywordStr("}")
756+
else if typeargs.isEmpty then toText(pattern.symbol.name) ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
757+
else toText(pattern.symbol.name) ~ "[" ~ toTextGlobal(typeargs, ", ")~ "]" ~ "(" ~ toTextGlobal(args, ", ") ~ ")"
757758
}
758759
case Hole(isTerm, idx, args, content) =>
759760
val (prefix, postfix) = if isTerm then ("{{{", "}}}") else ("[[[", "]]]")

compiler/src/dotty/tools/dotc/quoted/QuotePatterns.scala

+103-34
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,90 @@ object QuotePatterns:
2626
import tpd._
2727

2828
/** Check for restricted patterns */
29-
def checkPattern(quotePattern: QuotePattern)(using Context): Unit = new tpd.TreeTraverser {
30-
def traverse(tree: Tree)(using Context): Unit = tree match {
31-
case _: SplicePattern =>
32-
case tdef: TypeDef if tdef.symbol.isClass =>
33-
val kind = if tdef.symbol.is(Module) then "objects" else "classes"
34-
report.error(em"Implementation restriction: cannot match $kind", tree.srcPos)
35-
case tree: NamedDefTree =>
36-
if tree.name.is(NameKinds.WildcardParamName) then
37-
report.warning(
38-
"Use of `_` for lambda in quoted pattern. Use explicit lambda instead or use `$_` to match any term.",
39-
tree.srcPos)
40-
if tree.name.isTermName && !tree.nameSpan.isSynthetic && tree.name != nme.ANON_FUN && tree.name.startsWith("$") then
41-
report.error("Names cannot start with $ quote pattern", tree.namePos)
42-
traverseChildren(tree)
43-
case _: Match =>
44-
report.error("Implementation restriction: cannot match `match` expressions", tree.srcPos)
45-
case _: Try =>
46-
report.error("Implementation restriction: cannot match `try` expressions", tree.srcPos)
47-
case _: Return =>
48-
report.error("Implementation restriction: cannot match `return` statements", tree.srcPos)
49-
case _ =>
50-
traverseChildren(tree)
51-
}
29+
def checkPattern(quotePattern: QuotePattern)(using Context): Unit =
30+
def validatePatternAndCollectTypeVars(): Set[Symbol] = new tpd.TreeAccumulator[Set[Symbol]] {
31+
override def apply(typevars: Set[Symbol], tree: tpd.Tree)(using Context): Set[Symbol] =
32+
// Collect type variables
33+
val typevars1 = tree match
34+
case tree @ DefDef(_, paramss, _, _) =>
35+
typevars union paramss.flatMap{ params => params match
36+
case TypeDefs(tdefs) => tdefs.map(_.symbol)
37+
case _ => List.empty
38+
}.toSet union typevars
39+
case _ => typevars
40+
41+
// Validate pattern
42+
tree match
43+
case _: SplicePattern => typevars1
44+
case tdef: TypeDef if tdef.symbol.isClass =>
45+
val kind = if tdef.symbol.is(Module) then "objects" else "classes"
46+
report.error(em"Implementation restriction: cannot match $kind", tree.srcPos)
47+
typevars1
48+
case tree: NamedDefTree =>
49+
if tree.name.is(NameKinds.WildcardParamName) then
50+
report.warning(
51+
"Use of `_` for lambda in quoted pattern. Use explicit lambda instead or use `$_` to match any term.",
52+
tree.srcPos)
53+
if tree.name.isTermName && !tree.nameSpan.isSynthetic && tree.name != nme.ANON_FUN && tree.name.startsWith("$") then
54+
report.error("Names cannot start with $ quote pattern", tree.namePos)
55+
foldOver(typevars1, tree)
56+
case _: Match =>
57+
report.error("Implementation restriction: cannot match `match` expressions", tree.srcPos)
58+
typevars1
59+
case _: Try =>
60+
report.error("Implementation restriction: cannot match `try` expressions", tree.srcPos)
61+
typevars1
62+
case _: Return =>
63+
report.error("Implementation restriction: cannot match `return` statements", tree.srcPos)
64+
typevars1
65+
case _ =>
66+
foldOver(typevars1, tree)
67+
}.apply(Set.empty, quotePattern.body)
68+
69+
val boundTypeVars = validatePatternAndCollectTypeVars()
5270

53-
}.traverse(quotePattern.body)
71+
/*
72+
* This part checks well-formedness of arguments to hoas patterns.
73+
* (1) Type arguments of a hoas patterns must be introduced in the quote pattern.ctxShow
74+
* Examples
75+
* well-formed: '{ [A] => (x : A) => $a[A](x) } // A is introduced in the quote pattern
76+
* ill-formed: '{ (x : Int) => $a[Int](x) } // Int is defined outside of the quote pattern
77+
* (2) If value arguments of a hoas pattern has a type with type variables that are introduced in
78+
* the quote pattern, those type variables should be in type arguments to the hoas patternHole
79+
* Examples
80+
* well-formed: '{ [A] => (x : A) => $a[A](x) } // a : [A] => (x:A) => A
81+
* ill-formed: '{ [A] => (x : A) => $a(x) } // a : (x:A) => A ...but A is undefined; hence ill-formed
82+
*/
83+
new tpd.TreeTraverser {
84+
override def traverse(tree: tpd.Tree)(using Context): Unit = tree match {
85+
case tree: SplicePattern =>
86+
def uncapturedTypeVars(arg: tpd.Tree, capturedTypeVars: List[tpd.Tree]): Set[Type] =
87+
/* Sometimes arg is untyped when a splice pattern is ill-formed.
88+
* Return early in such case.
89+
* Refer to QuoteAndSplices::typedSplicePattern
90+
*/
91+
if !arg.hasType then return Set.empty
92+
93+
val capturedTypeVarsSet = capturedTypeVars.map(_.symbol).toSet
94+
new TypeAccumulator[Set[Type]] {
95+
def apply(x: Set[Type], tp: Type): Set[Type] =
96+
if boundTypeVars.contains(tp.typeSymbol) && !capturedTypeVarsSet.contains(tp.typeSymbol) then
97+
foldOver(x + tp, tp)
98+
else
99+
foldOver(x, tp)
100+
}.apply(Set.empty, arg.tpe)
101+
102+
for (typearg <- tree.typeargs) // case (1)
103+
do
104+
if !boundTypeVars.contains(typearg.symbol) then
105+
report.error("Type arguments of a hoas pattern needs to be defined inside the quoted pattern", typearg.srcPos)
106+
for (arg <- tree.args) // case (2)
107+
do
108+
if !uncapturedTypeVars(arg, tree.typeargs).isEmpty then
109+
report.error("Type variables that this argument depends on are not captured in this hoas pattern", arg.srcPos)
110+
case _ => traverseChildren(tree)
111+
}
112+
}.traverse(quotePattern.body)
54113

55114
/** Encode the quote pattern into an `unapply` that the pattern matcher can handle.
56115
*
@@ -74,7 +133,7 @@ object QuotePatterns:
74133
* .ExprMatch // or TypeMatch
75134
* .unapply[
76135
* KCons[t1 >: l1 <: b1, ...KCons[tn >: ln <: bn, KNil]...], // scala.quoted.runtime.{KCons, KNil}
77-
* (T1, T2, (A1, ..., An) => T3, ...)
136+
* (Expr[T1], Expr[T2], Expr[(A1, ..., An) => T3], ...)
78137
* ](
79138
* '{
80139
* type t1' >: l1' <: b1'
@@ -197,16 +256,24 @@ object QuotePatterns:
197256
val patBuf = new mutable.ListBuffer[Tree]
198257
val shape = new tpd.TreeMap {
199258
override def transform(tree: Tree)(using Context) = tree match {
200-
case Typed(splice @ SplicePattern(pat, Nil), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) =>
259+
case Typed(splice @ SplicePattern(pat, Nil, Nil), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) =>
201260
transform(tpt) // Collect type bindings
202261
transform(splice)
203-
case SplicePattern(pat, args) =>
262+
case SplicePattern(pat, typeargs, args) =>
204263
val patType = pat.tpe.widen
205264
val patType1 = patType.translateFromRepeated(toArray = false)
206265
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
207266
patBuf += pat1
208-
if args.isEmpty then ref(defn.QuotedRuntimePatterns_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
209-
else ref(defn.QuotedRuntimePatterns_higherOrderHole.termRef).appliedToType(tree.tpe).appliedTo(SeqLiteral(args, TypeTree(defn.AnyType))).withSpan(tree.span)
267+
if typeargs.isEmpty && args.isEmpty then ref(defn.QuotedRuntimePatterns_patternHole.termRef).appliedToType(tree.tpe).withSpan(tree.span)
268+
else if typeargs.isEmpty then
269+
ref(defn.QuotedRuntimePatterns_higherOrderHole.termRef)
270+
.appliedToType(tree.tpe)
271+
.appliedTo(SeqLiteral(args, TypeTree(defn.AnyType)))
272+
.withSpan(tree.span)
273+
else ref(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes.termRef)
274+
.appliedToTypeTrees(List(TypeTree(tree.tpe), tpd.hkNestedPairsTypeTree(typeargs)))
275+
.appliedTo(SeqLiteral(args, TypeTree(defn.AnyType)))
276+
.withSpan(tree.span)
210277
case _ =>
211278
super.transform(tree)
212279
}
@@ -232,17 +299,19 @@ object QuotePatterns:
232299
fun match
233300
// <quotes>.asInstanceOf[QuoteMatching].{ExprMatch,TypeMatch}.unapply[<typeBindings>, <resTypes>]
234301
case TypeApply(Select(Select(TypeApply(Select(quotes, _), _), _), _), typeBindings :: resTypes :: Nil) =>
235-
val bindings = unrollBindings(typeBindings)
302+
val bindings = unrollHkNestedPairsTypeTree(typeBindings)
236303
val addPattenSplice = new TreeMap {
237304
private val patternIterator = patterns.iterator.filter {
238305
case pat: Bind => !pat.symbol.name.is(PatMatGivenVarName)
239306
case _ => true
240307
}
241308
override def transform(tree: tpd.Tree)(using Context): tpd.Tree = tree match
242309
case TypeApply(patternHole, _) if patternHole.symbol == defn.QuotedRuntimePatterns_patternHole =>
243-
cpy.SplicePattern(tree)(patternIterator.next(), Nil)
310+
cpy.SplicePattern(tree)(patternIterator.next(), Nil, Nil)
244311
case Apply(patternHole, SeqLiteral(args, _) :: Nil) if patternHole.symbol == defn.QuotedRuntimePatterns_higherOrderHole =>
245-
cpy.SplicePattern(tree)(patternIterator.next(), args)
312+
cpy.SplicePattern(tree)(patternIterator.next(), Nil, args)
313+
case Apply(TypeApply(patternHole, List(_, targsTpe)), SeqLiteral(args, _) :: Nil) if patternHole.symbol == defn.QuotedRuntimePatterns_higherOrderHoleWithTypes =>
314+
cpy.SplicePattern(tree)(patternIterator.next(), unrollHkNestedPairsTypeTree(targsTpe), args)
246315
case _ => super.transform(tree)
247316
}
248317
val body = addPattenSplice.transform(shape) match
@@ -260,7 +329,7 @@ object QuotePatterns:
260329
case body => body
261330
cpy.QuotePattern(tree)(bindings, body, quotes)
262331

263-
private def unrollBindings(tree: Tree)(using Context): List[Tree] = tree match
332+
private def unrollHkNestedPairsTypeTree(tree: Tree)(using Context): List[Tree] = tree match
264333
case AppliedTypeTree(tupleN, bindings) if defn.isTupleClass(tupleN.symbol) => bindings // TupleN, 1 <= N <= 22
265-
case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollBindings(tail) // KCons or *:
334+
case AppliedTypeTree(_, head :: tail :: Nil) => head :: unrollHkNestedPairsTypeTree(tail) // KCons or *:
266335
case _ => Nil // KNil or EmptyTuple
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"folders": [
3+
{
4+
"path": "../../../../../.."
5+
},
6+
{
7+
"path": "../../../../../../../workspace"
8+
}
9+
],
10+
"settings": {
11+
"files.watcherExclude": {
12+
"**/target": true
13+
}
14+
}
15+
}

compiler/src/dotty/tools/dotc/typer/Applications.scala

+9
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,8 @@ trait Applications extends Compatibility {
11141114
}
11151115
else {
11161116
val app = tree.fun match
1117+
case untpd.TypeApply(_: untpd.SplicePattern, _) if Feature.quotedPatternsWithPolymorphicFunctionsEnabled =>
1118+
typedAppliedSpliceWithTypes(tree, pt)
11171119
case _: untpd.SplicePattern => typedAppliedSplice(tree, pt)
11181120
case _ => realApply
11191121
app match {
@@ -1164,9 +1166,16 @@ trait Applications extends Compatibility {
11641166
if (ctx.mode.is(Mode.Pattern))
11651167
return errorTree(tree, em"invalid pattern")
11661168

1169+
tree.fun match {
1170+
case _: untpd.SplicePattern if Feature.quotedPatternsWithPolymorphicFunctionsEnabled =>
1171+
return errorTree(tree, em"Implementation restriction: A higher-order pattern must carry value arguments")
1172+
case _ =>
1173+
}
1174+
11671175
val isNamed = hasNamedArg(tree.args)
11681176
val typedArgs = if (isNamed) typedNamedArgs(tree.args) else tree.args.mapconserve(typedType(_))
11691177
record("typedTypeApply")
1178+
11701179
typedExpr(tree.fun, PolyProto(typedArgs, pt)) match {
11711180
case fun: TypeApply if !ctx.isAfterTyper =>
11721181
val function = fun.fun

0 commit comments

Comments
 (0)