Skip to content

Insert traits with implicit parameters as extra parents of classes #11830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,10 @@ object TypeUtils {
case self: TypeProxy =>
self.underlying.companionRef
}

/** Is this type a methodic type that takes implicit parameters (both old and new) at some point? */
def takesImplicitParams(using Context): Boolean = self.stripPoly match
case mt: MethodType => mt.isImplicitMethod || mt.resType.takesImplicitParams
case _ => false
}
}
71 changes: 70 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1237,12 +1237,81 @@ class Namer { typer: Typer =>
}
}

/** Ensure that the first type in a list of parent types Ps points to a non-trait class.
* If that's not already the case, add one. The added class type CT is determined as follows.
* First, let C be the unique class such that
* - there is a parent P_i such that P_i derives from C, and
* - for every class D: If some parent P_j, j <= i derives from D, then C derives from D.
* Then, let CT be the smallest type which
* - has C as its class symbol, and
* - for all parents P_i: If P_i derives from C then P_i <:< CT.
*/
def ensureFirstIsClass(parents: List[Type]): List[Type] =

def realClassParent(sym: Symbol): ClassSymbol =
if !sym.isClass then defn.ObjectClass
else if !sym.is(Trait) then sym.asClass
else sym.info.parents match
case parentRef :: _ => realClassParent(parentRef.typeSymbol)
case nil => defn.ObjectClass

def improve(candidate: ClassSymbol, parent: Type): ClassSymbol =
val pcls = realClassParent(parent.classSymbol)
if (pcls derivesFrom candidate) pcls else candidate

parents match
case p :: _ if p.classSymbol.isRealClass => parents
case _ =>
val pcls = parents.foldLeft(defn.ObjectClass)(improve)
typr.println(i"ensure first is class $parents%, % --> ${parents map (_ baseType pcls)}%, %")
val first = TypeComparer.glb(defn.ObjectType :: parents.map(_.baseType(pcls)))
checkFeasibleParent(first, cls.srcPos, em" in inferred superclass $first") :: parents
end ensureFirstIsClass

/** If `parents` contains references to traits that have supertraits with implicit parameters
* add those supertraits in linearization order unless they are already covered by other
* parent types. For instance, in
*
* class A
* trait B(using I) extends A
* trait C extends B
* class D extends A, C
*
* the class declaration of `D` is augmented to
*
* class D extends A, B, C
*
* so that an implicit `I` can be passed to `B`. See i7613.scala for more examples.
*/
def addUsingTraits(parents: List[Type]): List[Type] =
lazy val existing = parents.map(_.classSymbol).toSet
def recur(parents: List[Type]): List[Type] = parents match
case parent :: parents1 =>
val psym = parent.classSymbol
val addedTraits =
if psym.is(Trait) then
psym.asClass.baseClasses.tail.iterator
.takeWhile(_.is(Trait))
.filter(p =>
p.primaryConstructor.info.takesImplicitParams
&& !cls.superClass.isSubClass(p)
&& !existing.contains(p))
.toList.reverse
else Nil
addedTraits.map(parent.baseType) ::: parent :: recur(parents1)
case nil =>
Nil
if cls.isRealClass then recur(parents) else parents
end addUsingTraits

completeConstructor(denot)
denot.info = tempInfo

val parentTypes = defn.adjustForTuple(cls, cls.typeParams,
defn.adjustForBoxedUnit(cls,
ensureFirstIsClass(parents.map(checkedParentType(_)), cls.span)
addUsingTraits(
ensureFirstIsClass(parents.map(checkedParentType(_)))
)
)
)
typr.println(i"completing $denot, parents = $parents%, %, parentTypes = $parentTypes%, %")
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/ReTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ class ReTyper extends Typer with ReChecking {

override def completeAnnotations(mdef: untpd.MemberDef, sym: Symbol)(using Context): Unit = ()

override def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] =
parents
override def ensureConstrCall(cls: ClassSymbol, parent: Tree)(using Context): Tree =
parent

