Skip to content

Commit d0f48b8

Browse files
committed
Use Skolems to infer GADT constraints
The rationale for using a Skolem here is: we want to record that there is at least one value that is both of the pattern type and the scrutinee type. All symbols are now considered valid for adding GADT constraints - the rationale is that set of constrainable symbols should be either selected on a per-(sub)pattern basis, or be the same for all matches. Previously, symbols which were only appearing variantly in a scrutinee type could be considered constrainable anyway because of an outer pattern match.
1 parent e85a2c2 commit d0f48b8

13 files changed

+232
-60
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ trait ConstraintHandling[AbstractContext] {
3030

3131
protected def isSubType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
3232
protected def isSameType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
33+
protected def typeLub(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Type
3334

3435
protected def constraint: Constraint
3536
protected def constraint_=(c: Constraint): Unit
@@ -131,7 +132,7 @@ trait ConstraintHandling[AbstractContext] {
131132
homogenizeArgs = Config.alignArgsInAnd
132133
try
133134
if (isUpper) oldBounds.derivedTypeBounds(lo, hi & bound)
134-
else oldBounds.derivedTypeBounds(lo | bound, hi)
135+
else oldBounds.derivedTypeBounds(typeLub(lo, bound), hi)
135136
finally homogenizeArgs = saved
136137
}
137138
val c1 = constraint.updateEntry(param, narrowedBounds)

compiler/src/dotty/tools/dotc/core/Contexts.scala

+4
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,10 @@ object Contexts {
809809
override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
810810
override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
811811

812+
override protected def typeLub(tp1: Type, tp2: Type)(implicit ctx: Context): Type = {
813+
ctx.typeComparer.lub(tp1, tp2, admitSingletons = true)
814+
}
815+
812816
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)
813817

814818
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = {

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+10-5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
3434

3535
override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = param
3636

37+
override protected def typeLub(tp1: Type, tp2: Type)(implicit actx: AbsentContext): Type =
38+
lub(tp1, tp2)
39+
3740
private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null
3841
private[this] var recCount = 0
3942
private[this] var monitored = false
@@ -1434,9 +1437,10 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
14341437

14351438
/** The least upper bound of two types
14361439
* @param canConstrain If true, new constraints might be added to simplify the lub.
1437-
* @note We do not admit singleton types in or-types as lubs.
1440+
* @param admitSingletons We only admit singletons as parts of lubs when we must maintain necessary conditions,
1441+
* such as when inferring GADT constraints.
14381442
*/
1439-
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ {
1443+
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, admitSingletons: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ {
14401444
if (tp1 eq tp2) tp1
14411445
else if (!tp1.exists) tp1
14421446
else if (!tp2.exists) tp2
@@ -1448,6 +1452,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
14481452
else {
14491453
val t2 = mergeIfSuper(tp2, tp1, canConstrain)
14501454
if (t2.exists) t2
1455+
else if (admitSingletons) orType(tp1.widenExpr, tp2.widenExpr)
14511456
else {
14521457
val tp1w = tp1.widen
14531458
val tp2w = tp2.widen
@@ -1973,9 +1978,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
19731978
super.hasMatchingMember(name, tp1, tp2)
19741979
}
19751980

1976-
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type =
1977-
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain)") {
1978-
super.lub(tp1, tp2, canConstrain)
1981+
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, admitSingletons: Boolean = false): Type =
1982+
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain, admitSingletons=$admitSingletons)") {
1983+
super.lub(tp1, tp2, canConstrain, admitSingletons)
19791984
}
19801985

19811986
override def glb(tp1: Type, tp2: Type): Type =

compiler/src/dotty/tools/dotc/core/Types.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -3605,7 +3605,12 @@ object Types {
36053605

36063606
// ----- Skolem types -----------------------------------------------
36073607

3608-
/** A skolem type reference with underlying type `binder`. */
3608+
/** A skolem type reference with underlying type `info`.
3609+
*
3610+
* For Dotty, a skolem type is a singleton type of some unknown value of type `info`.
3611+
* Note that care is needed when creating them, since not all types need to be inhabited.
3612+
* A skolem is equal to itself and no other type.
3613+
*/
36093614
case class SkolemType(info: Type) extends UncachedProxyType with ValueType with SingletonType {
36103615
override def underlying(implicit ctx: Context): Type = info
36113616
def derivedSkolemType(info: Type)(implicit ctx: Context): SkolemType =

compiler/src/dotty/tools/dotc/typer/Applications.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
10361036
* - If a type proxy P is not a reference to a class, P's supertype is in G
10371037
*/
10381038
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
1039-
if (constrainPatternType(subtp, tp)) true
1039+
if (constrainPatternType(subtp, tp, termPattern = true)) true
10401040
else tp match {
10411041
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
10421042
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+15-33
Original file line numberDiff line numberDiff line change
@@ -150,41 +150,22 @@ object Inferencing {
150150
tree
151151
}
152152

153-
/** Derive information about a pattern type by comparing it with some variant of the
154-
* static scrutinee type. We have the following situation in case of a (dynamic) pattern match:
153+
/** Infer constraints that should be in scope for a case body with given pattern and scrutinee types.
155154
*
156-
* StaticScrutineeType PatternType
157-
* \ /
158-
* DynamicScrutineeType
155+
* If `termPattern`, infer constraints from knowing that there exists a value which of both scrutinee
156+
* and pattern types (which is the case for normal pattern matching). If not `termPattern`, instead
157+
* infer constraints from knowing that `tp <: pt`.
159158
*
160-
* If `PatternType` is not a subtype of `StaticScrutineeType, there's no information to be gained.
161-
* Now let's say we can prove that `PatternType <: StaticScrutineeType`.
159+
* If a pattern matches during normal pattern matching, we can be certain that there exists a value
160+
* which is of both scrutinee and pattern types (the value we're matching on). If this value
161+
* was in a variable, say `x`, then we could simply infer constraints from `x.type <: pt`. Since we might
162+
* be matching on an expression as well, we take a skolem of the scrutinee, which is essentially an existential
163+
* singleton type (see [[dotty.tools.dotc.core.Types.SkolemType]]).
162164
*
163-
* StaticScrutineeType
164-
* | \
165-
* | \
166-
* | \
167-
* | PatternType
168-
* | /
169-
* DynamicScrutineeType
170-
*
171-
* What can we say about the relationship of parameter types between `PatternType` and
172-
* `DynamicScrutineeType`?
173-
*
174-
* - If `DynamicScrutineeType` refines the type parameters of `StaticScrutineeType`
175-
* in the same way as `PatternType` ("invariant refinement"), the subtype test
176-
* `PatternType <:< StaticScrutineeType` tells us all we need to know.
177-
* - Otherwise, if variant refinement is a possibility we can only make predictions
178-
* about invariant parameters of `StaticScrutineeType`. Hence we do a subtype test
179-
* where `PatternType <: widenVariantParams(StaticScrutineeType)`, where `widenVariantParams`
180-
* replaces all type argument of variant parameters with empty bounds.
181-
*
182-
* Invariant refinement can be assumed if `PatternType`'s class(es) are final or
183-
* case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`).
184-
*
185-
* TODO: Update so that GADT symbols can be variant, and we special case final class types in patterns
165+
* Note that we need to sometimes widen type parameters of the scrutinee type to avoid unsoundness -
166+
* see i3989c.scala and related issue discussion on Github.
186167
*/
187-
def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = {
168+
def constrainPatternType(tp: Type, pt: Type, termPattern: Boolean)(implicit ctx: Context): Boolean = {
188169
def refinementIsInvariant(tp: Type): Boolean = tp match {
189170
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
190171
case tp: TypeProxy => refinementIsInvariant(tp.underlying)
@@ -206,8 +187,9 @@ object Inferencing {
206187
}
207188

208189
val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt)
209-
trace(i"constraining pattern type $tp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
210-
tp <:< widePt
190+
val narrowTp = if (termPattern) SkolemType(tp) else tp
191+
trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
192+
narrowTp <:< widePt
211193
}
212194
}
213195

compiler/src/dotty/tools/dotc/typer/Typer.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ class Typer extends Namer
599599
def handlePattern: Tree = {
600600
val tpt1 = typedTpt
601601
if (!ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef)
602-
constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible))
602+
constrainPatternType(tpt1.tpe, pt, termPattern = true)(ctx.addMode(Mode.GADTflexible))
603603
// special case for an abstract type that comes with a class tag
604604
tryWithClassTag(ascription(tpt1, isWildcard = true), pt)
605605
}
@@ -1095,7 +1095,7 @@ class Typer extends Namer
10951095
def caseRest(implicit ctx: Context) = {
10961096
val pat1 = checkSimpleKinded(typedType(cdef.pat)(ctx.addMode(Mode.Pattern)))
10971097
if (!ctx.isAfterTyper)
1098-
constrainPatternType(pat1.tpe, selType)(ctx.addMode(Mode.GADTflexible))
1098+
constrainPatternType(pat1.tpe, selType, termPattern = false)(ctx.addMode(Mode.GADTflexible))
10991099
val pat2 = indexPattern(cdef).transform(pat1)
11001100
val body1 = typedType(cdef.body, pt)
11011101
assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1)
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
object buffer {
2+
object EssaInt {
3+
def unapply(i: Int): Some[Int] = Some(i)
4+
}
5+
6+
case class Inv[T](t: T)
7+
8+
enum EQ[A, B] { case Refl[T]() extends EQ[T, T] }
9+
enum SUB[A, +B] { case Refl[T]() extends SUB[T, T] } // A <: B
10+
11+
def test_eq1[A, B](eq: EQ[A, B], a: A, b: B): B =
12+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
13+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) | Sko(Int)
14+
eq match { case EQ.Refl() => // a = b
15+
val success: A = b
16+
val fail: A = 0 // error
17+
0 // error
18+
}
19+
}
20+
}
21+
22+
def test_eq2[A, B](eq: EQ[A, B], a: A, b: B): B =
23+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
24+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int)
25+
eq match { case EQ.Refl() => // a = b
26+
val success: A = b
27+
val fail: A = 0 // error
28+
0 // error
29+
}
30+
}
31+
}
32+
33+
def test_sub1[A, B](sub: SUB[A, B], a: A, b: B): B =
34+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int)
35+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int)
36+
sub match { case SUB.Refl() => // b >: a
37+
val success: B = a
38+
val fail: A = 0 // error
39+
0 // error
40+
}
41+
}
42+
}
43+
44+
def test_sub2[A, B](sub: SUB[A, B], a: A, b: B): B =
45+
Inv(a) match { case Inv(_: Int) => // a >: Sko(Int)
46+
Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int)
47+
sub match { case SUB.Refl() => // b >: a
48+
val success: B = a
49+
val fail: A = 0 // error
50+
0 // error
51+
}
52+
}
53+
}
54+
55+
56+
def test_sub_eq[A, B, C](sub: SUB[A|B, C], eqA: EQ[A, 5], eqB: EQ[B, 6]): C =
57+
sub match { case SUB.Refl() => // C >: A | B
58+
eqA match { case EQ.Refl() => // A = 5
59+
eqB match { case EQ.Refl() => // B = 6
60+
val fail1: A = 0 // error
61+
val fail2: B = 0 // error
62+
0 // error
63+
}
64+
}
65+
}
66+
}

