diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index d9da11c561e8..c91412988e82 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -2,7 +2,7 @@ package dotty.tools package dotc package core -import Contexts._, Types._, Symbols._, Names._, Flags._ +import Contexts._, Types._, Symbols._, Names._, NameKinds.*, Flags._ import SymDenotations._ import util.Spans._ import util.Stats @@ -839,24 +839,51 @@ object TypeOps: } } - // Prefix inference, replace `p.C.this.Child` with `X.Child` where `X <: p.C` - // Note: we need to strip ThisType in `p` recursively. + /** Gather GADT symbols and `ThisType`s found in `tp2`, ie. the scrutinee. */ + object TraverseTp2 extends TypeTraverser: + val thisTypes = util.HashSet[ThisType]() + val gadtSyms = new mutable.ListBuffer[Symbol] + + def traverse(tp: Type) = { + val tpd = tp.dealias + if tpd ne tp then traverse(tpd) + else tp match + case tp: ThisType if !tp.tref.symbol.isStaticOwner && !thisTypes.contains(tp) => + thisTypes += tp + traverseChildren(tp.tref) + case tp: TypeRef if tp.symbol.isAbstractOrParamType => + gadtSyms += tp.symbol + traverseChildren(tp) + case _ => + traverseChildren(tp) + } + TraverseTp2.traverse(tp2) + val thisTypes = TraverseTp2.thisTypes + val gadtSyms = TraverseTp2.gadtSyms.toList + + // Prefix inference, given `p.C.this.Child`: + // 1. return it as is, if `C.this` is found in `tp`, i.e. the scrutinee; or + // 2. replace it with `X.Child` where `X <: p.C`, stripping ThisType in `p` recursively. // - // See tests/patmat/i3938.scala + // See tests/patmat/i3938.scala, tests/pos/i15029.more.scala, tests/pos/i16785.scala class InferPrefixMap extends TypeMap { var prefixTVar: Type | Null = null def apply(tp: Type): Type = tp match { - case ThisType(tref: TypeRef) if !tref.symbol.isStaticOwner => + case tp @ ThisType(tref) if !tref.symbol.isStaticOwner => val symbol = tref.symbol - if (symbol.is(Module)) + if thisTypes.contains(tp) then + prefixTVar = tp // e.g. tests/pos/i16785.scala, keep Outer.this + prefixTVar.uncheckedNN + else if symbol.is(Module) then TermRef(this(tref.prefix), symbol.sourceModule) else if (prefixTVar != null) this(tref) else { prefixTVar = WildcardType // prevent recursive call from assigning it - val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds) } + // e.g. tests/pos/i15029.more.scala, create a TypeVar for `Instances`' B, so we can disregard `Ints` + val tvars = tref.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) } val tref2 = this(tref.applyIfParameterized(tvars)) - prefixTVar = newTypeVar(TypeBounds.upper(tref2)) + prefixTVar = newTypeVar(TypeBounds.upper(tref2), DepParamName.fresh(tref.name)) prefixTVar.uncheckedNN } case tp => mapOver(tp) @@ -864,15 +891,11 @@ object TypeOps: } val inferThisMap = new InferPrefixMap - val tvars = tp1.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds) } + val tvars = tp1.typeParams.map { tparam => newTypeVar(tparam.paramInfo.bounds, DepParamName.fresh(tparam.paramName)) } val protoTp1 = inferThisMap.apply(tp1).appliedTo(tvars) - val getAbstractSymbols = new TypeAccumulator[List[Symbol]]: - def apply(xs: List[Symbol], tp: Type) = tp.dealias match - case tp: TypeRef if tp.symbol.exists && !tp.symbol.isClass => foldOver(tp.symbol :: xs, tp) - case tp => foldOver(xs, tp) - val syms2 = getAbstractSymbols(Nil, tp2).reverse - if syms2.nonEmpty then ctx.gadtState.addToConstraint(syms2) + if gadtSyms.nonEmpty then + ctx.gadtState.addToConstraint(gadtSyms) // If parent contains a reference to an abstract type, then we should // refine subtype checking to eliminate abstract types according to diff --git a/tests/patmat/i12408.check b/tests/patmat/i12408.check index ada7b8c21fa8..60acc2cba84e 100644 --- a/tests/patmat/i12408.check +++ b/tests/patmat/i12408.check @@ -1,2 +1,2 @@ -13: Pattern Match Exhaustivity: X[] & (X.this : X[T]).A(_), X[] & (X.this : X[T]).C(_) +13: Pattern Match Exhaustivity: A(_), C(_) 21: Pattern Match diff --git a/tests/pos/i16785.scala b/tests/pos/i16785.scala new file mode 100644 index 000000000000..1cfabf5a4312 --- /dev/null +++ b/tests/pos/i16785.scala @@ -0,0 +1,11 @@ +class VarImpl[Lbl, A] + +class Outer[|*|[_, _], Lbl1]: + type Var[A1] = VarImpl[Lbl1, A1] + + sealed trait Foo[G] + case class Bar[T, U]() + extends Foo[Var[T] |*| Var[U]] + + def go[X](scr: Foo[Var[X]]): Unit = scr match // was: compile hang + case Bar() => ()