Skip to content

Commit f2aef39

Browse files
committed
Treat soft and hard unions differently when widening
1 parent e49e94f commit f2aef39

File tree

6 files changed

+150
-63
lines changed

6 files changed

+150
-63
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+45-34
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
426426
case Some(b) => return b
427427
case None =>
428428

429+
def widenOK =
430+
(tp2.widenSingletons eq tp2)
431+
&& (tp1.widenSingletons ne tp1)
432+
&& recur(tp1.widenSingletons, tp2)
433+
429434
def joinOK = tp2.dealiasKeepRefiningAnnots match {
430435
case tp2: AppliedType if !tp2.tycon.typeSymbol.isClass =>
431436
// If we apply the default algorithm for `A[X] | B[Y] <: C[Z]` where `C` is a
@@ -437,25 +442,30 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
437442
false
438443
}
439444

445+
// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
446+
// before splitting the LHS into its constituents. That way, the RHS variables are
447+
// constraint by the hard union and can be instantiated to it. If we just split and add
448+
// the two parts of the LHS separately to the constraint, the lower bound would become
449+
// a soft union.
450+
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
451+
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
452+
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
453+
case _ => true
454+
455+
// An & on the left side loses information. We compensate by also trying the join.
456+
// This is less ad-hoc than it looks since we produce joins in type inference,
457+
// and then need to check that they are indeed supertypes of the original types
458+
// under -Ycheck. Test case is i7965.scala.
440459
def containsAnd(tp: Type): Boolean = tp.dealiasKeepRefiningAnnots match
441460
case tp: AndType => true
442461
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
443462
case _ => false
444463

445-
def widenOK =
446-
(tp2.widenSingletons eq tp2) &&
447-
(tp1.widenSingletons ne tp1) &&
448-
recur(tp1.widenSingletons, tp2)
449-
450464
widenOK
451465
|| joinOK
452-
|| recur(tp11, tp2) && recur(tp12, tp2)
466+
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
453467
|| containsAnd(tp1) && recur(tp1.join, tp2)
454-
// An & on the left side loses information. Compensate by also trying the join.
455-
// This is less ad-hoc than it looks since we produce joins in type inference,
456-
// and then need to check that they are indeed supertypes of the original types
457-
// under -Ycheck. Test case is i7965.scala.
458-
case tp1: MatchType =>
468+
case tp1: MatchType =>
459469
val reduced = tp1.reduced
460470
if (reduced.exists) recur(reduced, tp2) else thirdTry
461471
case _: FlexType =>
@@ -509,35 +519,36 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
509519
fourthTry
510520
}
511521

