@@ -1151,10 +1151,11 @@ object Types {
1151
1151
case _ => this
1152
1152
}
1153
1153
1154
- /** Widen this type and if the result contains embedded union types, replace
1154
+ /** Widen this type and if the result contains embedded soft union types, replace
1155
1155
* them by their joins.
1156
- * "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
1157
- * If an embedded union is found, we first try to simplify or eliminate it by
1156
+ * "Embedded" means: inside type lambdas, intersections or recursive types,
1157
+ * in prefixes of refined types, or in hard union types.
1158
+ * If an embedded soft union is found, we first try to simplify or eliminate it by
1158
1159
* re-lubbing it while allowing type parameters to be constrained further.
1159
1160
* Any remaining union types are replaced by their joins.
1160
1161
*
@@ -1165,36 +1166,78 @@ object Types {
1165
1166
* is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
1166
1167
* instead of `ArrayBuffer[? >: Int | A <: Int & A]`
1167
1168
*
1169
+ * Hard unions inside soft ones are treated specially. For illustration assume we
1170
+ * want to widen the type `(A | C) \/ (B | C)` where `\/` means soft union and `|`
1171
+ * means hard union. In that case, the hard unions `A | C` and `B | C` are treated
1172
+ * in an asymmetric way. Only the first parts `A` and `B` are joined and the rest
1173
+ * is added again with a hard union to the result. So
1174
+ *
1175
+ * widenUnion[ (A | C) \/ (B | C) ]
1176
+ * = widenUnion[ A \/ B ] | C | C
1177
+ * = D | C | C
1178
+ * = D | C
1179
+ *
1180
+ * In general, If a hard union A | B_1 | ... | B_n is part of of a soft union,
1181
+ * only A forms part of the join, and B_1, ..., B_n are pushed out, just `C` is
1182
+ * pushed out above. All types that are pushed out are recombined with the result
1183
+ * of the join with a lub, but that lub yields again a hard union, not a soft one.
1184
+ *
1168
1185
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
1169
1186
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
1170
1187
*/
1171
- def widenUnion (using Context ): Type = widen match {
1188
+ def widenUnion (using Context ): Type = widen. match {
1172
1189
case tp @ OrNull (tp1): OrType =>
1173
1190
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
1174
1191
val tp1Widen = tp1.widenUnionWithoutNull
1175
1192
if (tp1Widen.isRef(defn.AnyClass )) tp1Widen
1176
1193
else tp.derivedOrType(tp1Widen, defn.NullType )
1177
1194
case tp =>
1178
1195
tp.widenUnionWithoutNull
1179
- }
1196
+ }.reporting( i " widenUnion( $this ) = $result " )
1180
1197
1181
- def widenUnionWithoutNull (using Context ): Type = widen match {
1182
- case tp @ OrType (lhs, rhs) =>
1183
- TypeComparer .lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true ) match {
1184
- case union : OrType => union.join
1185
- case res => res
1186
- }
1187
- case tp @ AndType (tp1, tp2) =>
1188
- tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1189
- case tp : RefinedType =>
1190
- tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1191
- case tp : RecType =>
1192
- tp.rebind(tp.parent.widenUnion)
1193
- case tp : HKTypeLambda =>
1194
- tp.derivedLambdaType(resType = tp.resType.widenUnion)
1195
- case tp =>
1196
- tp
1197
- }
1198
+ def widenUnionWithoutNull (using Context ): Type =
1199
+
1200
+ // Split hard union `A | B1 | ... | Bn` into leftmost part `A` and list of
1201
+ // pushed out parts `B1, ..., Bn`.
1202
+ def splitAlts (tp : Type , follow : List [Type ]): (Type , List [Type ]) = tp match
1203
+ case tp as OrType (lhs, rhs) if ! tp.isSoft =>
1204
+ splitAlts(lhs, rhs :: follow)
1205
+ case _ =>
1206
+ (tp, follow)
1207
+
1208
+ // Convert any soft unions in result of lub to hard ones */
1209
+ def harden (tp : Type ): Type = tp match
1210
+ case tp as OrType (tp1, tp2) if tp.isSoft =>
1211
+ OrType (harden(tp1), harden(tp2), soft = false )
1212
+ case _ =>
1213
+ tp
1214
+
1215
+ def recombine (tp1 : Type , tp2 : Type ) = harden(TypeComparer .lub(tp1, tp2))
1216
+
1217
+ widen match
1218
+ case tp @ OrType (lhs, rhs) =>
1219
+ if tp.isSoft then
1220
+ val (lhsCore, lhsExtras) = splitAlts(lhs.widenUnionWithoutNull, Nil )
1221
+ val (rhsCore, rhsExtras) = splitAlts(rhs.widenUnionWithoutNull, Nil )
1222
+ val core = TypeComparer .lub(lhsCore, rhsCore, canConstrain = true ) match
1223
+ case union : OrType => union.join
1224
+ case res => res
1225
+ rhsExtras.foldLeft(lhsExtras.foldLeft(core)(recombine))(recombine)
1226
+ else
1227
+ val lhs1 = lhs.widenUnionWithoutNull
1228
+ val rhs1 = rhs.widenUnionWithoutNull
1229
+ if (lhs1 eq lhs) && (rhs1 eq rhs) then tp else recombine(lhs1, rhs1)
1230
+ case tp @ AndType (tp1, tp2) =>
1231
+ tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1232
+ case tp : RefinedType =>
1233
+ tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1234
+ case tp : RecType =>
1235
+ tp.rebind(tp.parent.widenUnion)
1236
+ case tp : HKTypeLambda =>
1237
+ tp.derivedLambdaType(resType = tp.resType.widenUnion)
1238
+ case tp =>
1239
+ tp
1240
+ end widenUnionWithoutNull
1198
1241
1199
1242
/** Widen all top-level singletons reachable by dealiasing
1200
1243
* and going to the operands of & and |.
@@ -3044,9 +3087,9 @@ object Types {
3044
3087
myWidened
3045
3088
}
3046
3089
3047
- def derivedOrType (tp1 : Type , tp2 : Type )(using Context ): Type =
3048
- if ((tp1 eq this .tp1) && (tp2 eq this .tp2)) this
3049
- else OrType .make(tp1, tp2, isSoft )
3090
+ def derivedOrType (tp1 : Type , tp2 : Type , soft : Boolean = isSoft )(using Context ): Type =
3091
+ if ((tp1 eq this .tp1) && (tp2 eq this .tp2) && soft == isSoft ) this
3092
+ else OrType .make(tp1, tp2, soft )
3050
3093
3051
3094
override def computeHash (bs : Binders ): Int =
3052
3095
doHash(bs, if isSoft then 0 else 1 , tp1, tp2)
0 commit comments