Skip to content

Fix #3989: Fix several unsoundness problems related to variant refinement #4013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 17, 2018
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
* - If a type proxy P is not a reference to a class, P's supertype is in G
*/
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
if (subtp <:< tp) true
if (constrainPatternType(subtp, tp)) true
else tp match {
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)
Expand Down
57 changes: 57 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,63 @@ object Inferencing {
tree
}

/** Derive information about a pattern type by comparing it with some variant of the
* static scrutinee type. We have the following situation in case of a (dynamic) pattern match:
*
* StaticScrutineeType PatternType
* \ /
* DynamicScrutineeType
*
* If `PatternType` is not a subtype of `StaticScrutineeType, there's no information to be gained.
* Now let's say we can prove that `PatternType <: StaticScrutineeType`.
*
* StaticScrutineeType
* | \
* | \
* | \
* | PatternType
* | /
* DynamicScrutineeType
*
* What can we say about the relationship of parameter types between `PatternType` and
* `DynamicScrutineeType`?
*
* - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType`
* in the same way as `PatternType` ("invariant refinement"), the subtype test
* `PatternType <:< StaticScrutineeType` tells us all we need to know.
* - Otherwise, if variant refinement is a possibility we can only make predictions
* about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test
* where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams`
* replaces all type argument of variant parameters with empty bounds.
*
* Invariant refinement can be assumed if `PatternType`'s class(es) are final or
* case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`).
*/
def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = {
def refinementIsInvariant(tp: Type): Boolean = tp match {
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
case tp: AndType => refinementIsInvariant(tp.tp1) && refinementIsInvariant(tp.tp2)
case tp: OrType => refinementIsInvariant(tp.tp1) && refinementIsInvariant(tp.tp2)
case _ => false
}

def widenVariantParams = new TypeMap {
def apply(tp: Type) = mapOver(tp) match {
case tp @ AppliedType(tycon, args) =>
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
if (tparam.paramVariance != 0) TypeBounds.empty else arg
)
tp.derivedAppliedType(tycon, args1)
case tp =>
tp
}
}

val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt)
tp <:< widePt
}

/** The list of uninstantiated type variables bound by some prefix of type `T` which
* occur in at least one formal parameter type of a prefix application.
* Considered prefixes are:
Expand Down
106 changes: 106 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,73 @@ object RefChecks {
checkNoAbstractDecls(bc.asClass.superClass)
}

// Check that every term member of this concrete class has a symbol that matches the member's type
// Member types are computed by intersecting the types of all members that have the same name
// and signature. But a member selection will pick one particular implementation, according to
// the rules of overriding and linearization. This method checks that the implementation has indeed
// a type that subsumes the full member type.
def checkMemberTypesOK() = {

// First compute all member names we need to check in `membersToCheck`.
// We do not check
// - types
// - synthetic members or bridges
// - members in other concrete classes, since these have been checked before
// (this is done for efficiency)
// - members in a prefix of inherited parents that all come from Java or Scala2
// (this is done to avoid false positives since Scala2's rules for checking are different)
val membersToCheck = new util.HashSet[Name](4096)
val seenClasses = new util.HashSet[Symbol](256)
def addDecls(cls: Symbol): Unit =
if (!seenClasses.contains(cls)) {
seenClasses.addEntry(cls)
for (mbr <- cls.info.decls)
if (mbr.isTerm && !mbr.is(Synthetic | Bridge) && mbr.memberCanMatchInheritedSymbols &&
!membersToCheck.contains(mbr.name))
membersToCheck.addEntry(mbr.name)
cls.info.parents.map(_.classSymbol)
.filter(_.is(AbstractOrTrait))
.dropWhile(_.is(JavaDefined | Scala2x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why dropWhile and not filter?

.foreach(addDecls)
}
addDecls(clazz)

// For each member, check that the type of its symbol, as seen from `self`
// can override the info of this member
for (name <- membersToCheck) {
for (mbrd <- self.member(name).alternatives) {
val mbr = mbrd.symbol
val mbrType = mbr.info.asSeenFrom(self, mbr.owner)
if (!mbrType.overrides(mbrd.info, matchLoosely = true))
ctx.errorOrMigrationWarning(
em"""${mbr.showLocated} is not a legal implementation of `$name' in $clazz
| its type $mbrType
| does not conform to ${mbrd.info}""",
(if (mbr.owner == clazz) mbr else clazz).pos)
}
}
}

/** Check that inheriting a case class does not constitute a variant refinement
* of a base type of the case class. It is because of this restriction that we
* can assume invariant refinement for case classes in `constrainPatternType`.
*/
def checkCaseClassInheritanceInvariant() = {
for (caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)))
for (baseCls <- caseCls.info.baseClasses.tail)
if (baseCls.typeParams.exists(_.paramVariance != 0))
for (problem <- variantInheritanceProblems(baseCls, caseCls, "non-variant", "case "))
ctx.errorOrMigrationWarning(problem(), clazz.pos)
}
checkNoAbstractMembers()
if (abstractErrors.isEmpty)
checkNoAbstractDecls(clazz)

