Skip to content

Fix #10108: Distinguish between soft and hard union types #10112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ class Compiler {
new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only)
new ExplicitOuter, // Add accessors to outer classes from nested ones.
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
new StringInterpolatorOpt, // Optimizes raw and s string interpolators by rewriting them to string concatentations
new CrossCastAnd) :: // Normalize selections involving intersection types.
new StringInterpolatorOpt) :: // Optimizes raw and s string interpolators by rewriting them to string concatentations
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
new InlinePatterns, // Remove placeholders of inlined patterns
new VCInlineMethods, // Inlines calls to value class methods
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class Definitions {
def AnyKindType: TypeRef = AnyKindClass.typeRef

@tu lazy val andType: TypeSymbol = enterBinaryAlias(tpnme.AND, AndType(_, _))
@tu lazy val orType: TypeSymbol = enterBinaryAlias(tpnme.OR, OrType(_, _))
@tu lazy val orType: TypeSymbol = enterBinaryAlias(tpnme.OR, OrType(_, _, soft = false))

/** Marker method to indicate an argument to a call-by-name parameter.
* Created by byNameClosures and elimByName, eliminated by Erasure,
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Hashable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ trait Hashable {
protected final def doHash(bs: Binders, x1: Int, tp2: Type): Int =
finishHash(bs, hashing.mix(hashSeed, x1), 1, tp2)

protected final def doHash(bs: Binders, x1: Int, tp2: Type, tp3: Type): Int =
finishHash(bs, hashing.mix(hashSeed, x1), 1, tp2, tp3)

protected final def doHash(bs: Binders, tp1: Type, tp2: Type): Int =
finishHash(bs, hashSeed, 0, tp1, tp2)

Expand Down
87 changes: 49 additions & 38 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case Some(b) => return b
case None =>

def widenOK =
(tp2.widenSingletons eq tp2)
&& (tp1.widenSingletons ne tp1)
&& recur(tp1.widenSingletons, tp2)

def joinOK = tp2.dealiasKeepRefiningAnnots match {
case tp2: AppliedType if !tp2.tycon.typeSymbol.isClass =>
// If we apply the default algorithm for `A[X] | B[Y] <: C[Z]` where `C` is a
Expand All @@ -439,25 +444,30 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
false
}

// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
// before splitting the LHS into its constituents. That way, the RHS variables are
// constraint by the hard union and can be instantiated to it. If we just split and add
// the two parts of the LHS separately to the constraint, the lower bound would become
// a soft union.
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
case _ => true

// An & on the left side loses information. We compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
def containsAnd(tp: Type): Boolean = tp.dealiasKeepRefiningAnnots match
case tp: AndType => true
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
case _ => false

def widenOK =
(tp2.widenSingletons eq tp2) &&
(tp1.widenSingletons ne tp1) &&
recur(tp1.widenSingletons, tp2)

widenOK
|| joinOK
|| recur(tp11, tp2) && recur(tp12, tp2)
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
|| containsAnd(tp1) && recur(tp1.join, tp2)
// An & on the left side loses information. Compensate by also trying the join.
// This is less ad-hoc than it looks since we produce joins in type inference,
// and then need to check that they are indeed supertypes of the original types
// under -Ycheck. Test case is i7965.scala.
case tp1: MatchType =>
case tp1: MatchType =>
val reduced = tp1.reduced
if (reduced.exists) recur(reduced, tp2) else thirdTry
case _: FlexType =>
Expand Down Expand Up @@ -511,35 +521,36 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
fourthTry
}

def compareTypeParamRef(tp2: TypeParamRef): Boolean =
assumedTrue(tp2) || {
val alwaysTrue =
// The following condition is carefully formulated to catch all cases
// where the subtype relation is true without needing to add a constraint
// It's tricky because we might need to either approximate tp2 by its
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
// So if the constraint is not yet frozen, we do the same comparison again
// with a frozen constraint, which means that we get a chance to do the
// widening in `fourthTry` before adding to the constraint.
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
else isSubTypeWhenFrozen(tp1, tp2)
alwaysTrue ||
frozenConstraint && (tp1 match {
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
case _ => false
}) || {
if (canConstrain(tp2) && !approx.low)
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
else fourthTry
}
}

def thirdTry: Boolean = tp2 match {
case tp2 @ AppliedType(tycon2, args2) =>
compareAppliedType2(tp2, tycon2, args2)
case tp2: NamedType =>
thirdTryNamed(tp2)
case tp2: TypeParamRef =>
def compareTypeParamRef =
assumedTrue(tp2) || {
val alwaysTrue =
// The following condition is carefully formulated to catch all cases
// where the subtype relation is true without needing to add a constraint
// It's tricky because we might need to either approximate tp2 by its
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
// So if the constraint is not yet frozen, we do the same comparison again
// with a frozen constraint, which means that we get a chance to do the
// widening in `fourthTry` before adding to the constraint.
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
else isSubTypeWhenFrozen(tp1, tp2)
alwaysTrue ||
frozenConstraint && (tp1 match {
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
case _ => false
}) || {
if (canConstrain(tp2) && !approx.low)
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
else fourthTry
}
}
compareTypeParamRef
compareTypeParamRef(tp2)
case tp2: RefinedType =>
def compareRefinedSlow: Boolean = {
val name2 = tp2.refinedName
Expand Down Expand Up @@ -616,7 +627,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
}
}
compareTypeLambda
case OrType(tp21, tp22) =>
case tp2 as OrType(tp21, tp22) =>
compareAtoms(tp1, tp2) match
case Some(b) => return b
case _ =>
Expand Down Expand Up @@ -648,12 +659,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
// solutions. The rewriting delays the point where we have to choose.
tp21 match {
case AndType(tp211, tp212) =>
return recur(tp1, OrType(tp211, tp22)) && recur(tp1, OrType(tp212, tp22))
return recur(tp1, OrType(tp211, tp22, tp2.isSoft)) && recur(tp1, OrType(tp212, tp22, tp2.isSoft))
case _ =>
}
tp22 match {
case AndType(tp221, tp222) =>
return recur(tp1, OrType(tp21, tp221)) && recur(tp1, OrType(tp21, tp222))
return recur(tp1, OrType(tp21, tp221, tp2.isSoft)) && recur(tp1, OrType(tp21, tp222, tp2.isSoft))
case _ =>
}
either(recur(tp1, tp21), recur(tp1, tp22)) || fourthTry
Expand Down Expand Up @@ -2123,7 +2134,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
val t2 = distributeOr(tp2, tp1)
if (t2.exists) t2
else if (isErased) erasedLub(tp1, tp2)
else liftIfHK(tp1, tp2, OrType(_, _), _ | _, _ & _)
else liftIfHK(tp1, tp2, OrType(_, _, soft = true), _ | _, _ & _)
}
}

Expand Down
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,14 @@ object TypeOps:
tp.derivedAlias(simplify(tp.alias, theMap))
case AndType(l, r) if !ctx.mode.is(Mode.Type) =>
simplify(l, theMap) & simplify(r, theMap)
case OrType(l, r) if !ctx.mode.is(Mode.Type) =>
case tp as OrType(l, r)
if !ctx.mode.is(Mode.Type)
&& (tp.isSoft || defn.isBottomType(l) || defn.isBottomType(r)) =>
// Normalize A | Null and Null | A to A even if the union is hard (i.e.
// explicitly declared), but not if -Yexplicit-nulls is set. The reason is
// that in this case the normal asSeenFrom machinery is not prepared to deal
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
// corrective steps, so no widening is wanted.
simplify(l, theMap) | simplify(r, theMap)
case AnnotatedType(parent, annot)
if !ctx.mode.is(Mode.Type) && annot.symbol == defn.UncheckedVarianceAnnot =>
Expand Down
57 changes: 29 additions & 28 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ object Types {
def safe_& (that: Type)(using Context): Type = (this, that) match {
case (TypeBounds(lo1, hi1), TypeBounds(lo2, hi2)) =>
TypeBounds(
OrType.makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
OrType.makeHk(lo1.stripLazyRef, lo2.stripLazyRef),
AndType.makeHk(hi1.stripLazyRef, hi2.stripLazyRef))
case _ =>
this & that
Expand Down Expand Up @@ -1151,10 +1151,11 @@ object Types {
case _ => this
}

/** Widen this type and if the result contains embedded union types, replace
/** Widen this type and if the result contains embedded soft union types, replace
* them by their joins.
* "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
* If an embedded union is found, we first try to simplify or eliminate it by
* "Embedded" means: inside type lambdas, intersections or recursive types,
* in prefixes of refined types, or in hard union types.
* If an embedded soft union is found, we first try to simplify or eliminate it by
* re-lubbing it while allowing type parameters to be constrained further.
* Any remaining union types are replaced by their joins.
*
Expand All @@ -1168,24 +1169,22 @@ object Types {
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
*/
def widenUnion(using Context): Type = widen match {
def widenUnion(using Context): Type = widen match
case tp @ OrNull(tp1): OrType =>
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
val tp1Widen = tp1.widenUnionWithoutNull
if (tp1Widen.isRef(defn.AnyClass)) tp1Widen
else tp.derivedOrType(tp1Widen, defn.NullType)
case tp =>
tp.widenUnionWithoutNull
}

def widenUnionWithoutNull(using Context): Type = widen match {
case tp @ OrType(lhs, rhs) =>
TypeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match {
def widenUnionWithoutNull(using Context): Type = widen match
case tp @ OrType(lhs, rhs) if tp.isSoft =>
TypeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match
case union: OrType => union.join
case res => res
}
case tp @ AndType(tp1, tp2) =>
tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
case tp: AndOrType =>
tp.derivedAndOrType(tp.tp1.widenUnionWithoutNull, tp.tp2.widenUnionWithoutNull)
case tp: RefinedType =>
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
case tp: RecType =>
Expand All @@ -1194,7 +1193,6 @@ object Types {
tp.derivedLambdaType(resType = tp.resType.widenUnion)
case tp =>
tp
}

/** Widen all top-level singletons reachable by dealiasing
* and going to the operands of & and |.
Expand Down Expand Up @@ -2917,8 +2915,9 @@ object Types {

def derivedAndOrType(tp1: Type, tp2: Type)(using Context) =
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
else if (isAnd) AndType.make(tp1, tp2, checkValid = true)
else OrType.make(tp1, tp2)
else this match
case tp: OrType => OrType.make(tp1, tp2, tp.isSoft)
case tp: AndType => AndType.make(tp1, tp2, checkValid = true)
}

abstract case class AndType(tp1: Type, tp2: Type) extends AndOrType {
Expand Down Expand Up @@ -2992,6 +2991,7 @@ object Types {

abstract case class OrType(tp1: Type, tp2: Type) extends AndOrType {
def isAnd: Boolean = false
def isSoft: Boolean
private var myBaseClassesPeriod: Period = Nowhere
private var myBaseClasses: List[ClassSymbol] = _
/** Base classes of are the intersection of the operand base classes. */
Expand Down Expand Up @@ -3052,32 +3052,33 @@ object Types {
myWidened
}

def derivedOrType(tp1: Type, tp2: Type)(using Context): Type =
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
else OrType.make(tp1, tp2)
def derivedOrType(tp1: Type, tp2: Type, soft: Boolean = isSoft)(using Context): Type =
if ((tp1 eq this.tp1) && (tp2 eq this.tp2) && soft == isSoft) this
else OrType.make(tp1, tp2, soft)

override def computeHash(bs: Binders): Int = doHash(bs, tp1, tp2)
override def computeHash(bs: Binders): Int =
doHash(bs, if isSoft then 0 else 1, tp1, tp2)

override def eql(that: Type): Boolean = that match {
case that: OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2)
case that: OrType => tp1.eq(that.tp1) && tp2.eq(that.tp2) && isSoft == that.isSoft
case _ => false
}
}

