Skip to content

Commit f85877d

Browse files
authored
Merge pull request #6389 from dotty-staging/add-checked-patdef
Fix #2578 Part 1: Tighten type checking of pattern bindings
2 parents a5ac12d + 80a3a67 commit f85877d

File tree

20 files changed

+196
-56
lines changed

20 files changed

+196
-56
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+8-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ object desugar {
3232
*/
3333
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key
3434

35+
/** An attachment for match expressions generated from a PatDef */
36+
val PatDefMatch: Property.Key[Unit] = new Property.Key
37+
3538
/** Info of a variable in a pattern: The named tree and its type */
3639
private type VarInfo = (NameTree, Tree)
3740

@@ -956,7 +959,11 @@ object desugar {
956959
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
957960
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
958961

959-
def rhsUnchecked = makeAnnotated("scala.unchecked", rhs)
962+
def rhsUnchecked = {
963+
val rhs1 = makeAnnotated("scala.unchecked", rhs)
964+
rhs1.pushAttachment(PatDefMatch, ())
965+
rhs1
966+
}
960967
val vars =
961968
if (tupleOptimizable) // include `_`
962969
pat match {

compiler/src/dotty/tools/dotc/core/Types.scala

+9-5
Original file line numberDiff line numberDiff line change
@@ -1528,13 +1528,17 @@ object Types {
15281528
*/
15291529
def signature(implicit ctx: Context): Signature = Signature.NotAMethod
15301530

1531-
def dropRepeatedAnnot(implicit ctx: Context): Type = this match {
1532-
case AnnotatedType(parent, annot) if annot.symbol eq defn.RepeatedAnnot => parent
1533-
case tp @ AnnotatedType(parent, annot) =>
1534-
tp.derivedAnnotatedType(parent.dropRepeatedAnnot, annot)
1535-
case tp => tp
1531+
/** Drop annotation of given `cls` from this type */
1532+
def dropAnnot(cls: Symbol)(implicit ctx: Context): Type = stripTypeVar match {
1533+
case self @ AnnotatedType(pre, annot) =>
1534+
if (annot.symbol eq cls) pre
1535+
else self.derivedAnnotatedType(pre.dropAnnot(cls), annot)
1536+
case _ =>
1537+
this
15361538
}
15371539

1540+
def dropRepeatedAnnot(implicit ctx: Context): Type = dropAnnot(defn.RepeatedAnnot)
1541+
15381542
def annotatedToRepeated(implicit ctx: Context): Type = this match {
15391543
case tp @ ExprType(tp1) => tp.derivedExprType(tp1.annotatedToRepeated)
15401544
case AnnotatedType(tp, annot) if annot matches defn.RepeatedAnnot =>

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+41-18
Original file line numberDiff line numberDiff line change
@@ -1200,14 +1200,15 @@ object Parsers {
12001200
* | ForExpr
12011201
* | [SimpleExpr `.'] id `=' Expr
12021202
* | SimpleExpr1 ArgumentExprs `=' Expr
1203-
* | PostfixExpr [Ascription]
1204-
* | [‘inline’] PostfixExpr `match' `{' CaseClauses `}'
1203+
* | Expr2
1204+
* | [‘inline’] Expr2 `match' `{' CaseClauses `}'
12051205
* | `implicit' `match' `{' ImplicitCaseClauses `}'
1206-
* Bindings ::= `(' [Binding {`,' Binding}] `)'
1207-
* Binding ::= (id | `_') [`:' Type]
1208-
* Ascription ::= `:' CompoundType
1209-
* | `:' Annotation {Annotation}
1210-
* | `:' `_' `*'
1206+
* Bindings ::= `(' [Binding {`,' Binding}] `)'
1207+
* Binding ::= (id | `_') [`:' Type]
1208+
* Expr2 ::= PostfixExpr [Ascription]
1209+
* Ascription ::= `:' InfixType
1210+
* | `:' Annotation {Annotation}
1211+
* | `:' `_' `*'
12111212
*/
12121213
val exprInParens: () => Tree = () => expr(Location.InParens)
12131214

@@ -1324,15 +1325,16 @@ object Parsers {
13241325
t
13251326
}
13261327
case COLON =>
1327-
ascription(t, location)
1328+
in.nextToken()
1329+
val t1 = ascription(t, location)
1330+
if (in.token == MATCH) expr1Rest(t1, location) else t1
13281331
case MATCH =>
13291332
matchExpr(t, startOffset(t), Match)
13301333
case _ =>
13311334
t
13321335
}
13331336

13341337
def ascription(t: Tree, location: Location.Value): Tree = atSpan(startOffset(t)) {
1335-
in.skipToken()
13361338
in.token match {
13371339
case USCORE =>
13381340
val uscoreStart = in.skipToken()
@@ -1801,7 +1803,10 @@ object Parsers {
18011803
*/
18021804
def pattern1(): Tree = {
18031805
val p = pattern2()
1804-
if (isVarPattern(p) && in.token == COLON) ascription(p, Location.InPattern)
1806+
if (isVarPattern(p) && in.token == COLON) {
1807+
in.nextToken()
1808+
ascription(p, Location.InPattern)
1809+
}
18051810
else p
18061811
}
18071812

@@ -2353,14 +2358,32 @@ object Parsers {
23532358
tmplDef(start, mods)
23542359
}
23552360

2356-
/** PatDef ::= Pattern2 {`,' Pattern2} [`:' Type] `=' Expr
2357-
* VarDef ::= PatDef | id {`,' id} `:' Type `=' `_'
2358-
* ValDcl ::= id {`,' id} `:' Type
2359-
* VarDcl ::= id {`,' id} `:' Type
2361+
/** PatDef ::= ids [‘:’ Type] ‘=’ Expr
2362+
* | Pattern2 [‘:’ Type | Ascription] ‘=’ Expr
2363+
* VarDef ::= PatDef | id {`,' id} `:' Type `=' `_'
2364+
* ValDcl ::= id {`,' id} `:' Type
2365+
* VarDcl ::= id {`,' id} `:' Type
23602366
*/
23612367
def patDefOrDcl(start: Offset, mods: Modifiers): Tree = atSpan(start, nameStart) {
2362-
val lhs = commaSeparated(pattern2)
2363-
val tpt = typedOpt()
2368+
val first = pattern2()
2369+
var lhs = first match {
2370+
case id: Ident if in.token == COMMA =>
2371+
in.nextToken()
2372+
id :: commaSeparated(() => termIdent())
2373+
case _ =>
2374+
first :: Nil
2375+
}
2376+
def emptyType = TypeTree().withSpan(Span(in.lastOffset))
2377+
val tpt =
2378+
if (in.token == COLON) {
2379+
in.nextToken()
2380+
if (in.token == AT && lhs.tail.isEmpty) {
2381+
lhs = ascription(first, Location.ElseWhere) :: Nil
2382+
emptyType
2383+
}
2384+
else toplevelTyp()
2385+
}
2386+
else emptyType
23642387
val rhs =
23652388
if (tpt.isEmpty || in.token == EQUALS) {
23662389
accept(EQUALS)
@@ -2374,9 +2397,9 @@ object Parsers {
23742397
lhs match {
23752398
case (id: BackquotedIdent) :: Nil if id.name.isTermName =>
23762399
finalizeDef(BackquotedValDef(id.name.asTermName, tpt, rhs), mods, start)
2377-
case Ident(name: TermName) :: Nil => {
2400+
case Ident(name: TermName) :: Nil =>
23782401
finalizeDef(ValDef(name, tpt, rhs), mods, start)
2379-
} case _ =>
2402+
case _ =>
23802403
PatDef(mods, lhs, tpt, rhs)
23812404
}
23822405
}

compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ object TypeTestsCasts {
220220

221221
if (expr.tpe <:< testType)
222222
if (expr.tpe.isNotNull) {
223-
ctx.warning(TypeTestAlwaysSucceeds(foundCls, testCls), tree.sourcePos)
223+
if (!inMatch) ctx.warning(TypeTestAlwaysSucceeds(foundCls, testCls), tree.sourcePos)
224224
constant(expr, Literal(Constant(true)))
225225
}
226226
else expr.testNotNull

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

+18-13
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,25 @@ trait SpaceLogic {
280280
}
281281
}
282282

283+
object SpaceEngine {
284+
285+
/** Is the unapply irrefutable?
286+
* @param unapp The unapply function reference
287+
*/
288+
def isIrrefutableUnapply(unapp: tpd.Tree)(implicit ctx: Context): Boolean = {
289+
val unappResult = unapp.tpe.widen.finalResultType
290+
unappResult.isRef(defn.SomeClass) ||
291+
unappResult =:= ConstantType(Constant(true)) ||
292+
(unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) ||
293+
productArity(unappResult) > 0
294+
}
295+
}
296+
283297
/** Scala implementation of space logic */
284298
class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
285299
import tpd._
300+
import SpaceEngine._
286301

287-
private val scalaSomeClass = ctx.requiredClass("scala.Some")
288302
private val scalaSeqFactoryClass = ctx.requiredClass("scala.collection.generic.SeqFactory")
289303
private val scalaListType = ctx.requiredClassRef("scala.collection.immutable.List")
290304
private val scalaNilType = ctx.requiredModuleRef("scala.collection.immutable.Nil")
@@ -309,15 +323,6 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
309323
else Typ(AndType(tp1, tp2), true)
310324
}
311325

312-
/** Whether the extractor is irrefutable */
313-
def irrefutable(unapp: Tree): Boolean = {
314-
// TODO: optionless patmat
315-
unapp.tpe.widen.finalResultType.isRef(scalaSomeClass) ||
316-
unapp.tpe.widen.finalResultType =:= ConstantType(Constant(true)) ||
317-
(unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) ||
318-
productArity(unapp.tpe.widen.finalResultType) > 0
319-
}
320-
321326
/** Return the space that represents the pattern `pat` */
322327
def project(pat: Tree): Space = pat match {
323328
case Literal(c) =>
@@ -340,12 +345,12 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
340345
else {
341346
val (arity, elemTp, resultTp) = unapplySeqInfo(fun.tpe.widen.finalResultType, fun.sourcePos)
342347
if (elemTp.exists)
343-
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, irrefutable(fun))
348+
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, isIrrefutableUnapply(fun))
344349
else
345-
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)), irrefutable(fun))
350+
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)),isIrrefutableUnapply(fun))
346351
}
347352
else
348-
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.map(project), irrefutable(fun))
353+
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.map(project), isIrrefutableUnapply(fun))
349354
case Typed(pat @ UnApply(_, _, _), _) => project(pat)
350355
case Typed(expr, tpt) =>
351356
Typ(erase(expr.tpe.stripAnnots), true)

compiler/src/dotty/tools/dotc/typer/Applications.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
11071107
if (selType <:< unapplyArgType) {
11081108
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
11091109
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
1110-
selType
1110+
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
11111111
} else if (isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))) {
11121112
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
11131113
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)

compiler/src/dotty/tools/dotc/typer/Checking.scala

+46-1
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@ import ProtoTypes._
1616
import Scopes._
1717
import CheckRealizable._
1818
import ErrorReporting.errorTree
19+
import rewrites.Rewrites.patch
20+
import util.Spans.Span
1921

2022
import util.SourcePosition
2123
import transform.SymUtils._
2224
import Decorators._
2325
import ErrorReporting.{err, errorType}
24-
import config.Printers.typr
26+
import config.Printers.{typr, patmatch}
2527
import NameKinds.DefaultGetterName
28+
import Applications.unapplyArgs
29+
import transform.patmat.SpaceEngine.isIrrefutableUnapply
2630

2731
import collection.mutable
2832
import SymDenotations.{NoCompleter, NoDenotation}
@@ -594,6 +598,47 @@ trait Checking {
594598
ctx.error(ex"$cls cannot be instantiated since it${rstatus.msg}", pos)
595599
}
596600

601+
/** Check that pattern `pat` is irrefutable for scrutinee tye `pt`.
602+
* This means `pat` is either marked @unchecked or `pt` conforms to the
603+
* pattern's type. If pattern is an UnApply, do the check recursively.
604+
*/
605+
def checkIrrefutable(pat: Tree, pt: Type)(implicit ctx: Context): Boolean = {
606+
patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt")
607+
608+
def fail(pat: Tree, pt: Type): Boolean = {
609+
ctx.errorOrMigrationWarning(
610+
ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type ${pt.dropAnnot(defn.UncheckedAnnot)}
611+
|
612+
|If the narrowing is intentional, this can be communicated by writing `: @unchecked` after the full pattern.${err.rewriteNotice}""",
613+
pat.sourcePos)
614+
false
615+
}
616+
617+
def check(pat: Tree, pt: Type): Boolean = (pt <:< pat.tpe) || fail(pat, pt)
618+
619+
!ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR
620+
pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || {
621+
pat match {
622+
case Bind(_, pat1) =>
623+
checkIrrefutable(pat1, pt)
624+
case UnApply(fn, _, pats) =>
625+
check(pat, pt) &&
626+
(isIrrefutableUnapply(fn) || fail(pat, pt)) && {
627+
val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.sourcePos)
628+
pats.corresponds(argPts)(checkIrrefutable)
629+
}
630+
case Alternative(pats) =>
631+
pats.forall(checkIrrefutable(_, pt))
632+
case Typed(arg, tpt) =>
633+
check(pat, pt) && checkIrrefutable(arg, pt)
634+
case Ident(nme.WILDCARD) =>
635+
true
636+
case _ =>
637+
check(pat, pt)
638+
}
639+
}
640+
}
641+
597642
/** Check that `path` is a legal prefix for an import or export clause */
598643
def checkLegalImportPath(path: Tree)(implicit ctx: Context): Unit = {
599644
checkStable(path.tpe, path.sourcePos)

compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala

+4
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ object ErrorReporting {
157157
}
158158
"""\$\{\w*\}""".r.replaceSomeIn(raw, m => translate(m.matched.drop(2).init))
159159
}
160+
161+
def rewriteNotice: String =
162+
if (ctx.scala2Mode) "\nThis patch can be inserted automatically under -rewrite."
163+
else ""
160164
}
161165

162166
def err(implicit ctx: Context): Errors = new Errors

compiler/src/dotty/tools/dotc/typer/Typer.scala

+14-3
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,15 @@ class Typer extends Namer
10361036
if (tree.isInline) checkInInlineContext("inline match", tree.posd)
10371037
val sel1 = typedExpr(tree.selector)
10381038
val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen
1039-
typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1039+
val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1040+
result match {
1041+
case Match(sel, CaseDef(pat, _, _) :: _)
1042+
if (tree.selector.removeAttachment(desugar.PatDefMatch).isDefined) =>
1043+
if (!checkIrrefutable(pat, sel.tpe) && ctx.scala2Mode)
1044+
patch(Span(pat.span.end), ": @unchecked")
1045+
case _ =>
1046+
}
1047+
result
10401048
}
10411049
}
10421050

@@ -1827,8 +1835,11 @@ class Typer extends Namer
18271835
}
18281836
case _ => arg1
18291837
}
1830-
val tpt = TypeTree(AnnotatedType(arg1.tpe.widenIfUnstable, Annotation(annot1)))
1831-
assignType(cpy.Typed(tree)(arg2, tpt), tpt)
1838+
val argType =
1839+
if (arg1.isInstanceOf[Bind]) arg1.tpe.widen // bound symbol is not accessible outside of Bind node
1840+
else arg1.tpe.widenIfUnstable
1841+
val annotatedTpt = TypeTree(AnnotatedType(argType, Annotation(annot1)))
1842+
assignType(cpy.Typed(tree)(arg2, annotatedTpt), annotatedTpt)
18321843
}
18331844
}
18341845

compiler/test-resources/repl/patdef

-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,3 @@ scala> val _ @ List(x) = List(1)
2121
val x: Int = 1
2222
scala> val List(_ @ List(x)) = List(List(2))
2323
val x: Int = 2
24-
scala> val B @ List(), C: List[Int] = List()
25-
val B: List[Int] = List()
26-
val C: List[Int] = List()

compiler/test/dotty/tools/dotc/CompilationTests.scala

+4-3
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class CompilationTests extends ParallelTesting {
150150
aggregateTests(
151151
compileFilesInDir("tests/neg", defaultOptions),
152152
compileFilesInDir("tests/neg-tailcall", defaultOptions),
153+
compileFilesInDir("tests/neg-strict", defaultOptions.and("-strict")),
153154
compileFilesInDir("tests/neg-no-kind-polymorphism", defaultOptions and "-Yno-kind-polymorphism"),
154155
compileFilesInDir("tests/neg-custom-args/deprecation", defaultOptions.and("-Xfatal-warnings", "-deprecation")),
155156
compileFilesInDir("tests/neg-custom-args/fatal-warnings", defaultOptions.and("-Xfatal-warnings")),
@@ -160,8 +161,6 @@ class CompilationTests extends ParallelTesting {
160161
compileFile("tests/neg-custom-args/i3246.scala", scala2Mode),
161162
compileFile("tests/neg-custom-args/overrideClass.scala", scala2Mode),
162163
compileFile("tests/neg-custom-args/autoTuplingTest.scala", defaultOptions.and("-language:noAutoTupling")),
163-
compileFile("tests/neg-custom-args/i1050.scala", defaultOptions.and("-strict")),
164-
compileFile("tests/neg-custom-args/nullless.scala", defaultOptions.and("-strict")),
165164
compileFile("tests/neg-custom-args/nopredef.scala", defaultOptions.and("-Yno-predef")),
166165
compileFile("tests/neg-custom-args/noimports.scala", defaultOptions.and("-Yno-imports")),
167166
compileFile("tests/neg-custom-args/noimports2.scala", defaultOptions.and("-Yno-imports")),
@@ -249,7 +248,9 @@ class CompilationTests extends ParallelTesting {
249248

250249
val lib =
251250
compileList("src", librarySources,
252-
defaultOptions.and("-Ycheck-reentrant", "-strict", "-priorityclasspath", defaultOutputDir))(libGroup)
251+
defaultOptions.and("-Ycheck-reentrant",
252+
// "-strict", // TODO: re-enable once we allow : @unchecked in pattern definitions. Right now, lots of narrowing pattern definitions fail.
253+
"-priorityclasspath", defaultOutputDir))(libGroup)
253254

254255
val compilerSources = sources(Paths.get("compiler/src"))
255256
val compilerManagedSources = sources(Properties.dottyCompilerManagedSources)

0 commit comments

Comments
 (0)