Skip to content

Commit 14ddc26

Browse files
authored
Merge pull request #10112 from dotty-staging/change-union-types
Fix #10108: Distinguish between soft and hard union types
2 parents 300e9a2 + 0463200 commit 14ddc26

20 files changed

+153
-128
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ class Compiler {
7676
new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only)
7777
new ExplicitOuter, // Add accessors to outer classes from nested ones.
7878
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
79-
new StringInterpolatorOpt, // Optimizes raw and s string interpolators by rewriting them to string concatentations
80-
new CrossCastAnd) :: // Normalize selections involving intersection types.
79+
new StringInterpolatorOpt) :: // Optimizes raw and s string interpolators by rewriting them to string concatentations
8180
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
8281
new InlinePatterns, // Remove placeholders of inlined patterns
8382
new VCInlineMethods, // Inlines calls to value class methods

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ class Definitions {
440440
def AnyKindType: TypeRef = AnyKindClass.typeRef
441441

442442
@tu lazy val andType: TypeSymbol = enterBinaryAlias(tpnme.AND, AndType(_, _))
443-
@tu lazy val orType: TypeSymbol = enterBinaryAlias(tpnme.OR, OrType(_, _))
443+
@tu lazy val orType: TypeSymbol = enterBinaryAlias(tpnme.OR, OrType(_, _, soft = false))
444444

445445
/** Marker method to indicate an argument to a call-by-name parameter.
446446
* Created by byNameClosures and elimByName, eliminated by Erasure,

compiler/src/dotty/tools/dotc/core/Hashable.scala

+3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ trait Hashable {
9393
protected final def doHash(bs: Binders, x1: Int, tp2: Type): Int =
9494
finishHash(bs, hashing.mix(hashSeed, x1), 1, tp2)
9595

96+
protected final def doHash(bs: Binders, x1: Int, tp2: Type, tp3: Type): Int =
97+
finishHash(bs, hashing.mix(hashSeed, x1), 1, tp2, tp3)
98+
9699
protected final def doHash(bs: Binders, tp1: Type, tp2: Type): Int =
97100
finishHash(bs, hashSeed, 0, tp1, tp2)
98101

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+49-38
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
428428
case Some(b) => return b
429429
case None =>
430430

431+
def widenOK =
432+
(tp2.widenSingletons eq tp2)
433+
&& (tp1.widenSingletons ne tp1)
434+
&& recur(tp1.widenSingletons, tp2)
435+
431436
def joinOK = tp2.dealiasKeepRefiningAnnots match {
432437
case tp2: AppliedType if !tp2.tycon.typeSymbol.isClass =>
433438
// If we apply the default algorithm for `A[X] | B[Y] <: C[Z]` where `C` is a
@@ -439,25 +444,30 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
439444
false
440445
}
441446

447+
// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
448+
// before splitting the LHS into its constituents. That way, the RHS variables are
449+
// constraint by the hard union and can be instantiated to it. If we just split and add
450+
// the two parts of the LHS separately to the constraint, the lower bound would become
451+
// a soft union.
452+
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
453+
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
454+
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
455+
case _ => true
456+
457+
// An & on the left side loses information. We compensate by also trying the join.
458+
// This is less ad-hoc than it looks since we produce joins in type inference,
459+
// and then need to check that they are indeed supertypes of the original types
460+
// under -Ycheck. Test case is i7965.scala.
442461
def containsAnd(tp: Type): Boolean = tp.dealiasKeepRefiningAnnots match
443462
case tp: AndType => true
444463
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
445464
case _ => false
446465

447-
def widenOK =
448-
(tp2.widenSingletons eq tp2) &&
449-
(tp1.widenSingletons ne tp1) &&
450-
recur(tp1.widenSingletons, tp2)
451-
452466
widenOK
453467
|| joinOK
454-
|| recur(tp11, tp2) && recur(tp12, tp2)
468+
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
455469
|| containsAnd(tp1) && recur(tp1.join, tp2)
456-
// An & on the left side loses information. Compensate by also trying the join.
457-
// This is less ad-hoc than it looks since we produce joins in type inference,
458-
// and then need to check that they are indeed supertypes of the original types
459-
// under -Ycheck. Test case is i7965.scala.
460-
case tp1: MatchType =>
470+
case tp1: MatchType =>
461471
val reduced = tp1.reduced
462472
if (reduced.exists) recur(reduced, tp2) else thirdTry
463473
case _: FlexType =>
@@ -511,35 +521,36 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
511521
fourthTry
512522
}
513523

524+
def compareTypeParamRef(tp2: TypeParamRef): Boolean =
525+
assumedTrue(tp2) || {
526+
val alwaysTrue =
527+
// The following condition is carefully formulated to catch all cases
528+
// where the subtype relation is true without needing to add a constraint
529+
// It's tricky because we might need to either approximate tp2 by its
530+
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
531+
// So if the constraint is not yet frozen, we do the same comparison again
532+
// with a frozen constraint, which means that we get a chance to do the
533+
// widening in `fourthTry` before adding to the constraint.
534+
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
535+
else isSubTypeWhenFrozen(tp1, tp2)
536+
alwaysTrue ||
537+
frozenConstraint && (tp1 match {
538+
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
539+
case _ => false
540+
}) || {
541+
if (canConstrain(tp2) && !approx.low)
542+
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
543+
else fourthTry
544+
}
545+
}
546+
514547
def thirdTry: Boolean = tp2 match {
515548
case tp2 @ AppliedType(tycon2, args2) =>
516549
compareAppliedType2(tp2, tycon2, args2)
517550
case tp2: NamedType =>
518551
thirdTryNamed(tp2)
519552
case tp2: TypeParamRef =>
520-
def compareTypeParamRef =
521-
assumedTrue(tp2) || {
522-
val alwaysTrue =
523-
// The following condition is carefully formulated to catch all cases
524-
// where the subtype relation is true without needing to add a constraint
525-
// It's tricky because we might need to either approximate tp2 by its
526-
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
527-
// So if the constraint is not yet frozen, we do the same comparison again
528-
// with a frozen constraint, which means that we get a chance to do the
529-
// widening in `fourthTry` before adding to the constraint.
530-
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
531-
else isSubTypeWhenFrozen(tp1, tp2)
532-
alwaysTrue ||
533-
frozenConstraint && (tp1 match {
534-
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
535-
case _ => false
536-
}) || {
537-
if (canConstrain(tp2) && !approx.low)
538-
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
539-
else fourthTry
540-
}
541-
}
542-
compareTypeParamRef
553+
compareTypeParamRef(tp2)
543554
case tp2: RefinedType =>
544555
def compareRefinedSlow: Boolean = {
545556
val name2 = tp2.refinedName
@@ -616,7 +627,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
616627
}
617628
}
618629
compareTypeLambda
619-
case OrType(tp21, tp22) =>
630+
case tp2 as OrType(tp21, tp22) =>
620631
compareAtoms(tp1, tp2) match
621632
case Some(b) => return b
622633
case _ =>
@@ -648,12 +659,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
648659
// solutions. The rewriting delays the point where we have to choose.
649660
tp21 match {
650661
case AndType(tp211, tp212) =>
651-
return recur(tp1, OrType(tp211, tp22)) && recur(tp1, OrType(tp212, tp22))
662+
return recur(tp1, OrType(tp211, tp22, tp2.isSoft)) && recur(tp1, OrType(tp212, tp22, tp2.isSoft))
652663
case _ =>
653664
}
654665
tp22 match {
655666
case AndType(tp221, tp222) =>
656-
return recur(tp1, OrType(tp21, tp221)) && recur(tp1, OrType(tp21, tp222))
667+
return recur(tp1, OrType(tp21, tp221, tp2.isSoft)) && recur(tp1, OrType(tp21, tp222, tp2.isSoft))
657668
case _ =>
658669
}
659670
either(recur(tp1, tp21), recur(tp1, tp22)) || fourthTry
@@ -2123,7 +2134,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
21232134
val t2 = distributeOr(tp2, tp1)
21242135
if (t2.exists) t2
21252136
else if (isErased) erasedLub(tp1, tp2)
2126-
else liftIfHK(tp1, tp2, OrType(_, _), _ | _, _ & _)
2137+
else liftIfHK(tp1, tp2, OrType(_, _, soft = true), _ | _, _ & _)
21272138
}
21282139
}
21292140

compiler/src/dotty/tools/dotc/core/TypeOps.scala

+8-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,14 @@ object TypeOps:
150150
tp.derivedAlias(simplify(tp.alias, theMap))
151151
case AndType(l, r) if !ctx.mode.is(Mode.Type) =>
152152
simplify(l, theMap) & simplify(r, theMap)
153-
case OrType(l, r) if !ctx.mode.is(Mode.Type) =>
153+
case tp as OrType(l, r)
154+
if !ctx.mode.is(Mode.Type)
155+
&& (tp.isSoft || defn.isBottomType(l) || defn.isBottomType(r)) =>
156+
// Normalize A | Null and Null | A to A even if the union is hard (i.e.
157+
// explicitly declared), but not if -Yexplicit-nulls is set. The reason is
158+
// that in this case the normal asSeenFrom machinery is not prepared to deal
159+
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
160+
// corrective steps, so no widening is wanted.
154161
simplify(l, theMap) | simplify(r, theMap)
155162
case AnnotatedType(parent, annot)
156163
if !ctx.mode.is(Mode.Type) && annot.symbol == defn.UncheckedVarianceAnnot =>

compiler/src/dotty/tools/dotc/core/Types.scala

+29-28
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ object Types {
10401040
def safe_& (that: Type)(using Context): Type = (this, that) match {
10411041
case (TypeBounds(lo1, hi1), TypeBounds(lo2, hi2)) =>
10421042
TypeBounds(
1043-
OrType.makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
1043+
OrType.makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
10441044
AndType.makeHk(hi1.stripLazyRef, hi2.stripLazyRef))
10451045
case _ =>
10461046
this & that
@@ -1151,10 +1151,11 @@ object Types {
11511151
case _ => this
11521152
}
11531153

1154-
/** Widen this type and if the result contains embedded union types, replace
1154+
/** Widen this type and if the result contains embedded soft union types, replace
11551155
* them by their joins.
1156-
* "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
1157-
* If an embedded union is found, we first try to simplify or eliminate it by
1156+
* "Embedded" means: inside type lambdas, intersections or recursive types,
1157+
* in prefixes of refined types, or in hard union types.
1158+
* If an embedded soft union is found, we first try to simplify or eliminate it by
11581159
* re-lubbing it while allowing type parameters to be constrained further.
11591160
* Any remaining union types are replaced by their joins.
11601161
*
@@ -1168,24 +1169,22 @@ object Types {
11681169
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
11691170
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
11701171
*/
1171-
def widenUnion(using Context): Type = widen match {
1172+
def widenUnion(using Context): Type = widen match
11721173
case tp @ OrNull(tp1): OrType =>
11731174
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
11741175
val tp1Widen = tp1.widenUnionWithoutNull
11751176
if (tp1Widen.isRef(defn.AnyClass)) tp1Widen
11761177
else tp.derivedOrType(tp1Widen, defn.NullType)
11771178
case tp =>
11781179
tp.widenUnionWithoutNull
1179-
}
11801180

1181-
def widenUnionWithoutNull(using Context): Type = widen match {
1182-
case tp @ OrType(lhs, rhs) =>
1183-
TypeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match {
1181+
def widenUnionWithoutNull(using Context): Type = widen match
1182+
case tp @ OrType(lhs, rhs) if tp.isSoft =>
1183+
TypeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match
11841184
case union: OrType => union.join
11851185
case res => res
1186-
}
1187-
case tp @ AndType(tp1, tp2) =>
1188-
tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1186+
case tp: AndOrType =>
1187+
tp.derivedAndOrType(tp.tp1.widenUnionWithoutNull, tp.tp2.widenUnionWithoutNull)
11891188
case tp: RefinedType =>
11901189
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
11911190
case tp: RecType =>
@@ -1194,7 +1193,6 @@ object Types {
11941193
tp.derivedLambdaType(resType = tp.resType.widenUnion)
11951194
case tp =>
11961195
tp
1197-
}
11981196

11991197
/** Widen all top-level singletons reachable by dealiasing
12001198
* and going to the operands of & and |.
@@ -2917,8 +2915,9 @@ object Types {
29172915

29182916
def derivedAndOrType(tp1: Type, tp2: Type)(using Context) =
29192917
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
2920-
else if (isAnd) AndType.make(tp1, tp2, checkValid = true)
2921-
else OrType.make(tp1, tp2)
2918+
else this match
2919+
case tp: OrType => OrType.make(tp1, tp2, tp.isSoft)
2920+
case tp: AndType => AndType.make(tp1, tp2, checkValid = true)
29222921
}
29232922

29242923
abstract case class AndType(tp1: Type, tp2: Type) extends AndOrType {
@@ -2992,6 +2991,7 @@ object Types {
29922991

29932992
abstract case class OrType(tp1: Type, tp2: Type) extends AndOrType {
29942993
def isAnd: Boolean = false
2994+
def isSoft: Boolean
29952995
private var myBaseClassesPeriod: Period = Nowhere
29962996
private var myBaseClasses: List[ClassSymbol] = _
29972997
/** Base classes of are the intersection of the operand base classes. */
@@ -3052,32 +3052,33 @@ object Types {
30523052
myWidened
30533053
}
30543054

3055-
def derivedOrType(tp1: Type, tp2: Type)(using Context): Type =
3056-
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
3057-
else OrType.make(tp1, tp2)
3055+
def derivedOrType(tp1: Type, tp2: Type, soft: Boolean = isSoft)(using Context): Type =
3056+
if ((tp1 eq this.tp1) && (tp2 eq this.tp2) && soft == isSoft) this
3057+
else OrType.make(tp1, tp2, soft)
30583058

3059-
override def computeHash(bs: Binders): Int = doHash(bs, tp1, tp2)
3059+
override def computeHash(bs: Binders): Int =
3060+
doHash(bs, if isSoft then 0 else 1, tp1, tp2)
30603061

30613062
override def eql(that: Type): Boolean = that match {
3062-
case that: OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2)
3063+
case that: OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2) && isSoft == that.isSoft
30633064
case _ => false
30643065
}
30653066
}
30663067

