Skip to content

Commit 49dd34d

Browse files
authored
Merge pull request #6398 from dotty-staging/intersection-based-gadts
Intersection based gadts
2 parents fe749ea + 01096ff commit 49dd34d

27 files changed

+851
-137
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -336,16 +336,21 @@ trait ConstraintHandling[AbstractContext] {
336336
* L2 <: L1, and
337337
* U1 <: U2
338338
*
339-
* Both `c1` and `c2` are required to derive from constraint `pre`, possibly
340-
* narrowing it with further bounds.
339+
* Both `c1` and `c2` are required to derive from constraint `pre`, without adding
340+
* any new type variables but possibly narrowing already registered ones with further bounds.
341341
*/
342342
protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint)(implicit actx: AbstractContext): Boolean =
343343
if (c2 eq pre) true
344344
else if (c1 eq pre) false
345345
else {
346346
val saved = constraint
347347
try
348-
c2.forallParams(p =>
348+
// We iterate over params of `pre`, instead of `c2` as the documentation may suggest.
349+
// As neither `c1` nor `c2` can have more params than `pre`, this only matters in one edge case.
350+
// Constraint#forallParams only iterates over params that can be directly constrained.
351+
// If `c2` has, compared to `pre`, instantiated a param and we iterated over params of `c2`,
352+
// we could miss that param being instantiated to an incompatible type in `c1`.
353+
pre.forallParams(p =>
349354
c1.contains(p) &&
350355
c2.upper(p).forall(c1.isLess(p, _)) &&
351356
isSubTypeWhenFrozen(c1.nonParamBounds(p), c2.nonParamBounds(p)))

compiler/src/dotty/tools/dotc/core/Mode.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,11 @@ object Mode {
4949
/** We are in a pattern alternative */
5050
val InPatternAlternative: Mode = newMode(7, "InPatternAlternative")
5151

52-
/** Infer GADT constraints during type comparisons `A <:< B` */
53-
val GADTflexible: Mode = newMode(8, "GADTflexible")
52+
/** Make subtyping checks instead infer constraints necessarily following from given subtyping relation.
53+
*
54+
* This enables changing [[GadtConstraint]] and alters how [[TypeComparer]] approximates constraints.
55+
*/
56+
val GadtConstraintInference: Mode = newMode(8, "GadtConstraintInference")
5457

5558
/** Assume -language:strictEquality */
5659
val StrictEquality: Mode = newMode(9, "StrictEquality")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
package dotty.tools
2+
package dotc
3+
package core
4+
5+
import Decorators._
6+
import Symbols._
7+
import Types._
8+
import Flags._
9+
import dotty.tools.dotc.reporting.trace
10+
import config.Printers._
11+
12+
trait PatternTypeConstrainer { self: TypeComparer =>
13+
14+
/** Derive type and GADT constraints that necessarily follow from a pattern with the given type matching
15+
* a scrutinee of the given type.
16+
*
17+
* This function breaks down scrutinee and pattern types into subcomponents between which there must be
18+
* a subtyping relationship, and derives constraints from those relationships. We have the following situation
19+
* in case of a (dynamic) pattern match:
20+
*
21+
* StaticScrutineeType PatternType
22+
* \ /
23+
* DynamicScrutineeType
24+
*
25+
* In simple cases, it must hold that `PatternType <: StaticScrutineeType`:
26+
*
27+
* StaticScrutineeType
28+
* | \
29+
* | PatternType
30+
* | /
31+
* DynamicScrutineeType
32+
*
33+
* A good example of a situation where the above must hold is when static scrutinee type is the root of an enum,
34+
* and the pattern is an unapply of a case class, or a case object literal (of that enum).
35+
*
36+
* In slightly more complex cases, we may need to upcast `StaticScrutineeType`:
37+
*
38+
* SharedPatternScrutineeSuperType
39+
* / \
40+
* StaticScrutineeType PatternType
41+
* \ /
42+
* DynamicScrutineeType
43+
*
44+
* This may be the case if the scrutinee is a singleton type or a path-dependent type. It is also the case
45+
* for the following definitions:
46+
*
47+
* trait Expr[T]
48+
* trait IntExpr extends Expr[T]
49+
* trait Const[T] extends Expr[T]
50+
*
51+
* StaticScrutineeType = Const[T]
52+
* PatternType = IntExpr
53+
*
54+
* Union and intersection types are an additional complication - if either scrutinee or pattern are a union type,
55+
* then the above relationships only need to hold for the "leaves" of the types.
56+
*
57+
* Finally, if pattern type contains hk-types applied to concrete types (as opposed to type variables),
58+
* or either scrutinee or pattern type contain type member refinements, the above relationships do not need
59+
* to hold at all. Consider (where `T1`, `T2` are unrelated traits):
60+
*
61+
* StaticScrutineeType = { type T <: T1 }
62+
* PatternType = { type T <: T2 }
63+
*
64+
* In the above situation, DynamicScrutineeType can equal { type T = T1 & T2 }, but there is no useful relationship
65+
* between StaticScrutineeType and PatternType (nor any of their subcomponents). Similarly:
66+
*
67+
* StaticScrutineeType = Option[T1]
68+
* PatternType = Some[T2]
69+
*
70+
* Again, DynamicScrutineeType may equal Some[T1 & T2], and there's no useful relationship between the static
71+
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
72+
* in which case the subtyping relationship "heals" the type.
73+
*/
74+
def constrainPatternType(pat: Type, scrut: Type): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
75+
76+
def classesMayBeCompatible: Boolean = {
77+
import Flags._
78+
val patClassSym = pat.widenSingleton.classSymbol
79+
val scrutClassSym = scrut.widenSingleton.classSymbol
80+
!patClassSym.exists || !scrutClassSym.exists || {
81+
if (patClassSym.is(Final)) patClassSym.derivesFrom(scrutClassSym)
82+
else if (scrutClassSym.is(Final)) scrutClassSym.derivesFrom(patClassSym)
83+
else if (!patClassSym.is(Flags.Trait) && !scrutClassSym.is(Flags.Trait))
84+
patClassSym.derivesFrom(scrutClassSym) || scrutClassSym.derivesFrom(patClassSym)
85+
else true
86+
}
87+
}
88+
89+
def stripRefinement(tp: Type): Type = tp match {
90+
case tp: RefinedOrRecType => stripRefinement(tp.parent)
91+
case tp => tp
92+
}
93+
94+
def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) {
95+
val upcasted: Type = scrut match {
96+
case scrut: TypeRef if scrut.symbol.isClass =>
97+
// we do not infer constraints following from all parents for performance reasons
98+
// in principle however, if `A extends B, C`, then `A` can be treated as `B & C`
99+
scrut.firstParent
100+
case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
101+
val patClassSym = pat.classSymbol
102+
// as above, we do not consider all parents for performance reasons
103+
def firstParentSharedWithPat(tp: Type, tpClassSym: ClassSymbol): Symbol = {
104+
var parents = tpClassSym.info.parents
105+
parents match {
106+
case first :: rest =>
107+
if (first.classSymbol == defn.ObjectClass) parents = rest
108+
case _ => ;
109+
}
110+
parents match {
111+
case first :: _ =>
112+
val firstClassSym = first.classSymbol.asClass
113+
val res = if (patClassSym.derivesFrom(firstClassSym)) firstClassSym
114+
else firstParentSharedWithPat(first, firstClassSym)
115+
res
116+
case _ => NoSymbol
117+
}
118+
}
119+
val sym = firstParentSharedWithPat(tycon, tycon.symbol.asClass)
120+
if (sym.exists) scrut.baseType(sym) else NoType
121+
case scrut: TypeProxy => scrut.superType
122+
case _ => NoType
123+
}
124+
if (upcasted.exists)
125+
constrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted)
126+
else true
127+
}
128+
129+
scrut.dealias match {
130+
case OrType(scrut1, scrut2) =>
131+
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
132+
case AndType(scrut1, scrut2) =>
133+
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
134+
case scrut: RefinedOrRecType =>
135+
constrainPatternType(pat, stripRefinement(scrut))
136+
case scrut => pat.dealias match {
137+
case OrType(pat1, pat2) =>
138+
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
139+
case AndType(pat1, pat2) =>
140+
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
141+
case scrut: RefinedOrRecType =>
142+
constrainPatternType(stripRefinement(scrut), pat)
143+
case pat =>
144+
constrainSimplePatternType(pat, scrut) || classesMayBeCompatible && constrainUpcasted(scrut)
145+
}
146+
}
147+
}
148+
149+
/** Constrain "simple" patterns (see `constrainPatternType`).
150+
*
151+
* This function attempts to modify pattern and scrutinee type s.t. the pattern must be a subtype of the scrutinee,
152+
* or otherwise it cannot possibly match. In order to do that, we:
153+
*
154+
* 1. Rely on `constrainPatternType` to break the actual scrutinee/pattern types into subcomponents
155+
* 2. Widen type parameters of scrutinee type that are not invariantly refined (see below) by the pattern type.
156+
* 3. Wrap the pattern type in a skolem to avoid overconstraining top-level abstract types in scrutinee type
157+
* 4. Check that `WidenedScrutineeType <: NarrowedPatternType`
158+
*
159+
* Importantly, note that the pattern type may contain type variables.
160+
*
161+
* ## Invariant refinement
162+
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
163+
* when `c: C[T]` and `c` is instance of `D`, then necessarily `c: D[T]`. This is violated if `A` is variant:
164+
*
165+
* trait C[+A]
166+
* trait D[+B](val b: B) extends C[B]
167+
* trait E extends D[Any](0) with C[String]
168+
*
169+
* `E` is a counter-example to the above - if `e: E`, then `e: C[String]` and `e` is instance of `D`, but
170+
* it is false that `e: D[String]`! This is a problem if we're constraining a pattern like the below:
171+
*
172+
* def foo[T](c: C[T]): T = c match {
173+
* case d: D[t] => d.b
174+
* }
175+
*
176+
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
177+
* Note, however, that if `D` was a final class, we *could* rely on that relationship.
178+
* To support typical case classes, we also assume that this relationship holds for them and their parent traits.
179+
* This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
180+
* case classes without also appropriately extending the relevant case class
181+
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
182+
*/
183+
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type): Boolean = {
184+
def refinementIsInvariant(tp: Type): Boolean = tp match {
185+
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
186+
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
187+
case _ => false
188+
}
189+
190+
def widenVariantParams = new TypeMap {
191+
def apply(tp: Type) = mapOver(tp) match {
192+
case tp @ AppliedType(tycon, args) =>
193+
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
194+
if (tparam.paramVariance != 0) TypeBounds.empty else arg
195+
)
196+
tp.derivedAppliedType(tycon, args1)
197+
case tp =>
198+
tp
199+
}
200+
}
201+
202+
val widePt = if (ctx.scala2Mode || refinementIsInvariant(patternTp)) scrutineeTp else widenVariantParams(scrutineeTp)
203+
val narrowTp = SkolemType(patternTp)
204+
trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
205+
isSubType(narrowTp, widePt)
206+
}
207+
}
208+
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+23-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ object AbsentContext {
2525

2626
/** Provides methods to compare types.
2727
*/
28-
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
28+
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] with PatternTypeConstrainer {
2929
import TypeComparer._
3030
implicit def ctx(implicit nc: AbsentContext): Context = initctx
3131

@@ -141,6 +141,13 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
141141
*/
142142
private [this] var leftRoot: Type = _
143143

144+
/** Are we forbidden from recording GADT constraints?
145+
*
146+
* This flag is set when we're already in [[Mode.GadtConstraintInference]],
147+
* to signify that we temporarily cannot record any GADT constraints.
148+
*/
149+
private[this] var frozenGadt = false
150+
144151
protected def isSubType(tp1: Type, tp2: Type, a: ApproxState): Boolean = {
145152
val savedApprox = approx
146153
val savedLeftRoot = leftRoot
@@ -840,8 +847,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
840847
gadtBoundsContain(tycon1sym, tycon2) ||
841848
gadtBoundsContain(tycon2sym, tycon1)
842849
) &&
843-
isSubType(tycon1.prefix, tycon2.prefix) &&
844-
isSubArgs(args1, args2, tp1, tparams)
850+
isSubType(tycon1.prefix, tycon2.prefix) && {
851+
// check both tycons to deal with the case when they are equal b/c of GADT constraint
852+
val tyconIsInjective = tycon1sym.isClass || tycon2sym.isClass
853+
def checkSubArgs() = isSubArgs(args1, args2, tp1, tparams)
854+
// we only record GADT constraints if tycon is guaranteed to be injective
855+
if (tyconIsInjective) checkSubArgs()
856+
else {
857+
val savedFrozenGadt = frozenGadt
858+
frozenGadt = true
859+
try checkSubArgs() finally frozenGadt = savedFrozenGadt
860+
}
861+
}
845862
if (res && touchedGADTs) GADTused = true
846863
res
847864
case _ =>
@@ -1227,8 +1244,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
12271244
* @see [[sufficientEither]] for the normal case
12281245
* @see [[necessaryEither]] for the GADTFlexible case
12291246
*/
1230-
private def either(op1: => Boolean, op2: => Boolean): Boolean =
1231-
if (ctx.mode.is(Mode.GADTflexible)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
1247+
protected def either(op1: => Boolean, op2: => Boolean): Boolean =
1248+
if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
12321249

12331250
/** Returns true iff the result of evaluating either `op1` or `op2` is true,
12341251
* trying at the same time to keep the constraint as wide as possible.
@@ -1476,7 +1493,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
14761493
*/
14771494
private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = {
14781495
val boundImprecise = approx.high || approx.low
1479-
ctx.mode.is(Mode.GADTflexible) && !frozenConstraint && !boundImprecise && {
1496+
ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && {
14801497
val tparam = tr.symbol
14811498
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
14821499
if (bound.isRef(tparam)) false

compiler/src/dotty/tools/dotc/typer/Applications.scala

+5-22
Original file line numberDiff line numberDiff line change
@@ -1085,21 +1085,6 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
10851085

10861086
def fromScala2x = unapplyFn.symbol.exists && (unapplyFn.symbol.owner is Scala2x)
10871087

1088-
/** Is `subtp` a subtype of `tp` or of some generalization of `tp`?
1089-
* The generalizations of a type T are the smallest set G such that
1090-
*
1091-
* - T is in G
1092-
* - If a typeref R in G represents a class or trait, R's superclass is in G.
1093-
* - If a type proxy P is not a reference to a class, P's supertype is in G
1094-
*/
1095-
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
1096-
if (constrainPatternType(subtp, tp)) true
1097-
else tp match {
1098-
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
1099-
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)
1100-
case _ => false
1101-
}
1102-
11031088
unapplyFn.tpe.widen match {
11041089
case mt: MethodType if mt.paramInfos.length == 1 =>
11051090
val unapplyArgType = mt.paramInfos.head
@@ -1109,17 +1094,15 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
11091094
unapp.println(i"case 1 $unapplyArgType ${ctx.typerState.constraint}")
11101095
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
11111096
selType.dropAnnot(defn.UncheckedAnnot) // need to drop @unchecked. Just because the selector is @unchecked, the pattern isn't.
1112-
} else if (isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))) {
1097+
} else {
1098+
// We ignore whether constraining the pattern succeeded.
1099+
// Constraining only fails if the pattern cannot possibly match,
1100+
// but useless pattern checks detect more such cases, so we simply rely on them instead.
1101+
ctx.addMode(Mode.GadtConstraintInference).typeComparer.constrainPatternType(unapplyArgType, selType)
11131102
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
11141103
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
11151104
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")
11161105
unapplyArgType
1117-
} else {
1118-
unapp.println("Neither sub nor super")
1119-
unapp.println(TypeComparer.explained(implicit ctx => unapplyArgType <:< selType))
1120-
errorType(
1121-
ex"Pattern type $unapplyArgType is neither a subtype nor a supertype of selector type $selType",
1122-
tree.sourcePos)
11231106
}
11241107
val dummyArg = dummyTreeOfType(ownType)
11251108
val unapplyApp = typedExpr(untpd.TypedSplice(Apply(unapplyFn, dummyArg :: Nil)))

0 commit comments

Comments
 (0)