Skip to content

Two fixes to NamedTuple pattern matching #22953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/FirstTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ import NameKinds.OuterSelectName
import StdNames.*
import config.Feature
import inlines.Inlines.inInlineMethod
import util.Property

object FirstTransform {
val name: String = "firstTransform"
val description: String = "some transformations to put trees into a canonical form"

/** Attachment key for named argument patterns */
val WasNamedArg: Property.StickyKey[Unit] = Property.StickyKey()
}

/** The first tree transform
Expand All @@ -38,6 +42,7 @@ object FirstTransform {
*/
class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
import ast.tpd.*
import FirstTransform.*

override def phaseName: String = FirstTransform.name

Expand Down Expand Up @@ -156,7 +161,13 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>

override def transformOther(tree: Tree)(using Context): Tree = tree match {
case tree: Export => EmptyTree
case tree: NamedArg => transformAllDeep(tree.arg)
case tree: NamedArg =>
val res = transformAllDeep(tree.arg)
if ctx.mode.is(Mode.Pattern) then
// Need to keep NamedArg status for pattern matcher to work correctly when faced
// with single-element named tuples.
res.pushAttachment(WasNamedArg, ())
res
case tree => if (tree.isType) toTypeTree(tree) else tree
}

Expand Down
17 changes: 14 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,20 @@ object PatternMatcher {
}
else
letAbstract(get) { getResult =>
val selectors =
if (args.tail.isEmpty) ref(getResult) :: Nil
else productSelectors(getResult.info).map(ref(getResult).select(_))
def isUnaryNamedTupleSelectArg(arg: Tree) =
get.tpe.widenDealias.isNamedTupleType
&& arg.removeAttachment(FirstTransform.WasNamedArg).isDefined
// Special case: Normally, we pull out the argument wholesale if
// there is only one. But if the argument is a named argument for
// a single-element named tuple, we have to select the field instead.
// NamedArg trees are eliminated in FirstTransform but for named arguments
// of patterns we add a WasNamedArg attachment, which is used to guide the
// logic here. See i22900.scala for test cases.
val selectors = args match
case arg :: Nil if !isUnaryNamedTupleSelectArg(arg) =>
ref(getResult) :: Nil
case _ =>
productSelectors(getResult.info).map(ref(getResult).select(_))
matchArgsPlan(selectors, args, onSuccess)
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ object SpaceEngine {
|| unappResult <:< ConstantType(Constant(true)) // only for unapply
|| (unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) // scala2 compatibility
|| unapplySeqTypeElemTp(unappResult).exists // only for unapplySeq
|| isProductMatch(unappResult, argLen)
|| isProductMatch(unappResult.stripNamedTuple, argLen)
|| extractorMemberType(unappResult, nme.isEmpty, NoSourcePosition) <:< ConstantType(Constant(false))
|| unappResult.derivesFrom(defn.NonEmptyTupleClass)
|| unapp.symbol == defn.TupleXXL_unapplySeq // Fixes TupleXXL.unapplySeq which returns Some but declares Option
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,8 @@ trait Checking {
pats.forall(recur(_, pt))
case Typed(arg, tpt) =>
check(pat, pt) && recur(arg, pt)
case NamedArg(name, pat) =>
recur(pat, pt)
case Ident(nme.WILDCARD) =>
true
case pat: QuotePattern =>
Expand Down
8 changes: 8 additions & 0 deletions tests/run/i22900.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
6
6
6
6
7
6
7
(6)
26 changes: 26 additions & 0 deletions tests/run/i22900.scala
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks good.

It seems the tests were intended to have check files (they println things rather than assert), but there are no checkfiles.

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
object NameBaseExtractor {
def unapply(x: Int): Some[(someName: Int)] = Some((someName = x + 3))
}
object NameBaseExtractor2 {
def unapply(x: Int): Some[(someName: Int, age: Int)] = Some((someName = x + 3, age = x + 4))
}
@main
def Test =
val x1 = 3 match
case NameBaseExtractor(someName = x) => x
println(x1)
val NameBaseExtractor(someName = x2) = 3
println(x2)
val NameBaseExtractor((someName = x3)) = 3
println(x3)

val NameBaseExtractor2(someName = x4, age = x5) = 3
println(x4)
println(x5)

val NameBaseExtractor2((someName = x6, age = x7)) = 3
println(x6)
println(x7)

val NameBaseExtractor(y1) = 3
println(y1)
3 changes: 3 additions & 0 deletions tests/run/i22900a.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
3
6
3
15 changes: 15 additions & 0 deletions tests/run/i22900a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
case class C(someName: Int)

object NameBaseExtractor3 {
def unapply(x: Int): Some[C] = Some(C(someName = x + 3))
}

@main
def Test = {
val C(someName = xx) = C(3)
println(xx)
val NameBaseExtractor3(C(someName = x)) = 3
println(x)
C(3) match
case C(someName = xx) => println(xx)
}
27 changes: 27 additions & 0 deletions tests/warn/i22899.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
case class CaseClass(a: Int)

object ProductMatch_CaseClass {
def unapply(int: Int): CaseClass = CaseClass(int)
}

object ProductMatch_NamedTuple {
def unapply(int: Int): (a: Int) = (a = int)
}

object NameBasedMatch_CaseClass {
def unapply(int: Int): Some[CaseClass] = Some(CaseClass(int))
}

object NameBasedMatch_NamedTuple {
def unapply(int: Int): Some[(a: Int)] = Some((a = int))
}

object Test {
val ProductMatch_CaseClass(a = x1) = 1 // ok, was pattern's type (x1 : Int) is more specialized than the right hand side expression's type Int
val ProductMatch_NamedTuple(a = x2) = 2 // ok, was pattern binding uses refutable extractor `org.test.ProductMatch_NamedTuple`
val NameBasedMatch_CaseClass(a = x3) = 3 // ok, was pattern's type (x3 : Int) is more specialized than the right hand side expression's type Int
val NameBasedMatch_NamedTuple(a = x4) = 4 // ok, was pattern's type (x4 : Int) is more specialized than the right hand side expression's type Int

val CaseClass(a = x5) = CaseClass(5) // ok, was pattern's type (x5 : Int) is more specialized than the right hand side expression's type Int
val (a = x6) = (a = 6) // ok
}
Loading