tests/neg/int-extractor.scala

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
object Test {
2+
object EssaInt {
3+
def unapply(i: Int): Some[Int] = Some(i)
4+
}
5+
6+
def foo1[T](t: T): T = t match {
7+
case EssaInt(_) =>
8+
0 // error
9+
}
10+
11+
def foo2[T](t: T): T = t match {
12+
case EssaInt(_) => t match {
13+
case EssaInt(_) =>
14+
0 // error
15+
}
16+
}
17+
18+
case class Inv[T](t: T)
19+
20+
def bar1[T](t: T): T = Inv(t) match {
21+
case Inv(EssaInt(_)) =>
22+
0 // error
23+
}
24+
25+
def bar2[T](t: T): T = t match {
26+
case Inv(EssaInt(_)) => t match {
27+
case Inv(EssaInt(_)) =>
28+
0 // error
29+
}
30+
}
31+
}

tests/neg/invariant-gadt.scala

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
object `invariant-gadt` {
2+
case class Invariant[T](value: T)
3+
4+
def unsound0[T](t: T): T = Invariant(t) match {
5+
case Invariant(_: Int) =>
6+
(0: Any) // error
7+
}
8+
9+
def unsound1[T](t: T): T = Invariant(t) match {
10+
case Invariant(_: Int) =>
11+
0 // error
12+
}
13+
14+
def unsound2[T](t: T): T = Invariant(t) match {
15+
case Invariant(value) => value match {
16+
case _: Int =>
17+
0 // error
18+
}
19+
}
20+
21+
def unsoundTwice[T](t: T): T = Invariant(t) match {
22+
case Invariant(_: Int) => Invariant(t) match {
23+
case Invariant(_: Int) =>
24+
0 // error
25+
}
26+
}
27+
}

