Skip to content

Commit 9b872eb

Browse files
authored
Merge pull request #6448 from dotty-staging/add-case-for
Fix #2578: (part 2) Make for-generators filter only if prefixed with `case`.
2 parents 5906ccf + 9a7c709 commit 9b872eb

File tree

19 files changed

+486
-105
lines changed

19 files changed

+486
-105
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+69-44
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,19 @@ object desugar {
3333
*/
3434
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key
3535

36-
/** An attachment for match expressions generated from a PatDef */
37-
val PatDefMatch: Property.Key[Unit] = new Property.Key
36+
/** An attachment for match expressions generated from a PatDef or GenFrom.
37+
* Value of key == one of IrrefutablePatDef, IrrefutableGenFrom
38+
*/
39+
val CheckIrrefutable: Property.Key[MatchCheck] = new Property.StickyKey
40+
41+
/** What static check should be applied to a Match (none, irrefutable, exhaustive) */
42+
class MatchCheck(val n: Int) extends AnyVal
43+
object MatchCheck {
44+
val None = new MatchCheck(0)
45+
val Exhaustive = new MatchCheck(1)
46+
val IrrefutablePatDef = new MatchCheck(2)
47+
val IrrefutableGenFrom = new MatchCheck(3)
48+
}
3849

3950
/** Info of a variable in a pattern: The named tree and its type */
4051
private type VarInfo = (NameTree, Tree)
@@ -926,6 +937,22 @@ object desugar {
926937
}
927938
}
928939

940+
/** The selector of a match, which depends of the given `checkMode`.
941+
* @param sel the original selector
942+
* @return if `checkMode` is
943+
* - None : sel @unchecked
944+
* - Exhaustive : sel
945+
* - IrrefutablePatDef,
946+
* IrrefutableGenFrom: sel @unchecked with attachment `CheckIrrefutable -> checkMode`
947+
*/
948+
def makeSelector(sel: Tree, checkMode: MatchCheck)(implicit ctx: Context): Tree =
949+
if (checkMode == MatchCheck.Exhaustive) sel
950+
else {
951+
val sel1 = Annotated(sel, New(ref(defn.UncheckedAnnotType)))
952+
if (checkMode != MatchCheck.None) sel1.pushAttachment(CheckIrrefutable, checkMode)
953+
sel1
954+
}
955+
929956
/** If `pat` is a variable pattern,
930957
*
931958
* val/var/lazy val p = e
@@ -960,11 +987,6 @@ object desugar {
960987
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
961988
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
962989

963-
def rhsUnchecked = {
964-
val rhs1 = makeAnnotated("scala.unchecked", rhs)
965-
rhs1.pushAttachment(PatDefMatch, ())
966-
rhs1
967-
}
968990
val vars =
969991
if (tupleOptimizable) // include `_`
970992
pat match {
@@ -977,7 +999,7 @@ object desugar {
977999
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids))
9781000
val matchExpr =
9791001
if (tupleOptimizable) rhs
980-
else Match(rhsUnchecked, caseDef :: Nil)
1002+
else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)
9811003
vars match {
9821004
case Nil =>
9831005
matchExpr
@@ -1120,20 +1142,16 @@ object desugar {
11201142
*
11211143
* { cases }
11221144
* ==>
1123-
* x$1 => (x$1 @unchecked) match { cases }
1145+
* x$1 => (x$1 @unchecked?) match { cases }
11241146
*
11251147
* If `nparams` != 1, expand instead to
11261148
*
1127-
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked) match { cases }
1149+
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked?) match { cases }
11281150
*/
1129-
def makeCaseLambda(cases: List[CaseDef], nparams: Int = 1, unchecked: Boolean = true)(implicit ctx: Context): Function = {
1151+
def makeCaseLambda(cases: List[CaseDef], checkMode: MatchCheck, nparams: Int = 1)(implicit ctx: Context): Function = {
11301152
val params = (1 to nparams).toList.map(makeSyntheticParameter(_))
11311153
val selector = makeTuple(params.map(p => Ident(p.name)))
1132-
1133-
if (unchecked)
1134-
Function(params, Match(Annotated(selector, New(ref(defn.UncheckedAnnotType))), cases))
1135-
else
1136-
Function(params, Match(selector, cases))
1154+
Function(params, Match(makeSelector(selector, checkMode), cases))
11371155
}
11381156

11391157
/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
@@ -1262,15 +1280,19 @@ object desugar {
12621280
*/
12631281
def makeFor(mapName: TermName, flatMapName: TermName, enums: List[Tree], body: Tree): Tree = trace(i"make for ${ForYield(enums, body)}", show = true) {
12641282

1265-
/** Make a function value pat => body.
1266-
* If pat is a var pattern id: T then this gives (id: T) => body
1267-
* Otherwise this gives { case pat => body }
1283+
/** Let `pat` be `gen`'s pattern. Make a function value `pat => body`.
1284+
* If `pat` is a var pattern `id: T` then this gives `(id: T) => body`.
1285+
* Otherwise this gives `{ case pat => body }`, where `pat` is checked to be
1286+
* irrefutable if `gen`'s checkMode is GenCheckMode.Check.
12681287
*/
1269-
def makeLambda(pat: Tree, body: Tree): Tree = pat match {
1270-
case IdPattern(named, tpt) =>
1271-
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
1288+
def makeLambda(gen: GenFrom, body: Tree): Tree = gen.pat match {
1289+
case IdPattern(named, tpt) if gen.checkMode != GenCheckMode.FilterAlways =>
1290+
Function(derivedValDef(gen.pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
12721291
case _ =>
1273-
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil)
1292+
val matchCheckMode =
1293+
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
1294+
else MatchCheck.None
1295+
makeCaseLambda(CaseDef(gen.pat, EmptyTree, body) :: Nil, matchCheckMode)
12741296
}
12751297

12761298
/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
@@ -1316,7 +1338,7 @@ object desugar {
13161338
val cases = List(
13171339
CaseDef(pat, EmptyTree, Literal(Constant(true))),
13181340
CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))))
1319-
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases))
1341+
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases, MatchCheck.None))
13201342
}
13211343