override def handleUnexpectedFunType(tree: untpd.Apply, fun: Tree)(using Context): Tree = fun.tpe match {
case mt: MethodType =>
Expand Down
90 changes: 41 additions & 49 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2212,20 +2212,18 @@ class Typer extends Namer
* @param psym Its type symbol
* @param cinfo The info of its constructor
*/
def maybeCall(ref: Tree, psym: Symbol, cinfo: Type): Tree = cinfo.stripPoly match {
def maybeCall(ref: Tree, psym: Symbol): Tree = psym.primaryConstructor.info.stripPoly match
case cinfo @ MethodType(Nil) if cinfo.resultType.isImplicitMethod =>
typedExpr(untpd.New(untpd.TypedSplice(ref)(using superCtx), Nil))(using superCtx)
case cinfo @ MethodType(Nil) if !cinfo.resultType.isInstanceOf[MethodType] =>
ref
case cinfo: MethodType =>
if (!ctx.erasedTypes) { // after constructors arguments are passed in super call.
if !ctx.erasedTypes then // after constructors arguments are passed in super call.
typr.println(i"constr type: $cinfo")
report.error(ParameterizedTypeLacksArguments(psym), ref.srcPos)
}
ref
case _ =>
ref
}

val seenParents = mutable.Set[Symbol]()

Expand All @@ -2250,14 +2248,35 @@ class Typer extends Namer
if (tree.isType) {
checkSimpleKinded(result) // Not needed for constructor calls, as type arguments will be inferred.
if (psym.is(Trait) && !cls.is(Trait) && !cls.superClass.isSubClass(psym))
result = maybeCall(result, psym, psym.primaryConstructor.info)
result = maybeCall(result, psym)
}
else checkParentCall(result, cls)
checkTraitInheritance(psym, cls, tree.srcPos)
if (cls is Case) checkCaseInheritance(psym, cls, tree.srcPos)
result
}

/** Augment `ptrees` to have the same class symbols as `parents`. Generate TypeTrees
* or New trees to fill in any parents for which no tree exists yet.
*/
def parentTrees(parents: List[Type], ptrees: List[Tree]): List[Tree] = parents match
case parent :: parents1 =>
val psym = parent.classSymbol
def hasSameParent(ptree: Tree) = ptree.tpe.classSymbol == psym
ptrees match
case ptree :: ptrees1 if hasSameParent(ptree) =>
ptree :: parentTrees(parents1, ptrees1)
case ptree :: ptrees1 if ptrees1.exists(hasSameParent) =>
ptree :: parentTrees(parents, ptrees1)
case _ =>
var added: Tree = TypeTree(parent).withSpan(cdef.nameSpan.focus)
if psym.is(Trait) && psym.primaryConstructor.info.takesImplicitParams then
// classes get a constructor separately using a different context
added = ensureConstrCall(cls, added)
added :: parentTrees(parents1, ptrees)
case _ =>
ptrees

