Skip to content

Commit 98c84c3

Browse files
authored
Two fixes to NamedTuple pattern matching (#22953)
1. Fix translation for named patterns where the selector is a single-element named tuple. We used to take the whole tuple as result (which is correct for unnamed patterns) but for named patterns we have to select the field instead. 2. Take account of named patterns in the refutability check.
2 parents 44a61ea + ca80310 commit 98c84c3

File tree

9 files changed

+108
-5
lines changed

9 files changed

+108
-5
lines changed

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ import NameKinds.OuterSelectName
1919
import StdNames.*
2020
import config.Feature
2121
import inlines.Inlines.inInlineMethod
22+
import util.Property
2223

2324
object FirstTransform {
2425
val name: String = "firstTransform"
2526
val description: String = "some transformations to put trees into a canonical form"
27+
28+
/** Attachment key for named argument patterns */
29+
val WasNamedArg: Property.StickyKey[Unit] = Property.StickyKey()
2630
}
2731

2832
/** The first tree transform
@@ -38,6 +42,7 @@ object FirstTransform {
3842
*/
3943
class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
4044
import ast.tpd.*
45+
import FirstTransform.*
4146

4247
override def phaseName: String = FirstTransform.name
4348

@@ -156,7 +161,13 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
156161

157162
override def transformOther(tree: Tree)(using Context): Tree = tree match {
158163
case tree: Export => EmptyTree
159-
case tree: NamedArg => transformAllDeep(tree.arg)
164+
case tree: NamedArg =>
165+
val res = transformAllDeep(tree.arg)
166+
if ctx.mode.is(Mode.Pattern) then
167+
// Need to keep NamedArg status for pattern matcher to work correctly when faced
168+
// with single-element named tuples.
169+
res.pushAttachment(WasNamedArg, ())
170+
res
160171
case tree => if (tree.isType) toTypeTree(tree) else tree
161172
}
162173

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

+14-3
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,20 @@ object PatternMatcher {
386386
}
387387
else
388388
letAbstract(get) { getResult =>
389-
val selectors =
390-
if (args.tail.isEmpty) ref(getResult) :: Nil
391-
else productSelectors(getResult.info).map(ref(getResult).select(_))
389+
def isUnaryNamedTupleSelectArg(arg: Tree) =
390+
get.tpe.widenDealias.isNamedTupleType
391+
&& arg.removeAttachment(FirstTransform.WasNamedArg).isDefined
392+
// Special case: Normally, we pull out the argument wholesale if
393+
// there is only one. But if the argument is a named argument for
394+
// a single-element named tuple, we have to select the field instead.
395+
// NamedArg trees are eliminated in FirstTransform but for named arguments
396+
// of patterns we add a WasNamedArg attachment, which is used to guide the
397+
// logic here. See i22900.scala for test cases.
398+
val selectors = args match
399+
case arg :: Nil if !isUnaryNamedTupleSelectArg(arg) =>
400+
ref(getResult) :: Nil
401+
case _ =>
402+
productSelectors(getResult.info).map(ref(getResult).select(_))
392403
matchArgsPlan(selectors, args, onSuccess)
393404
}
394405
}

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ object SpaceEngine {
279279
|| unappResult <:< ConstantType(Constant(true)) // only for unapply
280280
|| (unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) // scala2 compatibility
281281
|| unapplySeqTypeElemTp(unappResult).exists // only for unapplySeq
282-
|| isProductMatch(unappResult, argLen)
282+
|| isProductMatch(unappResult.stripNamedTuple, argLen)
283283
|| extractorMemberType(unappResult, nme.isEmpty, NoSourcePosition) <:< ConstantType(Constant(false))
284284
|| unappResult.derivesFrom(defn.NonEmptyTupleClass)
285285
|| unapp.symbol == defn.TupleXXL_unapplySeq // Fixes TupleXXL.unapplySeq which returns Some but declares Option

compiler/src/dotty/tools/dotc/typer/Checking.scala

+2
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,8 @@ trait Checking {
10401040
pats.forall(recur(_, pt))
10411041
case Typed(arg, tpt) =>
10421042
check(pat, pt) && recur(arg, pt)
1043+
case NamedArg(name, pat) =>
1044+
recur(pat, pt)
10431045
case Ident(nme.WILDCARD) =>
10441046
true
10451047
case pat: QuotePattern =>

tests/run/i22900.check

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
6
2+
6
3+
6
4+
6
5+
7
6+
6
7+
7
8+
(6)

tests/run/i22900.scala

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
object NameBaseExtractor {
2+
def unapply(x: Int): Some[(someName: Int)] = Some((someName = x + 3))
3+
}
4+
object NameBaseExtractor2 {
5+
def unapply(x: Int): Some[(someName: Int, age: Int)] = Some((someName = x + 3, age = x + 4))
6+
}
7+
@main
8+
def Test =
9+
val x1 = 3 match
10+
case NameBaseExtractor(someName = x) => x
11+
println(x1)
12+
val NameBaseExtractor(someName = x2) = 3
13+
println(x2)
14+
val NameBaseExtractor((someName = x3)) = 3
15+
println(x3)
16+
17+
val NameBaseExtractor2(someName = x4, age = x5) = 3
18+
println(x4)
19+
println(x5)
20+
21+
val NameBaseExtractor2((someName = x6, age = x7)) = 3
22+
println(x6)
23+
println(x7)
24+
25+
val NameBaseExtractor(y1) = 3
26+
println(y1)

tests/run/i22900a.check

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
3
2+
6
3+
3

tests/run/i22900a.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
case class C(someName: Int)
2+
3+
object NameBaseExtractor3 {
4+
def unapply(x: Int): Some[C] = Some(C(someName = x + 3))
5+
}
6+
7+
@main
8+
def Test = {
9+
val C(someName = xx) = C(3)
10+
println(xx)
11+
val NameBaseExtractor3(C(someName = x)) = 3
12+
println(x)
13+
C(3) match
14+
case C(someName = xx) => println(xx)
15+
}

tests/warn/i22899.scala

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
case class CaseClass(a: Int)
2+
3+
object ProductMatch_CaseClass {
4+
def unapply(int: Int): CaseClass = CaseClass(int)
5+
}
6+
7+
object ProductMatch_NamedTuple {
8+
def unapply(int: Int): (a: Int) = (a = int)
9+
}
10+
11+
object NameBasedMatch_CaseClass {
12+
def unapply(int: Int): Some[CaseClass] = Some(CaseClass(int))
13+
}
14+
15+
object NameBasedMatch_NamedTuple {
16+
def unapply(int: Int): Some[(a: Int)] = Some((a = int))
17+
}
18+
19+
object Test {
20+
val ProductMatch_CaseClass(a = x1) = 1 // ok, was pattern's type (x1 : Int) is more specialized than the right hand side expression's type Int
21+
val ProductMatch_NamedTuple(a = x2) = 2 // ok, was pattern binding uses refutable extractor `org.test.ProductMatch_NamedTuple`
22+
val NameBasedMatch_CaseClass(a = x3) = 3 // ok, was pattern's type (x3 : Int) is more specialized than the right hand side expression's type Int
23+
val NameBasedMatch_NamedTuple(a = x4) = 4 // ok, was pattern's type (x4 : Int) is more specialized than the right hand side expression's type Int
24+
25+
val CaseClass(a = x5) = CaseClass(5) // ok, was pattern's type (x5 : Int) is more specialized than the right hand side expression's type Int
26+
val (a = x6) = (a = 6) // ok
27+
}

0 commit comments

Comments
 (0)