Skip to content

Commit 3ab18a9

Browse files
committed
Create fresh type variables to keep constraints level-correct
This completes the implementation of `LevelAvoidMap` from the previous commit: we now make sure that the nonParamBounds of a type variable does not refer to variables of a higher level by creating fresh variables of the appropriate level if necessary. Each fresh variable will be upper- or lower-bounded by the existing variable it is substituted for depending on variance (an idea that I got from [1]), in the invariant case the existing variable will be instantiated to the fresh one (unlike [2], we can't simply mutate the nestingLevel of the existing variable after running avoidance on its bounds because the constraint containing these bounds might later be retracted). Additionally: - When unifying two type variables, keep the one with the lowest level in the constraint set and make sure the bounds transferred from the other one are level-correct. This required some changes in `Constraint#addLess` which previously assumed that `unify` would always keep the second parameter. - When instantiating a type variable to its full lower- or upper-bound, we also need to avoid any type variable of a higher level among its param bound. This commit is necessary to avoid leaking local types in i8900a2.scala and i8900a3.scala, these kind of leaks will become compile-time error in the next commit. This commit required making a type parameter explicit both in SnippetChecker.scala and i13809/Macros_1.scala, in both situations the problem is that the lambda passed to `map` can only be typed if the type argument of `map` contains a wildcard, but LevelAvoidMap instead creates a fresh type variable of a lower level at a point where we don't know yet that this cannot work. Since this situation seems very rare in practice, I believe this is an acceptable trade-off for soundness. [1]: Lionel Parreaux. "The simple essence of algebraic subtyping: principal type inference with subtyping made easy (functional pearl)." https://dl.acm.org/doi/abs/10.1145/3409006 [2]: Oleg Kiselyov. "How OCaml type checker works -- or what polymorphism and garbage collection have in common" https://okmij.org/ftp/ML/generalization.html
1 parent 8ed6bde commit 3ab18a9

21 files changed

+381
-128
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

+20-6
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ abstract class Constraint extends Showable {
9393
/** A constraint that includes the relationship `p1 <: p2`.
9494
* `<:` relationships between parameters ("edges") are propagated, but
9595
* non-parameter bounds are left alone.
96+
*
97+
* @param direction Must be set to `KeepParam1` or `KeepParam2` when
98+
* `p2 <: p1` is already true depending on which parameter
99+
* the caller intends to keep. This will avoid propagating
100+
* bounds that will be redundant after `p1` and `p2` are
101+
* unified.
96102
*/
97-
def addLess(p1: TypeParamRef, p2: TypeParamRef)(using Context): This
98-
99-
/** A constraint resulting from adding p2 = p1 to this constraint, and at the same
100-
* time transferring all bounds of p2 to p1
101-
*/
102-
def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): This
103+
def addLess(p1: TypeParamRef, p2: TypeParamRef,
104+
direction: UnificationDirection = UnificationDirection.NoUnification)(using Context): This
103105

104106
/** A new constraint which is derived from this constraint by removing
105107
* the type parameter `param` from the domain and replacing all top-level occurrences
@@ -174,3 +176,15 @@ abstract class Constraint extends Showable {
174176
*/
175177
def checkConsistentVars()(using Context): Unit
176178
}
179+
180+
/** When calling `Constraint#addLess(p1, p2, ...)`, the caller might end up
181+
* unifying one parameter with the other, this enum lets `addLess` know which
182+
* direction the unification will take.
183+
*/
184+
enum UnificationDirection:
185+
/** Neither p1 nor p2 will be instantiated. */
186+
case NoUnification
187+
/** `p2 := p1`, p1 left uninstantiated. */
188+
case KeepParam1
189+
/** `p1 := p2`, p2 left uninstantiated. */
190+
case KeepParam2

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+136-13
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ import Flags._
1010
import config.Config
1111
import config.Printers.typr
1212
import reporting.trace
13-
import typer.ProtoTypes.newTypeVar
13+
import typer.ProtoTypes.{newTypeVar, representedParamRef}
1414
import StdNames.tpnme
15+
import UnificationDirection.*
16+
import NameKinds.AvoidNameKind
1517