/** Checks if one of the decls is a type with the same name as class type member in selfType */
def classExistsOnSelf(decls: Scope, self: tpd.ValDef): Boolean = {
val selfType = self.tpt.tpe
Expand All @@ -2278,8 +2297,10 @@ class Typer extends Namer

completeAnnotations(cdef, cls)
val constr1 = typed(constr).asInstanceOf[DefDef]
val parentsWithClass = ensureFirstTreeIsClass(parents.mapconserve(typedParent).filterConserve(!_.isEmpty), cdef.nameSpan)
val parents1 = ensureConstrCall(cls, parentsWithClass)(using superCtx)
val parents0 = parentTrees(
cls.classInfo.declaredParents,
parents.mapconserve(typedParent).filterConserve(!_.isEmpty))
val parents1 = ensureConstrCall(cls, parents0)(using superCtx)
val firstParentTpe = parents1.head.tpe.dealias
val firstParent = firstParentTpe.typeSymbol

Expand Down Expand Up @@ -2348,52 +2369,23 @@ class Typer extends Namer
protected def addAccessorDefs(cls: Symbol, body: List[Tree])(using Context): List[Tree] =
ctx.compilationUnit.inlineAccessors.addAccessorDefs(cls, body)

/** Ensure that the first type in a list of parent types Ps points to a non-trait class.
* If that's not already the case, add one. The added class type CT is determined as follows.
* First, let C be the unique class such that
* - there is a parent P_i such that P_i derives from C, and
* - for every class D: If some parent P_j, j <= i derives from D, then C derives from D.
* Then, let CT be the smallest type which
* - has C as its class symbol, and
* - for all parents P_i: If P_i derives from C then P_i <:< CT.
/** If this is a real class, make sure its first parent is a
* constructor call. Cannot simply use a type. Overridden in ReTyper.
*/
def ensureFirstIsClass(parents: List[Type], span: Span)(using Context): List[Type] = {
def realClassParent(cls: Symbol): ClassSymbol =
if (!cls.isClass) defn.ObjectClass
else if (!cls.is(Trait)) cls.asClass
else cls.info.parents match {
case parentRef :: _ => realClassParent(parentRef.typeSymbol)
case nil => defn.ObjectClass
}
def improve(candidate: ClassSymbol, parent: Type): ClassSymbol = {
val pcls = realClassParent(parent.classSymbol)
if (pcls derivesFrom candidate) pcls else candidate
}
parents match {
case p :: _ if p.classSymbol.isRealClass => parents
case _ =>
val pcls = parents.foldLeft(defn.ObjectClass)(improve)
typr.println(i"ensure first is class $parents%, % --> ${parents map (_ baseType pcls)}%, %")
val first = TypeComparer.glb(defn.ObjectType :: parents.map(_.baseType(pcls)))
checkFeasibleParent(first, ctx.source.atSpan(span), em" in inferred superclass $first") :: parents
}
}
def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] = parents match
case parents @ (first :: others) =>
parents.derivedCons(ensureConstrCall(cls, first), others)
case parents =>
parents

/** Ensure that first parent tree refers to a real class. */
def ensureFirstTreeIsClass(parents: List[Tree], span: Span)(using Context): List[Tree] = parents match {
case p :: ps if p.tpe.classSymbol.isRealClass => parents
case _ => TypeTree(ensureFirstIsClass(parents.tpes, span).head).withSpan(span.focus) :: parents
}

/** If this is a real class, make sure its first parent is a
/** If this is a real class, make sure its first parent is a
* constructor call. Cannot simply use a type. Overridden in ReTyper.
*/
def ensureConstrCall(cls: ClassSymbol, parents: List[Tree])(using Context): List[Tree] = {
val firstParent :: otherParents = parents
if (firstParent.isType && !cls.is(Trait) && !cls.is(JavaDefined))
typed(untpd.New(untpd.TypedSplice(firstParent), Nil)) :: otherParents
else parents
}
def ensureConstrCall(cls: ClassSymbol, parent: Tree)(using Context): Tree =
if (parent.isType && !cls.is(Trait) && !cls.is(JavaDefined))
typed(untpd.New(untpd.TypedSplice(parent), Nil))
else
parent

def localDummy(cls: ClassSymbol, impl: untpd.Template)(using Context): Symbol =
newLocalDummy(cls, impl.span)
Expand Down
30 changes: 30 additions & 0 deletions docs/docs/reference/other-new-features/trait-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ The correct way to write `E` is to extend both `Greeting` and
class E extends Greeting("Bob"), FormalGreeting
```

### Traits With Context Parameters

This "explicit extension required" rule is relaxed if the missing trait contains only
[context parameters](../contextual/using-clauses). In that case the trait reference is
implicitly inserted as an additional parent with inferred arguments. For instance,
here's a variant of greetings where the addressee is a context parameter of type
`ImpliedName`:

```scala
case class ImpliedName(name: String):
override def toString = name