13221344
/** Is pattern `pat` irrefutable when matched against `rhs`?
@@ -1342,41 +1364,47 @@ object desugar {
13421364
}
13431365
}
13441366

1345-
def isIrrefutableGenFrom(gen: GenFrom): Boolean =
1346-
gen.isInstanceOf[IrrefutableGenFrom] ||
1347-
IdPattern.unapply(gen.pat).isDefined ||
1348-
isIrrefutable(gen.pat, gen.expr)
1367+
def needsNoFilter(gen: GenFrom): Boolean =
1368+
if (gen.checkMode == GenCheckMode.FilterAlways) // pattern was prefixed by `case`
1369+
false
1370+
else (
1371+
gen.checkMode != GenCheckMode.FilterNow ||
1372+
IdPattern.unapply(gen.pat).isDefined ||
1373+
isIrrefutable(gen.pat, gen.expr)
1374+
)
13491375

13501376
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
13511377
* matched against `rhs`.
13521378
*/
13531379
def rhsSelect(gen: GenFrom, name: TermName) = {
1354-
val rhs = if (isIrrefutableGenFrom(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
1380+
val rhs = if (needsNoFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
13551381
Select(rhs, name)
13561382
}
13571383

1384+
def checkMode(gen: GenFrom) =
1385+
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
1386+
else MatchCheck.None // refutable paterns were already eliminated in filter step
1387+
13581388
enums match {
13591389
case (gen: GenFrom) :: Nil =>
1360-
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body))
1361-
case (gen: GenFrom) :: (rest @ (GenFrom(_, _) :: _)) =>
1390+
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
1391+
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
13621392
val cont = makeFor(mapName, flatMapName, rest, body)
1363-
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
1364-
case (GenFrom(pat, rhs)) :: (rest @ GenAlias(_, _) :: _) =>
1393+
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
1394+
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
13651395
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
13661396
val pats = valeqs map { case GenAlias(pat, _) => pat }
13671397
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
1368-
val (defpat0, id0) = makeIdPat(pat)
1398+
val (defpat0, id0) = makeIdPat(gen.pat)
13691399
val (defpats, ids) = (pats map makeIdPat).unzip
13701400
val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _))
1371-
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, rhs) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1372-
val allpats = pat :: pats
1373-
val vfrom1 = new IrrefutableGenFrom(makeTuple(allpats), rhs1)
1401+
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1402+
val allpats = gen.pat :: pats
1403+
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
13741404
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13751405
case (gen: GenFrom) :: test :: rest =>
1376-
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test))
1377-
val genFrom =
1378-
if (isIrrefutableGenFrom(gen)) new IrrefutableGenFrom(gen.pat, filtered)
1379-
else GenFrom(gen.pat, filtered)
1406+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
1407+
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
13801408
makeFor(mapName, flatMapName, genFrom :: rest, body)
13811409
case _ =>
13821410
EmptyTree //may happen for erroneous input
@@ -1571,7 +1599,4 @@ object desugar {
15711599
collect(tree)
15721600
buf.toList
15731601
}
1574-
1575-
private class IrrefutableGenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile)
1576-
extends GenFrom(pat, expr)
15771602
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