3067-
final class CachedOrType(tp1: Type, tp2: Type) extends OrType(tp1, tp2)
3068+
final class CachedOrType(tp1: Type, tp2: Type, override val isSoft: Boolean) extends OrType(tp1, tp2)
30683069

30693070
object OrType {
3070-
def apply(tp1: Type, tp2: Type)(using Context): OrType = {
3071+
def apply(tp1: Type, tp2: Type, soft: Boolean)(using Context): OrType = {
30713072
assertUnerased()
3072-
unique(new CachedOrType(tp1, tp2))
3073+
unique(new CachedOrType(tp1, tp2, soft))
30733074
}
3074-
def make(tp1: Type, tp2: Type)(using Context): Type =
3075+
def make(tp1: Type, tp2: Type, soft: Boolean)(using Context): Type =
30753076
if (tp1 eq tp2) tp1
3076-
else apply(tp1, tp2)
3077+
else apply(tp1, tp2, soft)
30773078

30783079
/** Like `make`, but also supports higher-kinded types as argument */
30793080
def makeHk(tp1: Type, tp2: Type)(using Context): Type =
3080-
TypeComparer.liftIfHK(tp1, tp2, OrType(_, _), makeHk, _ & _)
3081+
TypeComparer.liftIfHK(tp1, tp2, OrType(_, _, soft = true), makeHk, _ & _)
30813082
}
30823083

30833084
/** An extractor object to pattern match against a nullable union.
@@ -3089,7 +3090,7 @@ object Types {
30893090
*/
30903091
object OrNull {
30913092
def apply(tp: Type)(using Context) =
3092-
OrType(tp, defn.NullType)
3093+
OrType(tp, defn.NullType, soft = false)
30933094
def unapply(tp: Type)(using Context): Option[Type] =
30943095
if (ctx.explicitNulls) {
30953096
val tp1 = tp.stripNull()
@@ -3107,7 +3108,7 @@ object Types {
31073108
*/
31083109
object OrUncheckedNull {
31093110
def apply(tp: Type)(using Context) =
3110-
OrType(tp, defn.UncheckedNullAliasType)
3111+
OrType(tp, defn.UncheckedNullAliasType, soft = false)
31113112
def unapply(tp: Type)(using Context): Option[Type] =
31123113
if (ctx.explicitNulls) {
31133114
val tp1 = tp.stripUncheckedNull

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class TreeUnpickler(reader: TastyReader,
368368
case ANDtype =>
369369
AndType(readType(), readType())
370370
case ORtype =>
371-
OrType(readType(), readType())
371+
OrType(readType(), readType(), soft = false)
372372
case SUPERtype =>
373373
SuperType(readType(), readType())
374374
case MATCHtype =>
@@ -1222,7 +1222,7 @@ class TreeUnpickler(reader: TastyReader,
12221222
val args = until(end)(readTpt())
12231223
val ownType =
12241224
if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe)
1225-
else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe)
1225+
else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false)
12261226
else tycon.tpe.safeAppliedTo(args.tpes)
12271227
untpd.AppliedTypeTree(tycon, args).withType(ownType)
12281228
case ANNOTATEDtpt =>

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
214214
// of AndType and OrType to account for associativity
215215
case AndType(tp1, tp2) =>
216216
toTextInfixType(tpnme.raw.AMP, tp1, tp2) { toText(tpnme.raw.AMP) }
217-
case OrType(tp1, tp2) =>
218-
toTextInfixType(tpnme.raw.BAR, tp1, tp2) { toText(tpnme.raw.BAR) }
217+
case tp as OrType(tp1, tp2) =>
218+
toTextInfixType(tpnme.raw.BAR, tp1, tp2) {
219+
if tp.isSoft && printDebug then toText(tpnme.ZOR) else toText(tpnme.raw.BAR)
220+
}
219221
case tp @ EtaExpansion(tycon)
220222
if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty =>
221223
// don't eta contract if the application would be printed specially

compiler/src/dotty/tools/dotc/quoted/QuoteContextImpl.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1823,7 +1823,7 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext:
18231823
end OrTypeTypeTest
18241824

18251825
object OrType extends OrTypeModule:
1826-
def apply(lhs: TypeRepr, rhs: TypeRepr): OrType = Types.OrType(lhs, rhs)
1826+
def apply(lhs: TypeRepr, rhs: TypeRepr): OrType = Types.OrType(lhs, rhs, soft = false)
18271827
def unapply(x: OrType): Option[(TypeRepr, TypeRepr)] = Some((x.left, x.right))
18281828
end OrType
18291829

compiler/src/dotty/tools/dotc/reporting/messages.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -822,14 +822,16 @@ import transform.SymUtils._
822822
|"""
823823
}
824824

825-
class PatternMatchExhaustivity(uncoveredFn: => String)(using Context)
825+
class PatternMatchExhaustivity(uncoveredFn: => String, hasMore: Boolean)(using Context)
826826
extends Message(PatternMatchExhaustivityID) {
827827
def kind = "Pattern Match Exhaustivity"
828828
lazy val uncovered = uncoveredFn
829829
def msg =
830+
val addendum = if hasMore then "(More unmatched cases are elided)" else ""
830831
em"""|${hl("match")} may not be exhaustive.
831832
|
832-
|It would fail on pattern case: $uncovered"""
833+
|It would fail on pattern case: $uncovered
834+
|$addendum"""
833835

834836

835837
def explain =

0 commit comments

Comments
 (0)