if (abstractErrors.nonEmpty)
ctx.error(abstractErrorMessage, clazz.pos)

checkMemberTypesOK()
checkCaseClassInheritanceInvariant()
} else if (clazz.is(Trait) && !(clazz derivesFrom defn.AnyValClass)) {
// For non-AnyVal classes, prevent abstract methods in interfaces that override
// final members in Object; see #4431
Expand All @@ -613,6 +674,51 @@ object RefChecks {
}
}

if (!clazz.is(Trait)) {
// check that parameterized base classes and traits are typed in the same way as from the superclass
// I.e. say we have
//
// Sub extends Super extends* Base
//
// where `Base` has value parameters. Enforce that
//
// Sub.thisType.baseType(Base) =:= Sub.thisType.baseType(Super).baseType(Base)
//
// This is necessary because parameter values are determined directly or indirectly
// by `Super`. So we cannot pretend they have a different type when seen from `Sub`.
def checkParameterizedTraitsOK() = {
val mixins = clazz.mixins
for {
cls <- clazz.info.baseClasses.tail
if cls.paramAccessors.nonEmpty && !mixins.contains(cls)
problem <- variantInheritanceProblems(cls, clazz.asClass.superClass, "parameterized", "super")
} ctx.error(problem(), clazz.pos)
}

checkParameterizedTraitsOK()
}

/** Check that `site` does not inherit conflicting generic instances of `baseCls`,
* when doing a direct base type or going via intermediate class `middle`. I.e, we require:
*
* site.baseType(baseCls) =:= site.baseType(middle).baseType(baseCls)
*
* Return an optional by name error message if this test fails.
*/
def variantInheritanceProblems(
baseCls: Symbol, middle: Symbol, baseStr: String, middleStr: String): Option[() => String] = {
val superBT = self.baseType(middle)
val thisBT = self.baseType(baseCls)
val combinedBT = superBT.baseType(baseCls)
if (combinedBT =:= thisBT) None // ok
else
Some(() =>
em"""illegal inheritance: $clazz inherits conflicting instances of $baseStr base $baseCls.
|
| Direct basetype: $thisBT
| Basetype via $middleStr$middle: $combinedBT""")
}

