Skip to content

Commit 5d2812a

Browse files
authored
Refine criterion when to widen types (#17180)
2 parents 46e28b8 + 09f5e4c commit 5d2812a

15 files changed

+192
-45
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+8
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ trait ConstraintHandling {
6464
*/
6565
protected var canWidenAbstract: Boolean = true
6666

67+
/**
68+
* Used for match type reduction.
69+
* When an abstract type may not be widened, according to `widenAbstractOKFor`,
70+
* we record it in this set, so that we can ultimately fail the reduction, but
71+
* with all the information that comes out from continuing to widen the abstract type.
72+
*/
73+
protected var poisoned: Set[TypeParamRef] = Set.empty
74+
6775
protected var myNecessaryConstraintsOnly = false
6876
/** When collecting the constraints needed for a particular subtyping
6977
* judgment to be true, we sometimes need to approximate the constraint

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+50-7
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import typer.Applications.productSelectorTypes
2424
import reporting.trace
2525
import annotation.constructorOnly
2626
import cc.{CapturingType, derivedCapturingType, CaptureSet, stripCapturing, isBoxedCapturing, boxed, boxedUnlessFun, boxedIfTypeParam, isAlwaysPure}
27+
import NameKinds.WildcardParamName
2728

2829
/** Provides methods to compare types.
2930
*/
@@ -865,10 +866,36 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
865866
fourthTry
866867
}
867868

869+
/** Can we widen an abstract type when comparing with `tp`?
870+
* This is the case with the following cases:
871+
* - if `canWidenAbstract` is true.
872+
*
873+
* Secondly, if `tp` is a type parameter, we can widen if:
874+
* - if `tp` is not a type parameter of the matched-against case lambda
875+
* - if `tp` is an invariant or wildcard type parameter
876+
* - finally, allow widening, but record the type parameter in `poisoned`,
877+
* so that can be accounted for during the reduction step
878+
*/
879+
def widenAbstractOKFor(tp: Type): Boolean =
880+
val acc = new TypeAccumulator[Boolean]:
881+
override def apply(x: Boolean, t: Type) =
882+
x && t.match
883+
case t: TypeParamRef =>
884+
variance == 0
885+
|| (t.binder ne caseLambda)
886+
|| t.paramName.is(WildcardParamName)
887+
|| { poisoned += t; true }
888+
case _ =>
889+
foldOver(x, t)
890+
891+
canWidenAbstract && acc(true, tp)
892+
868893
def tryBaseType(cls2: Symbol) = {
869894
val base = nonExprBaseType(tp1, cls2).boxedIfTypeParam(tp1.typeSymbol)
870895
if base.exists && (base ne tp1)
871-
&& (!caseLambda.exists || canWidenAbstract || tp1.widen.underlyingClassRef(refinementOK = true).exists)
896+
&& (!caseLambda.exists
897+
|| widenAbstractOKFor(tp2)
898+
|| tp1.widen.underlyingClassRef(refinementOK = true).exists)
872899
then
873900
isSubType(base, tp2, if (tp1.isRef(cls2)) approx else approx.addLow)
874901
&& recordGadtUsageIf { MatchType.thatReducesUsingGadt(tp1) }
@@ -889,8 +916,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
889916
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true))
890917
&& (tp2.isAny || GADTusage(tp1.symbol))
891918

892-
(!caseLambda.exists || canWidenAbstract)
893-
&& isSubType(hi1.boxedIfTypeParam(tp1.symbol), tp2, approx.addLow) && (trustBounds || isSubType(lo1, tp2, approx.addLow))
919+
(!caseLambda.exists || widenAbstractOKFor(tp2))
920+
&& isSubType(hi1.boxedIfTypeParam(tp1.symbol), tp2, approx.addLow) && (trustBounds || isSubType(lo1, tp2, approx.addLow))
894921
|| compareGADT
895922
|| tryLiftedToThis1
896923
case _ =>
@@ -984,7 +1011,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
9841011
tp1.cases.corresponds(tp2.cases)(isSubType)
9851012
case _ => false
9861013
}
987-
recur(tp1.underlying, tp2) || compareMatch
1014+
(!caseLambda.exists || canWidenAbstract) && recur(tp1.underlying, tp2) || compareMatch
9881015
case tp1: AnnotatedType if tp1.isRefining =>
9891016
isNewSubType(tp1.parent)
9901017
case JavaArrayType(elem1) =>
@@ -3091,6 +3118,9 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
30913118
}
30923119