trait ImpliedGreeting(using val iname: ImpliedName):
def msg = s"How are you, $iname"

trait ImpliedFormalGreeting extends ImpliedGreeting:
override def msg = s"How do you do, $iname"

class F(using iname: ImpliedName) extends ImpliedFormalGreeting
```

The definition of `F` in the last line is implicitly expanded to
```scala
class F(using iname: ImpliedName) extends
Object,
ImpliedGreeting(using iname),
ImpliedFormalGreeting(using iname)
```
Note the inserted reference to the super trait `ImpliedGreeting`, which was not mentioned explicitly.

## Reference

For more information, see [Scala SIP 25](http://docs.scala-lang.org/sips/pending/trait-parameters.html).
2 changes: 1 addition & 1 deletion tests/neg/i6060.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class I1(i2: Int) {
def apply(i3: Int) = 1
new I1(1)(2) {} // error: too many arguments in parent constructor
new I1(1)(2) {} // error: too many arguments in parent constructor // error
}

class I0(i1: Int) {
Expand Down
8 changes: 8 additions & 0 deletions tests/neg/i7613.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/i7613.scala:10:16 ----------------------------------------------------------------------------------
10 | new BazLaws[A] {} // error // error
| ^
| no implicit argument of type Baz[A] was found for parameter x$1 of constructor BazLaws in trait BazLaws
-- Error: tests/neg/i7613.scala:10:2 -----------------------------------------------------------------------------------
10 | new BazLaws[A] {} // error // error
| ^
| no implicit argument of type Bar[A] was found for parameter x$1 of constructor BarLaws in trait BarLaws
11 changes: 11 additions & 0 deletions tests/neg/i7613.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
trait Foo[A]
trait Bar[A] extends Foo[A]
trait Baz[A] extends Bar[A]

trait FooLaws[A](using Foo[A])
trait BarLaws[A](using Bar[A]) extends FooLaws[A]
trait BazLaws[A](using Baz[A]) extends BarLaws[A]

def instance[A](using Foo[A]): BazLaws[A] =
new BazLaws[A] {} // error // error

9 changes: 9 additions & 0 deletions tests/pos/reference/trait-parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,13 @@ class E extends Greeting("Bob") with FormalGreeting

// class D2 extends C with Greeting("Bill") // error

case class ImpliedName(name: String):
override def toString = name

trait ImpliedGreeting(using val iname: ImpliedName):
def msg = s"How are you, $iname"

trait ImpliedFormalGreeting extends ImpliedGreeting:
override def msg = s"How do you do, $iname"

class F(using iname: ImpliedName) extends ImpliedFormalGreeting
5 changes: 5 additions & 0 deletions tests/run/i7613.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
D: B1
superD: B1
E: B2
F: B1
F: B2
29 changes: 29 additions & 0 deletions tests/run/i7613.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
trait Foo[A]
trait Bar[A] extends Foo[A]
trait Baz[A] extends Bar[A]

trait FooLaws[A](using Foo[A])
trait BarLaws[A](using Bar[A]) extends FooLaws[A]
trait BazLaws[A](using Baz[A]) extends BarLaws[A]

def instance1[A](using Baz[A]): BazLaws[A] =
new FooLaws[A] with BarLaws[A] with BazLaws[A] {}

def instance2[A](using Baz[A]): BazLaws[A] =
new BazLaws[A] {}

trait I:
def show(x: String): Unit
class A
trait B1(using I) extends A { summon[I].show("B1") }
trait B2(using I) extends B1 { summon[I].show("B2") }
trait C1 extends B1
trait C2 extends B2
class D(using I) extends A, C1
class E(using I) extends D(using new I { def show(x: String) = println(s"superD: $x")}), C2
class F(using I) extends A, C2

@main def Test =
D(using new I { def show(x: String) = println(s"D: $x")})
E(using new I { def show(x: String) = println(s"E: $x")})
F(using new I { def show(x: String) = println(s"F: $x")})