1618
/** Methods for adding constraints and solving them.
1719
*
@@ -85,13 +87,39 @@ trait ConstraintHandling {
8587
case tv: TypeVar => tv.nestingLevel
8688
case _ => Int.MaxValue
8789

90+
/** If `param` is nested deeper than `maxLevel`, try to instantiate it to a
91+
* fresh type variable of level `maxLevel` and return the new variable.
92+
* If this isn't possible, throw a TypeError.
93+
*/
94+
def atLevel(maxLevel: Int, param: TypeParamRef)(using Context): TypeParamRef =
95+
if nestingLevel(param) <= maxLevel then return param
96+
LevelAvoidMap(0, maxLevel)(param) match
97+
case freshVar: TypeVar => freshVar.origin
98+
case _ => throw new TypeError(
99+
i"Could not decrease the nesting level of ${param} from ${nestingLevel(param)} to $maxLevel in $constraint")
100+
88101
def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = constraint.nonParamBounds(param)
89102

103+
/** The full lower bound of `param` includes both the `nonParamBounds` and the
104+
* params in the constraint known to be `<: param`, except that
105+
* params with a `nestingLevel` higher than `param` will be instantiated
106+
* to a fresh param at a legal level. See the documentation of `TypeVar`
107+
* for details.
108+
*/
90109
def fullLowerBound(param: TypeParamRef)(using Context): Type =
91-
constraint.minLower(param).foldLeft(nonParamBounds(param).lo)(_ | _)
110+
val maxLevel = nestingLevel(param)
111+
var loParams = constraint.minLower(param)
112+
if maxLevel != Int.MaxValue then
113+
loParams = loParams.mapConserve(atLevel(maxLevel, _))
114+
loParams.foldLeft(nonParamBounds(param).lo)(_ | _)
92115

116+
/** The full upper bound of `param`, see the documentation of `fullLowerBounds` above. */
93117
def fullUpperBound(param: TypeParamRef)(using Context): Type =
94-
constraint.minUpper(param).foldLeft(nonParamBounds(param).hi)(_ & _)
118+
val maxLevel = nestingLevel(param)
119+
var hiParams = constraint.minUpper(param)
120+
if maxLevel != Int.MaxValue then
121+
hiParams = hiParams.mapConserve(atLevel(maxLevel, _))
122+
hiParams.foldLeft(nonParamBounds(param).hi)(_ & _)
95123

96124
/** Full bounds of `param`, including other lower/upper params.
97125
*
@@ -116,10 +144,64 @@ trait ConstraintHandling {
116144
def toAvoid(tp: NamedType): Boolean =
117145
tp.prefix == NoPrefix && !tp.symbol.isStatic && !levelOK(tp.symbol.nestingLevel)
118146

147+
/** Return a (possibly fresh) type variable of a level no greater than `maxLevel` which is:
148+
* - lower-bounded by `tp` if variance >= 0
149+
* - upper-bounded by `tp` if variance <= 0
150+
* If this isn't possible, return the empty range.
151+
*/
152+
def legalVar(tp: TypeVar): Type =
153+
val oldParam = tp.origin
154+
val nameKind =
155+
if variance > 0 then AvoidNameKind.UpperBound
156+
else if variance < 0 then AvoidNameKind.LowerBound
157+
else AvoidNameKind.BothBounds
158+
159+
/** If it exists, return the first param in the list created in a previous call to `legalVar(tp)`
160+
* with the appropriate level and variance.
161+
*/
162+
def findParam(params: List[TypeParamRef]): Option[TypeParamRef] =
163+
params.find(p =>
164+
nestingLevel(p) <= maxLevel && representedParamRef(p) == oldParam &&
165+
(p.paramName.is(AvoidNameKind.BothBounds) ||
166+
variance != 0 && p.paramName.is(nameKind)))
167+
168+
// First, check if we can reuse an existing parameter, this is more than an optimization
169+
// since it avoids an infinite loop in tests/pos/i8900-cycle.scala
170+
findParam(constraint.lower(oldParam)).orElse(findParam(constraint.upper(oldParam))) match
171+
case Some(param) =>
172+
constraint.typeVarOfParam(param)
173+
case _ =>
174+
// Otherwise, try to return a fresh type variable at `maxLevel` with
175+
// the appropriate constraints.
176+
val name = nameKind(oldParam.paramName.toTermName).toTypeName
177+
val freshVar = newTypeVar(TypeBounds.upper(tp.topType), name,
178+
nestingLevel = maxLevel, represents = oldParam)
179+
val ok =
180+
if variance < 0 then
181+
addLess(freshVar.origin, oldParam)
182+
else if variance > 0 then
183+
addLess(oldParam, freshVar.origin)
184+
else
185+
unify(freshVar.origin, oldParam)
186+
if ok then freshVar else emptyRange
187+
end legalVar
188+
189+
override def apply(tp: Type): Type = tp match
190+
case tp: TypeVar if !tp.isInstantiated && !levelOK(tp.nestingLevel) =>
191+
legalVar(tp)
192+
// TypeParamRef can occur in tl bounds
193+
case tp: TypeParamRef =>
194+
constraint.typeVarOfParam(tp) match
195+
case tvar: TypeVar =>
196+
apply(tvar)
197+
case _ => super.apply(tp)
198+
case _ =>
199+
super.apply(tp)
200+
119201
override def mapWild(t: WildcardType) =
120202
if ctx.mode.is(Mode.TypevarsMissContext) then super.mapWild(t)
121203
else
122-
val tvar = newTypeVar(apply(t.effectiveBounds).toBounds)
204+
val tvar = newTypeVar(apply(t.effectiveBounds).toBounds, nestingLevel = maxLevel)
123205
tvar
124206
end LevelAvoidMap
125207

