Skip to content

Commit ff9798d

Browse files
committed
Fix #9028: Introduce super traits
and eliminate super traits in widenInferred
1 parent 86bd4e9 commit ff9798d

File tree

18 files changed

+124
-44
lines changed

18 files changed

+124
-44
lines changed

compiler/src/dotty/tools/dotc/ast/untpd.scala

+2
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
200200
case class Inline()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.Inline)
201201

202202
case class Transparent()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.EmptyFlags)
203+
204+
case class Super()(implicit @constructorOnly src: SourceFile) extends Mod(Flags.SuperTrait)
203205
}
204206

205207
/** Modifiers and annotations for definitions

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

+31-18
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,9 @@ trait ConstraintHandling[AbstractContext] {
300300
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
301301
* 2. If `inst` is a union type, approximate the union type from above by an intersection
302302
* of all common base types, provided the result is a subtype of `bound`.
303-
* 3. (currently not enabled, see #9028) If `inst` is an intersection with some restricted base types, drop
304-
* the restricted base types from the intersection, provided the result is a subtype of `bound`.
303+
* 3. If `inst` a super trait instance or an intersection with some super trait
304+
* parents, replace all super trait instances with AnyRef (or Any, if the trait
305+
* is a universal trait) as long as the result is a subtype of `bound`.
305306
*
306307
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
307308
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -313,21 +314,36 @@ trait ConstraintHandling[AbstractContext] {
313314
*/
314315
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =
315316

316-
def isRestricted(tp: Type) = tp.typeSymbol == defn.EnumValueClass // for now, to be generalized later
317+
def dropSuperTraits(tp: Type): Type =
318+
var keep: Set[Type] = Set() // types to keep since otherwise bound would not fit
319+
var lastDropped: Type = NoType // the last type dropped in dropOneSuperTrait
320+
321+
def dropOneSuperTrait(tp: Type): Type =
322+
val tpd = tp.dealias
323+
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !keep.contains(tpd) then
324+
lastDropped = tpd
325+
if tpd.derivesFrom(defn.ObjectClass) then defn.ObjectType else defn.AnyType
326+
else tpd match
327+
case AndType(tp1, tp2) =>
328+
val tp1w = dropOneSuperTrait(tp1)
329+
if tp1w ne tp1 then tp1w & tp2
330+
else
331+
val tp2w = dropOneSuperTrait(tp2)
332+
if tp2w ne tp2 then tp1 & tp2w
333+
else tpd
334+
case _ =>
335+
tp
317336

318-
def dropRestricted(tp: Type): Type = tp.dealias match
319-
case tpd @ AndType(tp1, tp2) =>
320-
if isRestricted(tp1) then tp2
321-
else if isRestricted(tp2) then tp1
337+
def recur(tp: Type): Type =
338+
val tpw = dropOneSuperTrait(tp)
339+
if tpw eq tp then tp
340+
else if tpw <:< bound then recur(tpw)
322341
else
323-
val tpw = tpd.derivedAndType(dropRestricted(tp1), dropRestricted(tp2))
324-
if tpw ne tpd then tpw else tp
325-
case _ =>
326-
tp
342+
keep += lastDropped
343+
recur(tp)
327344

328-
def widenRestricted(tp: Type) =
329-
val tpw = dropRestricted(tp)
330-
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
345+
recur(tp)
346+
end dropSuperTraits
331347

332348
def widenOr(tp: Type) =
333349
val tpw = tp.widenUnion
@@ -343,10 +359,7 @@ trait ConstraintHandling[AbstractContext] {
343359

344360
val wideInst =
345361
if isSingleton(bound) then inst
346-
else /*widenRestricted*/(widenOr(widenSingle(inst)))
347-
// widenRestricted is currently not called since it's special cased in `dropEnumValue`
348-
// in `Namer`. It's left in here in case we want to generalize the scheme to other
349-
// "protected inheritance" classes.
362+
else dropSuperTraits(widenOr(widenSingle(inst)))
350363
wideInst match
351364
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
352365
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)