+16-7
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
9999
case class DoWhile(body: Tree, cond: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
100100
case class ForYield(enums: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
101101
case class ForDo(enums: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
102-
case class GenFrom(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
102+
case class GenFrom(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit @constructorOnly src: SourceFile) extends Tree
103103
case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
104104
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree
105105
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree
@@ -116,6 +116,15 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
116116
* `Positioned#checkPos` */
117117
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)
118118

119+
/** An enum to control checking or filtering of patterns in GenFrom trees */
120+
class GenCheckMode(val x: Int) extends AnyVal
121+
object GenCheckMode {
122+
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
123+
val Check = new GenCheckMode(1) // check that pattern is irrefutable
124+
val FilterNow = new GenCheckMode(2) // filter out non-matching elements since we are not in -strict
125+
val FilterAlways = new GenCheckMode(3) // filter out non-matching elements since pattern is prefixed by `case`
126+
}
127+
119128
// ----- Modifiers -----------------------------------------------------
120129
/** Mod is intended to record syntactic information about modifiers, it's
121130
* NOT a replacement of FlagSet.
@@ -525,9 +534,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
525534
case tree: ForDo if (enums eq tree.enums) && (body eq tree.body) => tree
526535
case _ => finalize(tree, untpd.ForDo(enums, body)(tree.source))
527536
}
528-
def GenFrom(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
529-
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) => tree
530-
case _ => finalize(tree, untpd.GenFrom(pat, expr)(tree.source))
537+
def GenFrom(tree: Tree)(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit ctx: Context): Tree = tree match {
538+
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (checkMode == tree.checkMode) => tree
539+
case _ => finalize(tree, untpd.GenFrom(pat, expr, checkMode)(tree.source))
531540
}
532541
def GenAlias(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
533542
case tree: GenAlias if (pat eq tree.pat) && (expr eq tree.expr) => tree
@@ -589,8 +598,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
589598
cpy.ForYield(tree)(transform(enums), transform(expr))
590599
case ForDo(enums, body) =>
591600
cpy.ForDo(tree)(transform(enums), transform(body))
592-
case GenFrom(pat, expr) =>
593-
cpy.GenFrom(tree)(transform(pat), transform(expr))
601+
case GenFrom(pat, expr, checkMode) =>
602+
cpy.GenFrom(tree)(transform(pat), transform(expr), checkMode)
594603
case GenAlias(pat, expr) =>
595604
cpy.GenAlias(tree)(transform(pat), transform(expr))
596605
case ContextBounds(bounds, cxBounds) =>
@@ -644,7 +653,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
644653
this(this(x, enums), expr)
645654
case ForDo(enums, body) =>
646655
this(this(x, enums), body)
647-
case GenFrom(pat, expr) =>
656+
case GenFrom(pat, expr, _) =>
648657
this(this(x, pat), expr)
649658
case GenAlias(pat, expr) =>
650659
this(this(x, pat), expr)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+29-19
Original file line numberDiff line numberDiff line change
@@ -1725,18 +1725,28 @@ object Parsers {
17251725
*/
17261726
def enumerator(): Tree =
17271727
if (in.token == IF) guard()
1728+
else if (in.token == CASE) generator()
17281729
else {
17291730
val pat = pattern1()
17301731
if (in.token == EQUALS) atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, expr()) }
1731-
else generatorRest(pat)
1732+
else generatorRest(pat, casePat = false)
17321733
}
17331734

1734-
/** Generator ::= Pattern `<-' Expr
1735+
/** Generator ::= [‘case’] Pattern `<-' Expr
17351736
*/
1736-
def generator(): Tree = generatorRest(pattern1())
1737+
def generator(): Tree = {
1738+
val casePat = if (in.token == CASE) { in.skipCASE(); true } else false
1739+
generatorRest(pattern1(), casePat)
1740+
}
17371741

1738-
def generatorRest(pat: Tree): GenFrom =
1739-
atSpan(startOffset(pat), accept(LARROW)) { GenFrom(pat, expr()) }
1742+
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
1743+
atSpan(startOffset(pat), accept(LARROW)) {
1744+
val checkMode =
1745+
if (casePat) GenCheckMode.FilterAlways
1746+
else if (ctx.settings.strict.value) GenCheckMode.Check
1747+
else GenCheckMode.FilterNow // filter for now, to keep backwards compat
1748+
GenFrom(pat, expr(), checkMode)
1749+
}
17401750

17411751
/** ForExpr ::= `for' (`(' Enumerators `)' | `{' Enumerators `}')
17421752
* {nl} [`yield'] Expr
@@ -1749,16 +1759,20 @@ object Parsers {
17491759
else if (in.token == LPAREN) {
17501760
val lparenOffset = in.skipToken()
17511761
openParens.change(LPAREN, 1)
1752-
val pats = patternsOpt()
1753-
val pat =
1754-
if (in.token == RPAREN || pats.length > 1) {
1755-
wrappedEnums = false
1756-
accept(RPAREN)
1757-
openParens.change(LPAREN, -1)
1758-
atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer.
1762+
val res =
1763+
if (in.token == CASE) enumerators()
1764+
else {
1765+
val pats = patternsOpt()
1766+
val pat =
1767+
if (in.token == RPAREN || pats.length > 1) {
1768+
wrappedEnums = false
1769+
accept(RPAREN)
1770+
openParens.change(LPAREN, -1)
1771+
atSpan(lparenOffset) { makeTupleOrParens(pats) } // note: alternatives `|' need to be weeded out by typer.
1772+
}
1773+
else pats.head
1774+
generatorRest(pat, casePat = false) :: enumeratorsRest()
17591775
}
1760-
else pats.head
1761-
val res = generatorRest(pat) :: enumeratorsRest()
17621776
if (wrappedEnums) {
17631777
accept(RPAREN)
17641778
openParens.change(LPAREN, -1)
@@ -2640,11 +2654,7 @@ object Parsers {
26402654
*/
26412655
def enumCase(start: Offset, mods: Modifiers): DefTree = {
26422656
val mods1 = addMod(mods, atSpan(in.offset)(Mod.Enum())) | Case
2643-
accept(CASE)
2644-
2645-
in.adjustSepRegions(ARROW)
2646-
// Scanner thinks it is in a pattern match after seeing the `case`.
2647-
// We need to get it out of that mode by telling it we are past the `=>`
2657+
in.skipCASE()
26482658

26492659
atSpan(start, nameStart) {
26502660
val id = termIdent()

compiler/src/dotty/tools/dotc/parsing/Scanners.scala

+10
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,16 @@ object Scanners {
351351
case _ =>
352352
}
353353

354+
/** Advance beyond a case token without marking the CASE in sepRegions.
355+
* This method should be called to skip beyond CASE tokens that are
356+
* not part of matches, i.e. no ARROW is expected after them.
357+
*/
358+
def skipCASE() = {
359+
assert(token == CASE)
360+
nextToken()
361+
sepRegions = sepRegions.tail
362+
}
363+
354364
/** Produce next token, filling TokenData fields of Scanner.
355365
*/
356366
def nextToken(): Unit = {

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
570570
forText(enums, expr, keywordStr(" yield "))
571571
case ForDo(enums, expr) =>
572572
forText(enums, expr, keywordStr(" do "))
573-
case GenFrom(pat, expr) =>
573+
case GenFrom(pat, expr, checkMode) =>
574+
(Str("case ") provided checkMode == untpd.GenCheckMode.FilterAlways) ~
574575
toText(pat) ~ " <- " ~ toText(expr)
575576
case GenAlias(pat, expr) =>
576577
toText(pat) ~ " = " ~ toText(expr)

0 commit comments

Comments
 (0)