30933120
def matchCases(scrut: Type, cases: List[Type])(using Context): Type = {
3121+
// a reference for the type parameters poisoned during matching
3122+
// for use during the reduction step
3123+
var poisoned: Set[TypeParamRef] = Set.empty
30943124

30953125
def paramInstances(canApprox: Boolean) = new TypeAccumulator[Array[Type]]:
30963126
def apply(insts: Array[Type], t: Type) = t match
@@ -3102,16 +3132,24 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
31023132
case entry: TypeBounds =>
31033133
val lo = fullLowerBound(param)
31043134
val hi = fullUpperBound(param)
3105-
if isSubType(hi, lo) then lo.simplified else Range(lo, hi)
3135+
if !poisoned(param) && isSubType(hi, lo) then lo.simplified else Range(lo, hi)
31063136
case inst =>
31073137
assert(inst.exists, i"param = $param\nconstraint = $constraint")
3108-
inst.simplified
3138+
if !poisoned(param) then inst.simplified else Range(inst, inst)
31093139
insts
31103140
case _ =>
31113141
foldOver(insts, t)
31123142

31133143
def instantiateParams(insts: Array[Type]) = new ApproximatingTypeMap {
31143144
variance = 0
3145+
3146+
override def range(lo: Type, hi: Type): Type =
3147+
if variance == 0 && (lo eq hi) then
3148+
// override the default `lo eq hi` test, which removes the Range
3149+
// which leads to a Reduced result, instead of NoInstance
3150+
Range(lower(lo), upper(hi))
3151+
else super.range(lo, hi)
3152+
31153153
def apply(t: Type) = t match {
31163154
case t @ TypeParamRef(b, n) if b `eq` caseLambda => insts(n)
31173155
case t: LazyRef => apply(t.ref)
@@ -3133,9 +3171,14 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
31333171

31343172
def matches(canWidenAbstract: Boolean): Boolean =
31353173
val saved = this.canWidenAbstract
3174+
val savedPoisoned = this.poisoned
31363175
this.canWidenAbstract = canWidenAbstract
3176+
this.poisoned = Set.empty
31373177
try necessarySubType(scrut, pat)
3138-
finally this.canWidenAbstract = saved
3178+
finally
3179+
poisoned = this.poisoned
3180+
this.poisoned = savedPoisoned
3181+
this.canWidenAbstract = saved
31393182

31403183
def redux(canApprox: Boolean): MatchResult =
31413184
caseLambda match

compiler/test/dotc/pos-test-pickling.blacklist

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ i6505.scala
5050
i15158.scala
5151
i15155.scala
5252
i15827.scala
53+
i17149.scala
5354

5455
# Opaque type
5556
i5720.scala

docs/_docs/reference/new-types/match-types.md

+2
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ An instantiation `Is` is _minimal_ for `Xs` if all type variables in `Xs` that
140140
appear covariantly and nonvariantly in `Is` are as small as possible and all
141141
type variables in `Xs` that appear contravariantly in `Is` are as large as
142142
possible. Here, "small" and "large" are understood with respect to `<:`.
143+
But, type parameter will not be "large" if a pattern containing it is matched
144+
against lambda case in co- or contra-variant position.
143145

144146
For simplicity, we have omitted constraint handling so far. The full formulation
145147
of subtyping tests describes them as a function from a constraint and a pair of

tests/neg/11982.scala

+6-6
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ type Head[X] = X match {
44
}
55

66
object Unpair {
7-
def unpair[X <: Tuple2[Any, Any]]: Head[X] = 1
8-
unpair[Tuple2["msg", 42]]: "msg" // error
7+
def unpair[X <: Tuple2[Any, Any]]: Head[X] = 1 // error
8+
unpair[Tuple2["msg", 42]]: "msg"
99
}
1010

1111

@@ -14,8 +14,8 @@ type Head2[X] = X match {
1414
}
1515

1616
object Unpair2 {
17-
def unpair[X <: Tuple2[Tuple2[Any, Any], Tuple2[Any, Any]]]: Head2[X] = 1
18-
unpair[Tuple2[Tuple2["msg", 42], Tuple2[41, 40]]]: "msg" // error
17+
def unpair[X <: Tuple2[Tuple2[Any, Any], Tuple2[Any, Any]]]: Head2[X] = 1 // error
18+
unpair[Tuple2[Tuple2["msg", 42], Tuple2[41, 40]]]: "msg"
1919
}
2020

2121

@@ -35,6 +35,6 @@ type Head4[X] = X match {
3535
}
3636

3737
object Unpair4 {
38-
def unpair[X <: Foo[Any, Any]]: Head4[Foo[X, X]] = 1
39-
unpair[Foo["msg", 42]]: "msg" // error
38+
def unpair[X <: Foo[Any, Any]]: Head4[Foo[X, X]] = 1 // error
39+
unpair[Foo["msg", 42]]: "msg"
4040
}

tests/neg/i11982.check

-4
This file was deleted.

tests/neg/i11982.scala

-27
This file was deleted.

tests/neg/i13780.check

+20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
-- [E007] Type Mismatch Error: tests/neg/i13780.scala:12:32 ------------------------------------------------------------
2+
12 | def unpair[X <: Y]: Head[X] = "" // error
3+
| ^^
4+
| Found: ("" : String)
5+
| Required: Head[X]
6+
|
7+
| where: X is a type in method unpair with bounds <: A.this.Y
8+
|
9+
|
10+
| Note: a match type could not be fully reduced:
11+
|
12+
| trying to reduce Head[X]
13+
| failed since selector X
14+
| does not uniquely determine parameters a, b in
15+
| case (a, b) => a
16+
| The computed bounds for the parameters are:
17+
| a >: Any
18+
| b >: Any
19+
|
20+
| longer explanation available when compiling with `-explain`
121
-- [E007] Type Mismatch Error: tests/neg/i13780.scala:18:31 ------------------------------------------------------------
222
18 | def int[X <: Y]: Int = unpair[X] // error
323
| ^^^^^^^^^

tests/neg/i13780.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ trait Z {
99

1010
class A extends Z {
1111
type Y <: Tuple2[Any, Any]
12-
def unpair[X <: Y]: Head[X] = ""
12+
def unpair[X <: Y]: Head[X] = "" // error
1313
def any[X <: Y]: Any = unpair[X]
1414
}
1515

tests/neg/i17149.scala

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
type Ext1[S] = S match {
2+
case Seq[t] => t
3+
}
4+
type Ext2[S] = S match {
5+
case Seq[_] => Int
6+
}
7+
type Ext3[S] = S match {
8+
case Array[t] => t
9+
}
10+
type Ext4[S] = S match {
11+
case Array[_] => Int
12+
}
13+
def foo[T <: Seq[Any], A <: Array[B], B] =
14+
summon[Ext1[T] =:= T] // error
15+
summon[Ext2[T] =:= Int] // ok
16+
summon[Ext3[A] =:= B] // ok
17+
summon[Ext4[A] =:= Int] // ok

tests/pos/i15926.contra.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
trait Show[-A >: Nothing]
2+
3+
type MT1[I <: Show[Nothing], N] = I match
4+
case Show[a] => N match
5+
case Int => a
6+
7+
val a = summon[MT1[Show[String], Int] =:= String]

tests/pos/i15926.extract.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// like pos/i15926.scala
2+
// but with the nested match type extracted
3+
// which is a workaround that fixed the problem
4+
sealed trait Nat
5+
final case class Zero() extends Nat
6+
final case class Succ[+N <: Nat]() extends Nat
7+
8+
final case class Neg[+N <: Succ[Nat]]()
9+
10+
type Sum[X, Y] = Y match
11+
case Zero => X
12+
case Succ[y] => Sum[Succ[X], y]
13+
14+
type IntSum[A, B] = B match
15+
case Neg[b] => IntSumNeg[A, b]
16+
17+
type IntSumNeg[A, B] = A match
18+
case Neg[a] => Neg[Sum[a, B]]
19+
20+
type One = Succ[Zero]
21+
type Two = Succ[One]
22+
23+
class Test:
24+
def test() = summon[IntSum[Neg[One], Neg[One]] =:= Neg[Two]]

tests/pos/i15926.min.scala

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// like pos/i15926.scala
2+
// but minimised to the subset of paths needed
3+
// to fail the specific test case
4+
sealed trait Nat
5+
final case class Zero() extends Nat
6+
final case class Succ[+N <: Nat]() extends Nat
7+
8+
final case class Neg[+N <: Succ[Nat]]()
9+
10+
type Sum[X, Y] = Y match
11+
case Zero => X
12+
case Succ[y] => Sum[Succ[X], y]
13+
14+
type IntSum[A, B] = B match
15+
case Neg[b] => A match
16+
case Neg[a] => Neg[Sum[a, b]]
17+
18+
type One = Succ[Zero]
19+
type Two = Succ[One]
20+
21+
class Test:
22+
def test() = summon[IntSum[Neg[One], Neg[One]] =:= Neg[Two]]

tests/pos/i15926.scala

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
@main def main(): Unit =
3+
println(summon[Sum[Minus[Succ[Zero]], Minus[Succ[Zero]]] =:= Minus[Succ[Succ[Zero]]]])
4+
5+
sealed trait IntT
6+
sealed trait NatT extends IntT
7+
final case class Zero() extends NatT
8+
final case class Succ[+N <: NatT](n: N) extends NatT
9+
final case class Minus[+N <: Succ[NatT]](n: N) extends IntT
10+
11+
type NatSum[X <: NatT, Y <: NatT] <: NatT = Y match
12+
case Zero => X
13+
case Succ[y] => NatSum[Succ[X], y]
14+
15+
type NatDif[X <: NatT, Y <: NatT] <: IntT = Y match
16+
case Zero => X
17+
case Succ[y] => X match
18+
case Zero => Minus[Y]
19+
case Succ[x] => NatDif[x, y]
20+
21+
type Sum[X <: IntT, Y <: IntT] <: IntT = Y match
22+
case Zero => X
23+
case Minus[y] => X match
24+
case Minus[x] => Minus[NatSum[x, y]]
25+
case _ => NatDif[X, y]
26+
case _ => X match
27+
case Minus[x] => NatDif[Y, x]
28+
case _ => NatSum[X, Y]

tests/pos/i17149.scala

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
type Ext[S <: Seq[_]] = S match {
2+
case Seq[t] => t
3+
}
4+
5+
val _ = implicitly[Ext[Seq[Int]] =:= Int] // e.scala: Cannot prove that e.Ext[Seq[Int]] =:= Int
6+
val _ = summon[Ext[Seq[Int]] =:= Int]

0 commit comments

Comments
 (0)