tests/neg/typeclass-derivation2.scala

+16-4
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ object TypeLevel {
111111
* It informs that type `T` has shape `S` and also implements runtime reflection on `T`.
112112
*/
113113
abstract class Shaped[T, S <: Shape] extends Reflected[T]
114+
115+
// substitute for erasedValue that allows precise matching
116+
final abstract class Type[-A, +B]
117+
type Subtype[t] = Type[_, t]
118+
type Supertype[t] = Type[t, _]
119+
type Exactly[t] = Type[t, t]
120+
erased def typeOf[T]: Type[T, T] = ???
114121
}
115122

116123
// An algebraic datatype
@@ -203,7 +210,7 @@ trait Show[T] {
203210
def show(x: T): String
204211
}
205212
object Show {
206-
import scala.compiletime.erasedValue
213+
import scala.compiletime.{erasedValue, error}
207214
import TypeLevel._
208215

209216
inline def tryShow[T](x: T): String = implicit match {
@@ -229,9 +236,14 @@ object Show {
229236
inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String =
230237
inline erasedValue[Alts] match {
231238
case _: (Shape.Case[alt, elems] *: alts1) =>
232-
x match {
233-
case x: `alt` => showCase[T, elems](r, x)
234-
case _ => showCases[T, alts1](r, x)
239+
inline typeOf[alt] match {
240+
case _: Subtype[T] =>
241+
x match {
242+
case x: `alt` => showCase[T, elems](r, x)
243+
case _ => showCases[T, alts1](r, x)
244+
}
245+
case _ =>
246+
error("invalid call to showCases: one of Alts is not a subtype of T")
235247
}
236248
case _: Unit =>
237249
throw new MatchError(x)

tests/pos/precise-pattern-type.scala

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
object `precise-pattern-type` {
2+
class Type {
3+
def isType: Boolean = true
4+
}
5+
6+
class Tree[-T >: Null] {
7+
def tpe: T @annotation.unchecked.uncheckedVariance = ???
8+
}
9+
10+
case class Select[-T >: Null](qual: Tree[T]) extends Tree[T]
11+
12+
def test[T <: Tree[Type]](tree: T) = tree match {
13+
case Select(q) =>
14+
q.tpe.isType
15+
}
16+
}

0 commit comments

Comments
 (0)