/* Returns whether there is a symbol declared in class `inclazz`
* (which must be different from `clazz`) whose name and type
* seen as a member of `class.thisType` matches `member`'s.
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ class Typer extends Namer
def typedTpt = checkSimpleKinded(typedType(tree.tpt))
def handlePattern: Tree = {
val tpt1 = typedTpt
if (!ctx.isAfterTyper) tpt1.tpe.<:<(pt)(ctx.addMode(Mode.GADTflexible))
if (!ctx.isAfterTyper) constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible))
// special case for an abstract type that comes with a class tag
tryWithClassTag(ascription(tpt1, isWildcard = true), pt)
}
Expand Down Expand Up @@ -1460,9 +1460,13 @@ class Typer extends Namer
ref
}

val seenParents = mutable.Set[Symbol]()

def typedParent(tree: untpd.Tree): Tree = {
var result = if (tree.isType) typedType(tree)(superCtx) else typedExpr(tree)(superCtx)
val psym = result.tpe.typeSymbol
if (seenParents.contains(psym)) ctx.error(i"$psym is extended twice", tree.pos)
seenParents += psym
if (tree.isType) {
if (psym.is(Trait) && !cls.is(Trait) && !cls.superClass.isSubClass(psym))
result = maybeCall(result, psym, psym.primaryConstructor.info)
Expand Down
12 changes: 8 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/VarianceChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class VarianceChecker()(implicit ctx: Context) {
def ignoreVarianceIn(base: Symbol): Boolean = (
base.isTerm
|| base.is(Package)
|| base.is(Local)
|| base.is(PrivateLocal)
)

/** The variance of a symbol occurrence of `tvar` seen at the level of the definition of `base`.
Expand Down Expand Up @@ -112,9 +112,13 @@ class VarianceChecker()(implicit ctx: Context) {
def checkVariance(sym: Symbol, pos: Position) = Validator.validateDefinition(sym) match {
case Some(VarianceError(tvar, required)) =>
def msg = i"${varianceString(tvar.flags)} $tvar occurs in ${varianceString(required)} position in type ${sym.info} of $sym"
if (ctx.scala2Mode && sym.owner.isConstructor) {
if (ctx.scala2Mode &&
(sym.owner.isConstructor || sym.ownersIterator.exists(_.is(ProtectedLocal)))) {
ctx.migrationWarning(s"According to new variance rules, this is no longer accepted; need to annotate with @uncheckedVariance:\n$msg", pos)
patch(Position(pos.end), " @scala.annotation.unchecked.uncheckedVariance") // TODO use an import or shorten if possible
// patch(Position(pos.end), " @scala.annotation.unchecked.uncheckedVariance")
// Patch is disabled until two TODOs are solved:
// TODO use an import or shorten if possible
// TODO need to use a `:' if annotation is on term
}
else ctx.error(msg, pos)
case None =>
Expand All @@ -125,7 +129,7 @@ class VarianceChecker()(implicit ctx: Context) {
// No variance check for private/protected[this] methods/values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comment? protected[this] is no longer skipped

def skip =
!sym.exists ||
sym.is(Local) || // !!! watch out for protected local!
sym.is(PrivateLocal) ||
sym.is(TypeParam) && sym.owner.isClass // already taken care of in primary constructor of class
tree match {
case defn: MemberDef if skip =>
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/util/HashSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class HashSet[T >: Null <: AnyRef](powerOfTwoInitialCapacity: Int, loadFactor: F
private[this] var limit: Int = _
private[this] var table: Array[AnyRef] = _

assert(Integer.bitCount(powerOfTwoInitialCapacity) == 1)
protected def isEqual(x: T, y: T): Boolean = x.equals(y)

// Counters for Stats
Expand Down
6 changes: 4 additions & 2 deletions compiler/test/dotty/tools/dotc/IdempotencyTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ class IdempotencyTests extends ParallelTesting {
}
val allChecks = {
check("CheckOrderIdempotency") +
check("CheckStrawmanIdempotency") +
// Disabled until strawman is fixed
// check("CheckStrawmanIdempotency") +
check("CheckPosIdempotency")
}

val allTests = {
strawmanIdempotency +
// Disabled until strawman is fixed
// strawmanIdempotency +
orderIdempotency +
posIdempotency
}
Expand Down
4 changes: 3 additions & 1 deletion compiler/test/dotty/tools/dotc/LinkTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class LinkTests extends ParallelTesting {
def testFilter = Properties.testsFilter


@Test def linkTest: Unit = {
// Disabled until strawman is fixed
// @Test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import org.junit.Ignore
@Ignore("Disabled until strawman is fixed")
@Test
def linkTest: Unit = {

This logs that a test is ignored. Otherwise we might very well forget about this test

def linkTest: Unit = {
// Setup and compile libraries
val strawmanLibGroup = TestGroup("linkTest/strawmanLibrary")
val strawmanLibTestGroup = TestGroup(strawmanLibGroup + "/tests")
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/i1240b.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ abstract class A[X] extends T[X] {
trait U[X] extends T[X] {
abstract override def foo(x: X): X = super.foo(x)
}
object Test extends A[String] with U[String] // error: accidental override
object Test extends A[String] with U[String] // error: accidental override // error: merge error
7 changes: 7 additions & 0 deletions tests/neg/i3989.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
object Test extends App {
trait A[+X] { def get: X }
case class B[X](x: X) extends A[X] { def get: X = x }
class C[X](x: Any) extends B[Any](x) with A[X] // error: not a legal implementation of `get'
def g(a: A[Int]): Int = a.get
g(new C[Int]("foo"))
}
10 changes: 10 additions & 0 deletions tests/neg/i3989a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object Test extends App {
trait A[+X]
class B[+X](val x: X) extends A[X]
class C[+X](x: Any) extends B[Any](x) with A[X]
def f(a: A[Int]): Int = a match {
case a: B[_] => a.x // error
case _ => 0
}
f(new C[Int]("foo"))
}
10 changes: 10 additions & 0 deletions tests/neg/i3989b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object Test extends App {
trait A[+X]
case class B[+X](val x: X) extends A[X]
class C[+X](x: Any) extends B[Any](x) with A[X] // error
def f(a: A[Int]): Int = a match {
case B(i) => i
case _ => 0
}
f(new C[Int]("foo"))
}
15 changes: 15 additions & 0 deletions tests/neg/i3989c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.Option
object Test extends App {
trait A[+X]
class B[+X](val x: X) extends A[X]
object B {
def unapply[X](b: B[X]): Option[X] = Some(b.x)
}

class C[+X](x: Any) extends B[Any](x) with A[X]
def f(a: A[Int]): Int = a match {
case B(i) => i // error
case _ => 0
}
f(new C[Int]("foo"))
}
18 changes: 18 additions & 0 deletions tests/neg/i3989d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
trait A[+_X] {
protected[this] type X = _X // error: variance
def f: X
}

trait B extends A[B] {
def f: X = new B {}
}

class C extends B with A[C] {
// should be required because the inherited f is of type B, not C
// override def f: X = new C
}

object Test extends App {
val c1 = new C
val c2: C = c1.f
}
11 changes: 11 additions & 0 deletions tests/neg/i3989e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
object Test extends App {
trait A[+X](val x: X)
class B extends A(5) with A("hello") // error: A is extended twice

def f(a: A[Int]): Int = a match {
case b: B => b.x
case _ => 0
}

f(new B)
}
13 changes: 13 additions & 0 deletions tests/neg/i3989f.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
object Test extends App {
trait A[+X](val x: X)
class B[+X](val y: X) extends A[X](y)
class C extends B(5) with A[String] // error: illegal inheritance

class D extends B(5) with A[Any] // ok

def f(a: A[Int]): String = a match {
case c: C => c.x
case _ => "hello"
}
f(new C)
}
Loading