522+
def compareTypeParamRef(tp2: TypeParamRef): Boolean =
523+
assumedTrue(tp2) || {
524+
val alwaysTrue =
525+
// The following condition is carefully formulated to catch all cases
526+
// where the subtype relation is true without needing to add a constraint
527+
// It's tricky because we might need to either approximate tp2 by its
528+
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
529+
// So if the constraint is not yet frozen, we do the same comparison again
530+
// with a frozen constraint, which means that we get a chance to do the
531+
// widening in `fourthTry` before adding to the constraint.
532+
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
533+
else isSubTypeWhenFrozen(tp1, tp2)
534+
alwaysTrue ||
535+
frozenConstraint && (tp1 match {
536+
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
537+
case _ => false
538+
}) || {
539+
if (canConstrain(tp2) && !approx.low)
540+
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
541+
else fourthTry
542+
}
543+
}
544+
512545
def thirdTry: Boolean = tp2 match {
513546
case tp2 @ AppliedType(tycon2, args2) =>
514547
compareAppliedType2(tp2, tycon2, args2)
515548
case tp2: NamedType =>
516549
thirdTryNamed(tp2)
517550
case tp2: TypeParamRef =>
518-
def compareTypeParamRef =
519-
assumedTrue(tp2) || {
520-
val alwaysTrue =
521-
// The following condition is carefully formulated to catch all cases
522-
// where the subtype relation is true without needing to add a constraint
523-
// It's tricky because we might need to either approximate tp2 by its
524-
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
525-
// So if the constraint is not yet frozen, we do the same comparison again
526-
// with a frozen constraint, which means that we get a chance to do the
527-
// widening in `fourthTry` before adding to the constraint.
528-
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
529-
else isSubTypeWhenFrozen(tp1, tp2)
530-
alwaysTrue ||
531-
frozenConstraint && (tp1 match {
532-
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
533-
case _ => false
534-
}) || {
535-
if (canConstrain(tp2) && !approx.low)
536-
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
537-
else fourthTry
538-
}
539-
}
540-
compareTypeParamRef
551+
compareTypeParamRef(tp2)
541552
case tp2: RefinedType =>
542553
def compareRefinedSlow: Boolean = {
543554
val name2 = tp2.refinedName

compiler/src/dotty/tools/dotc/core/TypeOps.scala

+8-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,14 @@ object TypeOps:
150150
tp.derivedAlias(simplify(tp.alias, theMap))
151151
case AndType(l, r) if !ctx.mode.is(Mode.Type) =>
152152
simplify(l, theMap) & simplify(r, theMap)
153-
case OrType(l, r) if !ctx.mode.is(Mode.Type) =>
153+
case tp as OrType(l, r)
154+
if !ctx.mode.is(Mode.Type)
155+
&& (tp.isSoft || defn.isBottomType(l) || defn.isBottomType(r)) =>
156+
// Normalize A | Null and Null | A to A even if the union is hard (i.e.
157+
// explicitly declared), but not if -Yexplicit-nulls is set. The reason is
158+
// that in this case the normal asSeenFrom machinery is not prepared to deal
159+
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
160+
// corrective steps, so no widening is wanted.
154161
simplify(l, theMap) | simplify(r, theMap)
155162
case AnnotatedType(parent, annot)
156163
if !ctx.mode.is(Mode.Type) && annot.symbol == defn.UncheckedVarianceAnnot =>

compiler/src/dotty/tools/dotc/core/Types.scala

+68-25
Original file line numberDiff line numberDiff line change
@@ -1151,10 +1151,11 @@ object Types {
11511151
case _ => this
11521152
}
11531153

1154-
/** Widen this type and if the result contains embedded union types, replace
1154+
/** Widen this type and if the result contains embedded soft union types, replace
11551155
* them by their joins.
1156-
* "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
1157-
* If an embedded union is found, we first try to simplify or eliminate it by
1156+
* "Embedded" means: inside type lambdas, intersections or recursive types,
1157+
* in prefixes of refined types, or in hard union types.
1158+
* If an embedded soft union is found, we first try to simplify or eliminate it by
11581159
* re-lubbing it while allowing type parameters to be constrained further.
11591160
* Any remaining union types are replaced by their joins.
11601161
*
@@ -1165,36 +1166,78 @@ object Types {
11651166
* is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
11661167
* instead of `ArrayBuffer[? >: Int | A <: Int & A]`
11671168
*
1169+
* Hard unions inside soft ones are treated specially. For illustration assume we
1170+
* want to widen the type `(A | C) \/ (B | C)` where `\/` means soft union and `|`
1171+
* means hard union. In that case, the hard unions `A | C` and `B | C` are treated
1172+
* in an asymmetric way. Only the first parts `A` and `B` are joined and the rest
1173+
* is added again with a hard union to the result. So
1174+
*
1175+
* widenUnion[ (A | C) \/ (B | C) ]
1176+
* = widenUnion[ A \/ B ] | C | C
1177+
* = D | C | C
1178+
* = D | C
1179+
*
1180+
* In general, If a hard union A | B_1 | ... | B_n is part of of a soft union,
1181+
* only A forms part of the join, and B_1, ..., B_n are pushed out, just `C` is
1182+
* pushed out above. All types that are pushed out are recombined with the result
1183+
* of the join with a lub, but that lub yields again a hard union, not a soft one.
1184+
*
11681185
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
11691186
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
11701187
*/
1171-
def widenUnion(using Context): Type = widen match {
1188+
def widenUnion(using Context): Type = widen.match {
11721189
case tp @ OrNull(tp1): OrType =>
11731190
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
11741191
val tp1Widen = tp1.widenUnionWithoutNull
11751192
if (tp1Widen.isRef(defn.AnyClass)) tp1Widen
11761193
else tp.derivedOrType(tp1Widen, defn.NullType)
11771194
case tp =>
11781195
tp.widenUnionWithoutNull
1179-
}
1196+
}.reporting(i"widenUnion($this) = $result")
11801197

1181-
def widenUnionWithoutNull(using Context): Type = widen match {
1182-
case tp @ OrType(lhs, rhs) =>
1183-
TypeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match {
1184-
case union: OrType => union.join
1185-
case res => res
1186-
}
1187-
case tp @ AndType(tp1, tp2) =>
1188-
tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1189-
case tp: RefinedType =>
1190-
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1191-
case tp: RecType =>
1192-
tp.rebind(tp.parent.widenUnion)
1193-
case tp: HKTypeLambda =>
1194-
tp.derivedLambdaType(resType = tp.resType.widenUnion)
1195-
case tp =>
1196-
tp
1197-
}
1198+
def widenUnionWithoutNull(using Context): Type =
1199+
1200+
// Split hard union `A | B1 | ... | Bn` into leftmost part `A` and list of
1201+
// pushed out parts `B1, ..., Bn`.
1202+
def splitAlts(tp: Type, follow: List[Type]): (Type, List[Type]) = tp match
1203+
case tp as OrType(lhs, rhs) if !tp.isSoft =>
1204+
splitAlts(lhs, rhs :: follow)
1205+
case _ =>
1206+
(tp, follow)
1207+
1208+
// Convert any soft unions in result of lub to hard ones */
1209+
def harden(tp: Type): Type = tp match
1210+
case tp as OrType(tp1, tp2) if tp.isSoft =>
1211+
OrType(harden(tp1), harden(tp2), soft = false)
1212+
case _ =>
1213+
tp
1214+
1215+
def recombine(tp1: Type, tp2: Type) = harden(TypeComparer.lub(tp1, tp2))
1216+
1217+
widen match
1218+
case tp @ OrType(lhs, rhs) =>
1219+
if tp.isSoft then
1220+
val (lhsCore, lhsExtras) = splitAlts(lhs.widenUnionWithoutNull, Nil)
1221+
val (rhsCore, rhsExtras) = splitAlts(rhs.widenUnionWithoutNull, Nil)
1222+
val core = TypeComparer.lub(lhsCore, rhsCore, canConstrain = true) match
1223+
case union: OrType => union.join
1224+
case res => res
1225+
rhsExtras.foldLeft(lhsExtras.foldLeft(core)(recombine))(recombine)
1226+
else
1227+
val lhs1 = lhs.widenUnionWithoutNull
1228+
val rhs1 = rhs.widenUnionWithoutNull
1229+
if (lhs1 eq lhs) && (rhs1 eq rhs) then tp else recombine(lhs1, rhs1)
1230+
case tp @ AndType(tp1, tp2) =>
1231+
tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1232+
case tp: RefinedType =>
1233+
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1234+
case tp: RecType =>
1235+
tp.rebind(tp.parent.widenUnion)
1236+
case tp: HKTypeLambda =>
1237+
tp.derivedLambdaType(resType = tp.resType.widenUnion)
1238+
case tp =>
1239+
tp
1240+
end widenUnionWithoutNull
11981241

11991242
/** Widen all top-level singletons reachable by dealiasing
12001243
* and going to the operands of & and |.
@@ -3044,9 +3087,9 @@ object Types {
30443087
myWidened
30453088
}
30463089

3047-
def derivedOrType(tp1: Type, tp2: Type)(using Context): Type =
3048-
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
3049-
else OrType.make(tp1, tp2, isSoft)
3090+
def derivedOrType(tp1: Type, tp2: Type, soft: Boolean = isSoft)(using Context): Type =
3091+
if ((tp1 eq this.tp1) && (tp2 eq this.tp2) && soft == isSoft) this
3092+
else OrType.make(tp1, tp2, soft)
30503093

30513094
override def computeHash(bs: Binders): Int =
30523095
doHash(bs, if isSoft then 0 else 1, tp1, tp2)

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
214214
// of AndType and OrType to account for associativity
215215
case AndType(tp1, tp2) =>
216216
toTextInfixType(tpnme.raw.AMP, tp1, tp2) { toText(tpnme.raw.AMP) }
217-
case OrType(tp1, tp2) =>
218-
toTextInfixType(tpnme.raw.BAR, tp1, tp2) { toText(tpnme.raw.BAR) }
217+
case tp as OrType(tp1, tp2) =>
218+
toTextInfixType(tpnme.raw.BAR, tp1, tp2) {
219+
if tp.isSoft && printDebug then toText(tpnme.ZOR) else toText(tpnme.raw.BAR)
220+
}
219221
case tp @ EtaExpansion(tycon)
220222
if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty =>
221223
// don't eta contract if the application would be printed specially

tests/pos/widen-union.scala

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
object Test1:
3+
val x: Int | String = 1
4+
val y = x
5+
val z: Int | String = y
6+
7+
object Test2:
8+
type Sig = Int | String
9+
def consistent(x: Sig, y: Sig): Boolean = ???// x == y
10+
11+
def consistentLists(xs: List[Sig], ys: List[Sig]): Boolean =
12+
xs.corresponds(ys)(consistent) // OK
13+
|| xs.corresponds(ys)(consistent(_, _)) // error, found: Any, required: Int | String
14+
15+
object Test3:
16+
17+
def g[X](x: X | String): Int = ???
18+
def y: Boolean | String = ???
19+
g[Boolean](y)
20+
g(y)
21+
g[Boolean](identity(y))
22+
g(identity(y))
23+
24+

tests/run/i8726.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
case class A(a: Int)
2-
object C { def unapply(a: A): true | true = true }
2+
object C { def unapply(a: A): true = true }
33

44
@main
55
def Test = (A(1): A | A) match { case C() => "OK" }

0 commit comments

Comments
 (0)