@@ -140,7 +222,16 @@ trait ConstraintHandling {
140222
// flip the variance to under-approximate.
141223
if necessaryConstraintsOnly then variance = -variance
142224

143-
val approx = LevelAvoidMap(variance, nestingLevel(param))
225+
val approx = new LevelAvoidMap(variance, nestingLevel(param)):
226+
override def legalVar(tp: TypeVar): Type =
227+
// `legalVar` will create a type variable whose bounds depend on
228+
// `variance`, but whether the variance is positive or negative,
229+
// we can still infer necessary constraints since just creating a
230+
// type variable doesn't reduce the set of possible solutions.
231+
// Therefore, we can safely "unflip" the variance flipped above.
232+
// This is necessary for i8900-unflip.scala to typecheck.
233+
val v = if necessaryConstraintsOnly then -this.variance else this.variance
234+
atVariance(v)(super.legalVar(tp))
144235
approx(rawBound)
145236
end legalBound
146237

@@ -246,19 +337,50 @@ trait ConstraintHandling {
246337

247338
def location(using Context) = "" // i"in ${ctx.typerState.stateChainStr}" // use for debugging
248339

249-
/** Make p2 = p1, transfer all bounds of p2 to p1
250-
* @pre less(p1)(p2)
340+
/** Unify p1 with p2: one parameter will be kept in the constraint, the
341+
* other will be removed and its bounds transferred to the remaining one.
342+
*
343+
* If p1 and p2 have different `nestingLevel`, the parameter with the lowest
344+
* level will be kept and the transferred bounds from the other parameter
345+
* will be adjusted for level-correctness.
251346
*/
252347
private def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
253348
constr.println(s"unifying $p1 $p2")
254-
assert(constraint.isLess(p1, p2))
255-
constraint = constraint.addLess(p2, p1)
349+
if !constraint.isLess(p1, p2) then
350+
constraint = constraint.addLess(p1, p2)
351+
352+
val level1 = nestingLevel(p1)
353+
val level2 = nestingLevel(p2)
354+
val pKept = if level1 <= level2 then p1 else p2
355+
val pRemoved = if level1 <= level2 then p2 else p1
356+
357+
constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)
358+
359+
val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
360+
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)
361+
362+
if level1 != level2 then
363+
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)
364+
val TypeBounds(lo, hi) = boundRemoved
365+
// After avoidance, the interval might be empty, e.g. in
366+
// tests/pos/i8900-promote.scala:
367+
// >: x.type <: Singleton
368+
// becomes:
369+
// >: Int <: Singleton
370+
// In that case, we can still get a legal constraint
371+
// by replacing the lower-bound to get:
372+
// >: Int & Singleton <: Singleton
373+
if !isSub(lo, hi) then
374+
boundRemoved = TypeBounds(lo & hi, hi)
375+
256376
val down = constraint.exclusiveLower(p2, p1)
257377
val up = constraint.exclusiveUpper(p1, p2)
258-
constraint = constraint.unify(p1, p2)
259-
val bounds = constraint.nonParamBounds(p1)
260-
val lo = bounds.lo
261-
val hi = bounds.hi
378+
379+
val newBounds = (boundKept & boundRemoved).bounds
380+
constraint = constraint.updateEntry(pKept, newBounds).replace(pRemoved, pKept)
381+
382+
val lo = newBounds.lo
383+
val hi = newBounds.hi
262384
isSub(lo, hi) &&
263385
down.forall(addOneBound(_, hi, isUpper = true)) &&
264386
up.forall(addOneBound(_, lo, isUpper = false))
@@ -311,6 +433,7 @@ trait ConstraintHandling {
311433
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
312434
constraint.entry(param) match
313435
case entry: TypeBounds =>
436+
val maxLevel = nestingLevel(param)
314437
val useLowerBound = fromBelow || param.occursIn(entry.hi)
315438
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
316439
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import collection.mutable
1111
import printing._
1212

1313
import scala.annotation.internal.sharable
14+
import scala.annotation.unused
1415

1516
/** Represents GADT constraints currently in scope */
1617
sealed abstract class GadtConstraint extends Showable {

compiler/src/dotty/tools/dotc/core/NameKinds.scala

+7
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ object NameKinds {
358358
val ProtectedAccessorName: PrefixNameKind = new PrefixNameKind(PROTECTEDACCESSOR, "protected$")
359359
val InlineAccessorName: PrefixNameKind = new PrefixNameKind(INLINEACCESSOR, "inline$")
360360

361+
/** See `ConstraintHandling#LevelAvoidMap`. */
362+
enum AvoidNameKind(tag: Int, prefix: String) extends PrefixNameKind(tag, prefix):
363+
override def definesNewName = true
364+
case UpperBound extends AvoidNameKind(AVOIDUPPER, "(upper)")
365+
case LowerBound extends AvoidNameKind(AVOIDLOWER, "(lower)")
366+
case BothBounds extends AvoidNameKind(AVOIDBOTH, "(avoid)")
367+
361368
val BodyRetainerName: SuffixNameKind = new SuffixNameKind(BODYRETAINER, "$retainedBody")
362369
val FieldName: SuffixNameKind = new SuffixNameKind(FIELD, "$$local") {
363370
override def mkString(underlying: TermName, info: ThisInfo) = underlying.toString

compiler/src/dotty/tools/dotc/core/NameTags.scala

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ object NameTags extends TastyFormat.NameTags {
3232

3333
inline val SETTER = 34 // A synthesized += suffix.
3434

35+
// Name of type variables created by `ConstraintHandling#LevelAvoidMap`.
36+
final val AVOIDUPPER = 35
37+
final val AVOIDLOWER = 36
38+
final val AVOIDBOTH = 37
39+
3540
def nameTagToString(tag: Int): String = tag match {
3641
case UTF8 => "UTF8"
3742
case QUALIFIED => "QUALIFIED"

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

+19-15
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
134134
private val lowerMap : ParamOrdering,
135135
private val upperMap : ParamOrdering) extends Constraint {
136136

137+
import UnificationDirection.*
138+
137139
type This = OrderingConstraint
138140

139141
// ----------- Basic indices --------------------------------------------------
@@ -350,29 +352,37 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
350352
/** Add the fact `param1 <: param2` to the constraint `current` and propagate
351353
* `<:<` relationships between parameters ("edges") but not bounds.
352354
*/
353-
private def order(current: This, param1: TypeParamRef, param2: TypeParamRef)(using Context): This =
355+
def order(current: This, param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection = NoUnification)(using Context): This =
354356
if (param1 == param2 || current.isLess(param1, param2)) this
355357
else {
356358
assert(contains(param1), i"$param1")
357359
assert(contains(param2), i"$param2")
358-
// Is `order` called during parameter unification?
359-
val unifying = isLess(param2, param1)
360+
val unifying = direction != NoUnification
360361
val newUpper = {
361362
val up = exclusiveUpper(param2, param1)
362363
if unifying then
363364
// Since param2 <:< param1 already holds now, filter out param1 to avoid adding
364365
// duplicated orderings.
365-
param2 :: up.filterNot(_ eq param1)
366+
val filtered = up.filterNot(_ eq param1)
367+
// Only add bounds for param2 if it will be kept in the constraint after unification.
368+
if direction == KeepParam2 then
369+
param2 :: filtered
370+
else
371+
filtered
366372
else
367373
param2 :: up
368374
}
369375
val newLower = {
370376
val lower = exclusiveLower(param1, param2)
371377
if unifying then
372-
// Do not add bounds for param1 since it will be unified to param2 soon.
373-
// And, similarly filter out param2 from lowerly-ordered parameters
378+
// Similarly, filter out param2 from lowerly-ordered parameters
374379
// to avoid duplicated orderings.
375-
lower.filterNot(_ eq param2)
380+
val filtered = lower.filterNot(_ eq param2)
381+
// Only add bounds for param1 if it will be kept in the constraint after unification.
382+
if direction == KeepParam1 then
383+
param1 :: filtered
384+
else
385+
filtered
376386
else
377387
param1 :: lower
378388
}
@@ -416,14 +426,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
416426
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This =
417427
updateEntry(this, param, ensureNonCyclic(param, tp)).checkNonCyclic()
418428

419-
def addLess(param1: TypeParamRef, param2: TypeParamRef)(using Context): This =
420-
order(this, param1, param2).checkNonCyclic()
421-
422-
def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): This =
423-
val bound1 = nonParamBounds(p1).substParam(p2, p1)
424-
val bound2 = nonParamBounds(p2).substParam(p2, p1)
425-
val p1Bounds = bound1 & bound2
426-
updateEntry(p1, p1Bounds).replace(p2, p1)
429+
def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This =
430+
order(this, param1, param2, direction).checkNonCyclic()
427431

428432
// ---------- Replacements and Removals -------------------------------------
429433

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

+10
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@ class TypeApplications(val self: Type) extends AnyVal {
231231
(alias ne self) && alias.hasSimpleKind
232232
}
233233

234+
/** The top type with the same kind as `self`. */
235+
def topType(using Context): Type =
236+
if self.hasSimpleKind then
237+
defn.AnyType
238+
else EtaExpand(self.typeParams) match
239+
case tp: HKTypeLambda =>
240+
tp.derivedLambdaType(resType = tp.resultType.topType)
241+
case _ =>
242+
defn.AnyKindType
243+
234244
/** If self type is higher-kinded, its result type, otherwise NoType.
235245
* Note: The hkResult of an any-kinded type is again AnyKind.
236246
*/

compiler/src/dotty/tools/dotc/core/Types.scala

+11-2
Original file line numberDiff line numberDiff line change
@@ -4666,6 +4666,15 @@ object Types {
46664666
*
46674667
* @param origin The parameter that's tracked by the type variable.
46684668
* @param creatorState The typer state in which the variable was created.
4669+
* @param nestingLevel Symbols with a nestingLevel strictly greater than this
4670+
* will not appear in the instantiation of this type variable.
4671+
* This is enforced in `ConstraintHandling` by:
4672+
* - Maintaining the invariant that the `nonParamBounds`
4673+
* of a type variable never refer to a type with a
4674+
* greater `nestingLevel` (see `legalBound` for the reason
4675+
* why this cannot be delayed until instantiation).
4676+
* - On instantiation, replacing any param in the param bound
4677+
* with a level greater than nestingLevel (see `fullLowerBound`).
46694678
*/
46704679
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, val nestingLevel: Int) extends CachedProxyType with ValueType {
46714680
private var currentOrigin = initOrigin
@@ -4774,8 +4783,8 @@ object Types {
47744783
}
47754784
}
47764785
object TypeVar:
4777-
def apply(initOrigin: TypeParamRef, creatorState: TyperState)(using Context) =
4778-
new TypeVar(initOrigin, creatorState, ctx.nestingLevel)
4786+
def apply(using Context)(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int = ctx.nestingLevel) =
4787+
new TypeVar(initOrigin, creatorState, nestingLevel)
47794788

47804789
type TypeVars = SimpleIdentitySet[TypeVar]
47814790

0 commit comments

Comments
 (0)