compiler/src/dotty/tools/dotc/core/Definitions.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,6 @@ class Definitions {
639639
@tu lazy val EnumClass: ClassSymbol = ctx.requiredClass("scala.Enum")
640640
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
641641

642-
@tu lazy val EnumValueClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValue")
643642
@tu lazy val EnumValuesClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValues")
644643
@tu lazy val ProductClass: ClassSymbol = ctx.requiredClass("scala.Product")
645644
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
@@ -1308,6 +1307,9 @@ class Definitions {
13081307
def isInfix(sym: Symbol)(implicit ctx: Context): Boolean =
13091308
(sym eq Object_eq) || (sym eq Object_ne)
13101309

1310+
@tu lazy val assumedSuperTraits =
1311+
Set(ComparableClass, JavaSerializableClass, ProductClass, SerializableClass)
1312+
13111313
// ----- primitive value class machinery ------------------------------------------
13121314

13131315
/** This class would also be obviated by the implicit function type design */

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

+4
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,10 @@ object SymDenotations {
11661166
final def isEffectivelySealed(using Context): Boolean =
11671167
isOneOf(FinalOrSealed) || isClass && !isOneOf(EffectivelyOpenFlags)
11681168

1169+
final def isSuperTrait(using Context): Boolean =
1170+
isClass
1171+
&& (is(SuperTrait) || defn.assumedSuperTraits.contains(symbol.asClass))
1172+
11691173
/** The class containing this denotation which has the given effective name. */
11701174
final def enclosingClassNamed(name: Name)(implicit ctx: Context): Symbol = {
11711175
val cls = enclosingClass

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ class TreePickler(pickler: TastyPickler) {
705705
if (flags.is(Sealed)) writeModTag(SEALED)
706706
if (flags.is(Abstract)) writeModTag(ABSTRACT)
707707
if (flags.is(Trait)) writeModTag(TRAIT)
708+
if flags.is(SuperTrait) then writeModTag(SUPERTRAIT)
708709
if (flags.is(Covariant)) writeModTag(COVARIANT)
709710
if (flags.is(Contravariant)) writeModTag(CONTRAVARIANT)
710711
if (flags.is(Opaque)) writeModTag(OPAQUE)

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+1
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ class TreeUnpickler(reader: TastyReader,
639639
case STATIC => addFlag(JavaStatic)
640640
case OBJECT => addFlag(Module)
641641
case TRAIT => addFlag(Trait)
642+
case SUPERTRAIT => addFlag(SuperTrait)
642643
case ENUM => addFlag(Enum)
643644
case LOCAL => addFlag(Local)
644645
case SYNTHETIC => addFlag(Synthetic)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -3434,7 +3434,7 @@ object Parsers {
34343434
}
34353435
}
34363436

3437-
/** TmplDef ::= ([‘case’] ‘class’ | ‘trait’) ClassDef
3437+
/** TmplDef ::= ([‘case’] ‘class’ | [‘super’] ‘trait’) ClassDef
34383438
* | [‘case’] ‘object’ ObjectDef
34393439
* | ‘enum’ EnumDef
34403440
* | ‘given’ GivenDef
@@ -3444,6 +3444,8 @@ object Parsers {
34443444
in.token match {
34453445
case TRAIT =>
34463446
classDef(start, posMods(start, addFlag(mods, Trait)))
3447+
case SUPERTRAIT =>
3448+
classDef(start, posMods(start, addFlag(mods, Trait | SuperTrait)))
34473449
case CLASS =>
34483450
classDef(start, posMods(start, mods))
34493451
case CASECLASS =>

compiler/src/dotty/tools/dotc/parsing/Scanners.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,8 @@ object Scanners {
586586
currentRegion = r.outer
587587
case _ =>
588588

589-
/** - Join CASE + CLASS => CASECLASS, CASE + OBJECT => CASEOBJECT, SEMI + ELSE => ELSE, COLON + <EOL> => COLONEOL
589+
/** - Join CASE + CLASS => CASECLASS, CASE + OBJECT => CASEOBJECT, SUPER + TRAIT => SUPERTRAIT
590+
* SEMI + ELSE => ELSE, COLON + <EOL> => COLONEOL
590591
* - Insert missing OUTDENTs at EOF
591592
*/
592593
def postProcessToken(): Unit = {
@@ -602,6 +603,10 @@ object Scanners {
602603
if (token == CLASS) fuse(CASECLASS)
603604
else if (token == OBJECT) fuse(CASEOBJECT)
604605
else reset()
606+
case SUPER =>
607+
lookAhead()
608+
if token == TRAIT then fuse(SUPERTRAIT)
609+
else reset()
605610
case SEMI =>
606611
lookAhead()
607612
if (token != ELSE) reset()

compiler/src/dotty/tools/dotc/parsing/Tokens.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ object Tokens extends TokensCommon {
184184
final val ERASED = 63; enter(ERASED, "erased")
185185
final val GIVEN = 64; enter(GIVEN, "given")
186186
final val EXPORT = 65; enter(EXPORT, "export")
187-
final val MACRO = 66; enter(MACRO, "macro") // TODO: remove
187+
final val SUPERTRAIT = 66; enter(SUPERTRAIT, "super trait")
188+
final val MACRO = 67; enter(MACRO, "macro") // TODO: remove
188189

189190
/** special symbols */
190191
final val NEWLINE = 78; enter(NEWLINE, "end of statement", "new line")
@@ -233,7 +234,7 @@ object Tokens extends TokensCommon {
233234
final val canStartTypeTokens: TokenSet = literalTokens | identifierTokens | BitSet(
234235
THIS, SUPER, USCORE, LPAREN, AT)
235236

236-
final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT)
237+
final val templateIntroTokens: TokenSet = BitSet(CLASS, TRAIT, OBJECT, ENUM, CASECLASS, CASEOBJECT, SUPERTRAIT)
237238

238239
final val dclIntroTokens: TokenSet = BitSet(DEF, VAL, VAR, TYPE, GIVEN)
239240

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

+7-2
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
729729
}
730730

731731
private def Modifiers(sym: Symbol): Modifiers = untpd.Modifiers(
732-
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags),
732+
sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags | SuperTrait else ModifierFlags),
733733
if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY,
734734
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree))
735735

@@ -839,7 +839,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
839839
}
840840

841841
protected def templateText(tree: TypeDef, impl: Template): Text = {
842-
val decl = modText(tree.mods, tree.symbol, keywordStr(if (tree.mods.is(Trait)) "trait" else "class"), isType = true)
842+
val kw =
843+
if tree.mods.is(SuperTrait) then "super trait"
844+
else if tree.mods.is(Trait) then "trait"
845+
else "class"
846+
val decl = modText(tree.mods, tree.symbol, keywordStr(kw), isType = true)
843847
( decl ~~ typeText(nameIdText(tree)) ~ withEnclosingDef(tree) { toTextTemplate(impl) }
844848
// ~ (if (tree.hasType && printDebug) i"[decls = ${tree.symbol.info.decls}]" else "") // uncomment to enable
845849
)
@@ -945,6 +949,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
945949
else if (sym.isPackageObject) "package object"
946950
else if (flags.is(Module) && flags.is(Case)) "case object"
947951
else if (sym.isClass && flags.is(Case)) "case class"
952+
else if sym.isClass && flags.is(SuperTrait) then "super trait"
948953
else super.keyString(sym)
949954
}
950955

compiler/src/dotty/tools/dotc/typer/Namer.scala

+1-16
Original file line numberDiff line numberDiff line change
@@ -1458,19 +1458,6 @@ class Namer { typer: Typer =>
14581458
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
14591459
def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable)
14601460

1461-
def isEnumValue(tp: Type) = tp.typeSymbol == defn.EnumValueClass
1462-
1463-
// Drop EnumValue parents from inferred types of enum constants
1464-
def dropEnumValue(tp: Type): Type = tp.dealias match
1465-
case tpd @ AndType(tp1, tp2) =>
1466-
if isEnumValue(tp1) then tp2
1467-
else if isEnumValue(tp2) then tp1
1468-
else
1469-
val tpw = tpd.derivedAndType(dropEnumValue(tp1), dropEnumValue(tp2))
1470-
if tpw ne tpd then tpw else tp
1471-
case _ =>
1472-
tp
1473-
14741461
// Widen rhs type and eliminate `|' but keep ConstantTypes if
14751462
// definition is inline (i.e. final in Scala2) and keep module singleton types
14761463
// instead of widening to the underlying module class types.
@@ -1479,9 +1466,7 @@ class Namer { typer: Typer =>
14791466
def widenRhs(tp: Type): Type =
14801467
tp.widenTermRefExpr.simplified match
14811468
case ctp: ConstantType if isInlineVal => ctp
1482-
case tp =>
1483-
val tp1 = ctx.typeComparer.widenInferred(tp, rhsProto)
1484-
if sym.is(Enum) then dropEnumValue(tp1) else tp1
1469+
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)
14851470

14861471
// Replace aliases to Unit by Unit itself. If we leave the alias in
14871472
// it would be erased to BoxedUnit.

docs/docs/internals/syntax.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ VarDef ::= PatDef
388388
DefDef ::= DefSig [‘:’ Type] ‘=’ Expr DefDef(_, name, tparams, vparamss, tpe, expr)
389389
| ‘this’ DefParamClause DefParamClauses ‘=’ ConstrExpr DefDef(_, <init>, Nil, vparamss, EmptyTree, expr | Block)
390390
391-
TmplDef ::= ([‘case’] ‘class’ | ‘trait’) ClassDef
391+
TmplDef ::= ([‘case’] ‘class’ | [‘super’] ‘trait’) ClassDef
392392
| [‘case’] ‘object’ ObjectDef
393393
| ‘enum’ EnumDef
394394
| ‘given’ GivenDef
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package scala.runtime
2+
3+
super trait EnumValue extends Product, Serializable:
4+
override def canEqual(that: Any) = this eq that.asInstanceOf[AnyRef]
5+
override def productArity: Int = 0
6+
override def productPrefix: String = toString
7+
override def productElement(n: Int): Any =
8+
throw IndexOutOfBoundsException(n.toString)
9+
override def productElementName(n: Int): String =
10+
throw IndexOutOfBoundsException(n.toString)

tasty/src/dotty/tools/tasty/TastyFormat.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ Standard-Section: "ASTs" TopLevelStat*
189189
STATIC -- Mapped to static Java member
190190
OBJECT -- An object or its class
191191
TRAIT -- A trait
192+
SUPERTRAIT -- A super trait
192193
ENUM -- A enum class or enum case
193194
LOCAL -- private[this] or protected[this], used in conjunction with PRIVATE or PROTECTED
194195
SYNTHETIC -- Generated by Scala compiler
@@ -359,6 +360,7 @@ object TastyFormat {
359360
final val OPEN = 40
360361
final val PARAMEND = 41
361362
final val PARAMalias = 42
363+
final val SUPERTRAIT = 43
362364

363365
// Cat. 2: tag Nat
364366

@@ -473,7 +475,7 @@ object TastyFormat {
473475

474476
/** Useful for debugging */
475477
def isLegalTag(tag: Int): Boolean =
476-
firstSimpleTreeTag <= tag && tag <= PARAMalias ||
478+
firstSimpleTreeTag <= tag && tag <= SUPERTRAIT ||
477479
firstNatTreeTag <= tag && tag <= RENAMED ||
478480
firstASTTreeTag <= tag && tag <= BOUNDED ||
479481
firstNatASTTreeTag <= tag && tag <= NAMEDARG ||
@@ -502,6 +504,7 @@ object TastyFormat {
502504
| STATIC
503505
| OBJECT
504506
| TRAIT
507+
| SUPERTRAIT
505508
| ENUM
506509
| LOCAL
507510
| SYNTHETIC
@@ -562,6 +565,7 @@ object TastyFormat {
562565
case STATIC => "STATIC"
563566
case OBJECT => "OBJECT"
564567
case TRAIT => "TRAIT"
568+
case SUPERTRAIT => "SUPERTRAIT"
565569
case ENUM => "ENUM"
566570
case LOCAL => "LOCAL"
567571
case SYNTHETIC => "SYNTHETIC"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
sealed super trait TA
2+
sealed super trait TB
3+
case object a extends TA, TB
4+
case object b extends TA, TB
5+
6+
object Test:
7+
8+
def choose0[X](x: X, y: X): X = x
9+
def choose1[X <: TA](x: X, y: X): X = x
10+
def choose2[X <: TB](x: X, y: X): X = x
11+
def choose3[X <: Product](x: X, y: X): X = x
12+
def choose4[X <: TA & TB](x: X, y: X): X = x
13+
14+
choose0(a, b) match
15+
case _: TA => ???
16+
case _: TB => ???
17+
18+
choose1(a, b) match
19+
case _: TA => ???
20+
case _: TB => ??? // error: unreachable
21+
22+
choose2(a, b) match
23+
case _: TB => ???
24+
case _: TA => ??? // error: unreachable
25+
26+
choose3(a, b) match
27+
case _: Product => ???
28+
case _: TA => ??? // error: unreachable
29+
30+
choose4(a, b) match
31+
case _: (TA & TB) => ???
32+
case _: Product => ??? // error: unreachable

tests/neg/supertraits.scala

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
super trait S
2+
trait A
3+
class B extends A, S
4+
class C extends A, S
5+
6+
val x = if ??? then B() else C()
7+
val x1: S = x // error
8+
9+
case object a
10+
case object b
11+
val y = if ??? then a else b
12+
val y1: Product = y // error
13+
val y2: Serializable = y // error
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
object Test {
22
def main(args: Array[String]): Unit = {
33
val a = new A_1
4-
val x = new java.io.Serializable {}
4+
val x: java.io.Serializable = new java.io.Serializable {}
55
a.foo(x)
66
}
77
}

0 commit comments

Comments
 (0)