@@ -1040,7 +1040,7 @@ object Types {
1040
1040
def safe_& (that : Type )(using Context ): Type = (this , that) match {
1041
1041
case (TypeBounds (lo1, hi1), TypeBounds (lo2, hi2)) =>
1042
1042
TypeBounds (
1043
- OrType .makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
1043
+ OrType .makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
1044
1044
AndType .makeHk(hi1.stripLazyRef, hi2.stripLazyRef))
1045
1045
case _ =>
1046
1046
this & that
@@ -1151,10 +1151,11 @@ object Types {
1151
1151
case _ => this
1152
1152
}
1153
1153
1154
- /** Widen this type and if the result contains embedded union types, replace
1154
+ /** Widen this type and if the result contains embedded soft union types, replace
1155
1155
* them by their joins.
1156
- * "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
1157
- * If an embedded union is found, we first try to simplify or eliminate it by
1156
+ * "Embedded" means: inside type lambdas, intersections or recursive types,
1157
+ * in prefixes of refined types, or in hard union types.
1158
+ * If an embedded soft union is found, we first try to simplify or eliminate it by
1158
1159
* re-lubbing it while allowing type parameters to be constrained further.
1159
1160
* Any remaining union types are replaced by their joins.
1160
1161
*
@@ -1168,24 +1169,22 @@ object Types {
1168
1169
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
1169
1170
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
1170
1171
*/
1171
- def widenUnion (using Context ): Type = widen match {
1172
+ def widenUnion (using Context ): Type = widen match
1172
1173
case tp @ OrNull (tp1): OrType =>
1173
1174
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
1174
1175
val tp1Widen = tp1.widenUnionWithoutNull
1175
1176
if (tp1Widen.isRef(defn.AnyClass )) tp1Widen
1176
1177
else tp.derivedOrType(tp1Widen, defn.NullType )
1177
1178
case tp =>
1178
1179
tp.widenUnionWithoutNull
1179
- }
1180
1180
1181
- def widenUnionWithoutNull (using Context ): Type = widen match {
1182
- case tp @ OrType (lhs, rhs) =>
1183
- TypeComparer .lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true ) match {
1181
+ def widenUnionWithoutNull (using Context ): Type = widen match
1182
+ case tp @ OrType (lhs, rhs) if tp.isSoft =>
1183
+ TypeComparer .lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true ) match
1184
1184
case union : OrType => union.join
1185
1185
case res => res
1186
- }
1187
- case tp @ AndType (tp1, tp2) =>
1188
- tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1186
+ case tp : AndOrType =>
1187
+ tp.derivedAndOrType(tp.tp1.widenUnionWithoutNull, tp.tp2.widenUnionWithoutNull)
1189
1188
case tp : RefinedType =>
1190
1189
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1191
1190
case tp : RecType =>
@@ -1194,7 +1193,6 @@ object Types {
1194
1193
tp.derivedLambdaType(resType = tp.resType.widenUnion)
1195
1194
case tp =>
1196
1195
tp
1197
- }
1198
1196
1199
1197
/** Widen all top-level singletons reachable by dealiasing
1200
1198
* and going to the operands of & and |.
@@ -2917,8 +2915,9 @@ object Types {
2917
2915
2918
2916
def derivedAndOrType (tp1 : Type , tp2 : Type )(using Context ) =
2919
2917
if ((tp1 eq this .tp1) && (tp2 eq this .tp2)) this
2920
- else if (isAnd) AndType .make(tp1, tp2, checkValid = true )
2921
- else OrType .make(tp1, tp2)
2918
+ else this match
2919
+ case tp : OrType => OrType .make(tp1, tp2, tp.isSoft)
2920
+ case tp : AndType => AndType .make(tp1, tp2, checkValid = true )
2922
2921
}
2923
2922
2924
2923
abstract case class AndType (tp1 : Type , tp2 : Type ) extends AndOrType {
@@ -2992,6 +2991,7 @@ object Types {
2992
2991
2993
2992
abstract case class OrType (tp1 : Type , tp2 : Type ) extends AndOrType {
2994
2993
def isAnd : Boolean = false
2994
+ def isSoft : Boolean
2995
2995
private var myBaseClassesPeriod : Period = Nowhere
2996
2996
private var myBaseClasses : List [ClassSymbol ] = _
2997
2997
/** Base classes of are the intersection of the operand base classes. */
@@ -3052,32 +3052,33 @@ object Types {
3052
3052
myWidened
3053
3053
}
3054
3054
3055
- def derivedOrType (tp1 : Type , tp2 : Type )(using Context ): Type =
3056
- if ((tp1 eq this .tp1) && (tp2 eq this .tp2)) this
3057
- else OrType .make(tp1, tp2)
3055
+ def derivedOrType (tp1 : Type , tp2 : Type , soft : Boolean = isSoft )(using Context ): Type =
3056
+ if ((tp1 eq this .tp1) && (tp2 eq this .tp2) && soft == isSoft ) this
3057
+ else OrType .make(tp1, tp2, soft )
3058
3058
3059
- override def computeHash (bs : Binders ): Int = doHash(bs, tp1, tp2)
3059
+ override def computeHash (bs : Binders ): Int =
3060
+ doHash(bs, if isSoft then 0 else 1 , tp1, tp2)
3060
3061
3061
3062
override def eql (that : Type ): Boolean = that match {
3062
- case that : OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2)
3063
+ case that : OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2) && isSoft == that.isSoft
3063
3064
case _ => false
3064
3065
}
3065
3066
}
3066
3067
3067
- final class CachedOrType (tp1 : Type , tp2 : Type ) extends OrType (tp1, tp2)
3068
+ final class CachedOrType (tp1 : Type , tp2 : Type , override val isSoft : Boolean ) extends OrType (tp1, tp2)
3068
3069
3069
3070
object OrType {
3070
- def apply (tp1 : Type , tp2 : Type )(using Context ): OrType = {
3071
+ def apply (tp1 : Type , tp2 : Type , soft : Boolean )(using Context ): OrType = {
3071
3072
assertUnerased()
3072
- unique(new CachedOrType (tp1, tp2))
3073
+ unique(new CachedOrType (tp1, tp2, soft ))
3073
3074
}
3074
- def make (tp1 : Type , tp2 : Type )(using Context ): Type =
3075
+ def make (tp1 : Type , tp2 : Type , soft : Boolean )(using Context ): Type =
3075
3076
if (tp1 eq tp2) tp1
3076
- else apply(tp1, tp2)
3077
+ else apply(tp1, tp2, soft )
3077
3078
3078
3079
/** Like `make`, but also supports higher-kinded types as argument */
3079
3080
def makeHk (tp1 : Type , tp2 : Type )(using Context ): Type =
3080
- TypeComparer .liftIfHK(tp1, tp2, OrType (_, _), makeHk, _ & _)
3081
+ TypeComparer .liftIfHK(tp1, tp2, OrType (_, _, soft = true ), makeHk, _ & _)
3081
3082
}
3082
3083
3083
3084
/** An extractor object to pattern match against a nullable union.
@@ -3089,7 +3090,7 @@ object Types {
3089
3090
*/
3090
3091
object OrNull {
3091
3092
def apply (tp : Type )(using Context ) =
3092
- OrType (tp, defn.NullType )
3093
+ OrType (tp, defn.NullType , soft = false )
3093
3094
def unapply (tp : Type )(using Context ): Option [Type ] =
3094
3095
if (ctx.explicitNulls) {
3095
3096
val tp1 = tp.stripNull()
@@ -3107,7 +3108,7 @@ object Types {
3107
3108
*/
3108
3109
object OrUncheckedNull {
3109
3110
def apply (tp : Type )(using Context ) =
3110
- OrType (tp, defn.UncheckedNullAliasType )
3111
+ OrType (tp, defn.UncheckedNullAliasType , soft = false )
3111
3112
def unapply (tp : Type )(using Context ): Option [Type ] =
3112
3113
if (ctx.explicitNulls) {
3113
3114
val tp1 = tp.stripUncheckedNull
0 commit comments