From 50eb0e979407cf2348d82eb34d6fe8a168ba5171 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Thu, 24 Nov 2022 11:04:01 +0100 Subject: [PATCH] Avoid incorrect simplifications when updating bounds in the constraint When combining an old and a new bound, we use `Type#&`/`Type#|` which perform simplifications. This is usually fine, but if the new bounds refer to the parameter currently being updated, we can run into cyclic reasoning issues which make the simplifications invalid after the update. We already have logic for handling self-references in parameter bounds: `updateEntry` calls `ensureNonCyclic` which sanitizes the type, but at this point the simplifications have already occured. This commit simply moves the logic out of `updateEntry` so that we can sanitize the new bounds before simplification. More precisely, we rename `ensureNonCyclic` to `validBoundsFor` which calls `validBoundFor`. Both are used to sanitize bounds where needed in `addOneBound` and `unify`. Since all calls to `updateEntry` now have sanitized bounds, we no longer need to sanitize them in `updateEntry` itself, we document this change by adding a pre-condition to `updateEntry`. For the record, here's how `ConstraintsTest#validBoundsInit` used to fail. It defines a method: def foo[S >: T <: T | Int, T <: String]: Any Before this commit, when `foo` was added to the current constraints, the constraint `S <: T | Int` was propagated to the lower bound `T` of `S`. The updated upper bound of `T` was thus set to: String & (T | Int) But because `Type#&` performs simplifications, this became T | (String & Int) by relying on the fact that at this point, `T <: String`. But in fact this simplified bound no longer ensures that `T <: String`! The self-reference was then replaced by `Any` in `OrderingConstraint#ensureNonCyclic`. After this commit, the problematic simplification no longer occurs since the new `T | Int` is sanitized to `Any` before being intersected with the old bound. --- .../dotty/tools/dotc/core/Constraint.scala | 19 +++++++ .../tools/dotc/core/ConstraintHandling.scala | 10 ++-- .../tools/dotc/core/OrderingConstraint.scala | 53 ++++++++----------- .../tools/dotc/core/ConstraintsTest.scala | 24 +++++++++ 4 files changed, 71 insertions(+), 35 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index fb87aed77c41..b849c7aa7093 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -88,6 +88,8 @@ abstract class Constraint extends Showable { * - Another type, indicating a solution for the parameter * * @pre `this contains param`. + * @pre `tp` does not contain top-level references to `param` + * (see `validBoundsFor`) */ def updateEntry(param: TypeParamRef, tp: Type)(using Context): This @@ -172,6 +174,23 @@ abstract class Constraint extends Showable { */ def occursAtToplevel(param: TypeParamRef, tp: Type)(using Context): Boolean + /** Sanitize `bound` to make it either a valid upper or lower bound for + * `param` depending on `isUpper`. + * + * Toplevel references to `param`, are replaced by `Any` if `isUpper` is true + * and `Nothing` otherwise. + * + * @see `occursAtTopLevel` for a definition of "toplevel" + * @see `validBoundsFor` to sanitize both the lower and upper bound at once. + */ + def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type + + /** Sanitize `bounds` to make them valid constraints for `param`. + * + * @see `validBoundFor` for details. + */ + def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type + /** A string that shows the reverse dependencies maintained by this constraint * (coDeps and contraDeps for OrderingConstraints). */ diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index c7c005b5220f..6207e0a3d728 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -257,7 +257,7 @@ trait ConstraintHandling { end LevelAvoidMap /** Approximate `rawBound` if needed to make it a legal bound of `param` by - * avoiding wildcards and types with a level strictly greater than its + * avoiding cycles, wildcards and types with a level strictly greater than its * `nestingLevel`. * * Note that level-checking must be performed here and cannot be delayed @@ -283,7 +283,7 @@ trait ConstraintHandling { // This is necessary for i8900-unflip.scala to typecheck. val v = if necessaryConstraintsOnly then -this.variance else this.variance atVariance(v)(super.legalVar(tp)) - approx(rawBound) + constraint.validBoundFor(param, approx(rawBound), isUpper) end legalBound protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean = @@ -413,8 +413,10 @@ trait ConstraintHandling { constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1) - val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept) - var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept) + val boundKept = constraint.validBoundsFor(pKept, + constraint.nonParamBounds( pKept).substParam(pRemoved, pKept).bounds) + var boundRemoved = constraint.validBoundsFor(pKept, + constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept).bounds) if level1 != level2 then boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved) diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 1f65fa324147..603d7a3cb0e3 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -525,20 +525,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds, // ---------- Updates ------------------------------------------------------------ - /** If `inst` is a TypeBounds, make sure it does not contain toplevel - * references to `param` (see `Constraint#occursAtToplevel` for a definition - * of "toplevel"). - * Any such references are replaced by `Nothing` in the lower bound and `Any` - * in the upper bound. - * References can be direct or indirect through instantiations of other - * parameters in the constraint. - */ - private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type = - - def recur(tp: Type, fromBelow: Boolean): Type = tp match + def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type = + def recur(tp: Type): Type = tp match case tp: AndOrType => - val r1 = recur(tp.tp1, fromBelow) - val r2 = recur(tp.tp2, fromBelow) + val r1 = recur(tp.tp1) + val r2 = recur(tp.tp2) if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp else tp.match case tp: OrType => @@ -547,35 +538,34 @@ class OrderingConstraint(private val boundsMap: ParamBounds, r1 & r2 case tp: TypeParamRef => if tp eq param then - if fromBelow then defn.NothingType else defn.AnyType + if isUpper then defn.AnyType else defn.NothingType else entry(tp) match case NoType => tp - case TypeBounds(lo, hi) => if lo eq hi then recur(lo, fromBelow) else tp - case inst => recur(inst, fromBelow) + case TypeBounds(lo, hi) => if lo eq hi then recur(lo) else tp + case inst => recur(inst) case tp: TypeVar => - val underlying1 = recur(tp.underlying, fromBelow) + val underlying1 = recur(tp.underlying) if underlying1 ne tp.underlying then underlying1 else tp case CapturingType(parent, refs) => - val parent1 = recur(parent, fromBelow) + val parent1 = recur(parent) if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp case tp: AnnotatedType => - val parent1 = recur(tp.parent, fromBelow) + val parent1 = recur(tp.parent) if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp case _ => val tp1 = tp.dealiasKeepAnnots if tp1 ne tp then - val tp2 = recur(tp1, fromBelow) + val tp2 = recur(tp1) if tp2 ne tp1 then tp2 else tp else tp - inst match - case bounds: TypeBounds => - bounds.derivedTypeBounds( - recur(bounds.lo, fromBelow = true), - recur(bounds.hi, fromBelow = false)) - case _ => - inst - end ensureNonCyclic + recur(bound) + end validBoundFor + + def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type = + bounds.derivedTypeBounds( + validBoundFor(param, bounds.lo, isUpper = false), + validBoundFor(param, bounds.hi, isUpper = true)) /** Add the fact `param1 <: param2` to the constraint `current` and propagate * `<:<` relationships between parameters ("edges") but not bounds. @@ -658,9 +648,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds, current1 } - /** The public version of `updateEntry`. Guarantees that there are no cycles */ def updateEntry(param: TypeParamRef, tp: Type)(using Context): This = - updateEntry(this, param, ensureNonCyclic(param, tp)).checkWellFormed() + updateEntry(this, param, tp).checkWellFormed() def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This = order(this, param1, param2, direction).checkWellFormed() @@ -703,7 +692,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds, def replaceParamIn(other: TypeParamRef) = val oldEntry = current.entry(other) - val newEntry = current.ensureNonCyclic(other, oldEntry.substParam(param, replacement)) + val newEntry = oldEntry.substParam(param, replacement) match + case tp: TypeBounds => validBoundsFor(other, tp) + case tp => tp current = boundsLens.update(this, current, other, newEntry) var oldDepEntry = oldEntry var newDepEntry = newEntry diff --git a/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala b/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala index 5ab162b9f05c..ad8578fa3e61 100644 --- a/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala +++ b/compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala @@ -53,3 +53,27 @@ class ConstraintsTest: i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}") } end mergeBoundsTransitivity + + @Test def validBoundsInit: Unit = inCompilerContext( + TestConfiguration.basicClasspath, + scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") { + val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2 + val List(s, t) = tvars.tpes + + val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked + assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}") + assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any + } + + @Test def validBoundsUnify: Unit = inCompilerContext( + TestConfiguration.basicClasspath, + scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") { + val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2 + val List(s, t) = tvars.tpes + + s <:< t + + val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked + assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}") + assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") + }