final class CachedOrType(tp1: Type, tp2: Type) extends OrType(tp1, tp2)
final class CachedOrType(tp1: Type, tp2: Type, override val isSoft: Boolean) extends OrType(tp1, tp2)

object OrType {
def apply(tp1: Type, tp2: Type)(using Context): OrType = {
def apply(tp1: Type, tp2: Type, soft: Boolean)(using Context): OrType = {
assertUnerased()
unique(new CachedOrType(tp1, tp2))
unique(new CachedOrType(tp1, tp2, soft))
}
def make(tp1: Type, tp2: Type)(using Context): Type =
def make(tp1: Type, tp2: Type, soft: Boolean)(using Context): Type =
if (tp1 eq tp2) tp1
else apply(tp1, tp2)
else apply(tp1, tp2, soft)

/** Like `make`, but also supports higher-kinded types as argument */
def makeHk(tp1: Type, tp2: Type)(using Context): Type =
TypeComparer.liftIfHK(tp1, tp2, OrType(_, _), makeHk, _ & _)
TypeComparer.liftIfHK(tp1, tp2, OrType(_, _, soft = true), makeHk, _ & _)
}

/** An extractor object to pattern match against a nullable union.
Expand All @@ -3089,7 +3090,7 @@ object Types {
*/
object OrNull {
def apply(tp: Type)(using Context) =
OrType(tp, defn.NullType)
OrType(tp, defn.NullType, soft = false)
def unapply(tp: Type)(using Context): Option[Type] =
if (ctx.explicitNulls) {
val tp1 = tp.stripNull()
Expand All @@ -3107,7 +3108,7 @@ object Types {
*/
object OrUncheckedNull {
def apply(tp: Type)(using Context) =
OrType(tp, defn.UncheckedNullAliasType)
OrType(tp, defn.UncheckedNullAliasType, soft = false)
def unapply(tp: Type)(using Context): Option[Type] =
if (ctx.explicitNulls) {
val tp1 = tp.stripUncheckedNull
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class TreeUnpickler(reader: TastyReader,
case ANDtype =>
AndType(readType(), readType())
case ORtype =>
OrType(readType(), readType())
OrType(readType(), readType(), soft = false)
case SUPERtype =>
SuperType(readType(), readType())
case MATCHtype =>
Expand Down Expand Up @@ -1222,7 +1222,7 @@ class TreeUnpickler(reader: TastyReader,
val args = until(end)(readTpt())
val ownType =
if (tycon.symbol == defn.andType) AndType(args(0).tpe, args(1).tpe)
else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe)
else if (tycon.symbol == defn.orType) OrType(args(0).tpe, args(1).tpe, soft = false)
else tycon.tpe.safeAppliedTo(args.tpes)
untpd.AppliedTypeTree(tycon, args).withType(ownType)
case ANNOTATEDtpt =>
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
// of AndType and OrType to account for associativity
case AndType(tp1, tp2) =>
toTextInfixType(tpnme.raw.AMP, tp1, tp2) { toText(tpnme.raw.AMP) }
case OrType(tp1, tp2) =>
toTextInfixType(tpnme.raw.BAR, tp1, tp2) { toText(tpnme.raw.BAR) }
case tp as OrType(tp1, tp2) =>
toTextInfixType(tpnme.raw.BAR, tp1, tp2) {
if tp.isSoft && printDebug then toText(tpnme.ZOR) else toText(tpnme.raw.BAR)
}
case tp @ EtaExpansion(tycon)
if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty =>
// don't eta contract if the application would be printed specially
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,7 @@ class QuoteContextImpl private (ctx: Context) extends QuoteContext:
end OrTypeTypeTest

object OrType extends OrTypeModule:
def apply(lhs: TypeRepr, rhs: TypeRepr): OrType = Types.OrType(lhs, rhs)
def apply(lhs: TypeRepr, rhs: TypeRepr): OrType = Types.OrType(lhs, rhs, soft = false)
def unapply(x: OrType): Option[(TypeRepr, TypeRepr)] = Some((x.left, x.right))
end OrType

Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -822,14 +822,16 @@ import transform.SymUtils._
|"""
}

class PatternMatchExhaustivity(uncoveredFn: => String)(using Context)
class PatternMatchExhaustivity(uncoveredFn: => String, hasMore: Boolean)(using Context)
extends Message(PatternMatchExhaustivityID) {
def kind = "Pattern Match Exhaustivity"
lazy val uncovered = uncoveredFn
def msg =
val addendum = if hasMore then "(More unmatched cases are elided)" else ""
em"""|${hl("match")} may not be exhaustive.
|
|It would fail on pattern case: $uncovered"""
|It would fail on pattern case: $uncovered
|$addendum"""


def explain =
Expand Down
Loading