Skip to content

Commit 1b31f06

Browse files
committed
Merge pull request #639 from dotty-staging/add/trait-parameters
Add/trait parameters
2 parents 9ee21cf + 0369f1e commit 1b31f06

File tree

10 files changed

+186
-37
lines changed

10 files changed

+186
-37
lines changed

src/dotty/tools/dotc/ast/tpd.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
302302
true
303303
case pre: ThisType =>
304304
pre.cls.isStaticOwner ||
305-
tp.symbol.is(ParamOrAccessor) && ctx.owner.enclosingClass == pre.cls
305+
tp.symbol.is(ParamOrAccessor) && !pre.cls.is(Trait) && ctx.owner.enclosingClass == pre.cls
306306
// was ctx.owner.enclosingClass.derivesFrom(pre.cls) which was not tight enough
307307
// and was spuriously triggered in case inner class would inherit from outer one
308308
// eg anonymous TypeMap inside TypeMap.andThen

src/dotty/tools/dotc/core/SymDenotations.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ object SymDenotations {
516516
!isAnonymousFunction &&
517517
!isCompanionMethod
518518

519-
/** Is this a setter? */
519+
/** Is this a getter? */
520520
final def isGetter(implicit ctx: Context) =
521521
(this is Accessor) && !originalName.isSetterName && !originalName.isScala2LocalSuffix
522522

src/dotty/tools/dotc/transform/Mixin.scala

+80-30
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import collection.mutable
2020

2121
/** This phase performs the following transformations:
2222
*
23-
* 1. (done in `traitDefs`) Map every concrete trait getter
23+
* 1. (done in `traitDefs` and `transformSym`) Map every concrete trait getter
2424
*
2525
* <mods> def x(): T = expr
2626
*
@@ -46,32 +46,45 @@ import collection.mutable
4646
* For every trait M directly implemented by the class (see SymUtils.mixin), in
4747
* reverse linearization order, add the following definitions to C:
4848
*
49-
* 3.1 (done in `traitInits`) For every concrete trait getter `<mods> def x(): T` in M,
50-
* in order of textual occurrence, produce the following:
49+
* 3.1 (done in `traitInits`) For every parameter accessor `<mods> def x(): T` in M,
50+
* in order of textual occurrence, add
5151
*
52-
* 3.1.1 If `x` is also a member of `C`, and M is a Dotty trait:
52+
* <mods> def x() = e
53+
*
54+
* where `e` is the constructor argument in C that corresponds to `x`. Issue
55+
* an error if no such argument exists.
56+
*
57+
* 3.2 (done in `traitInits`) For every concrete trait getter `<mods> def x(): T` in M
58+
* which is not a parameter accessor, in order of textual occurrence, produce the following:
59+
*
60+
* 3.2.1 If `x` is also a member of `C`, and M is a Dotty trait:
5361
*
5462
* <mods> def x(): T = super[M].initial$x()
5563
*
56-
* 3.1.2 If `x` is also a member of `C`, and M is a Scala 2.x trait:
64+
* 3.2.2 If `x` is also a member of `C`, and M is a Scala 2.x trait:
5765
*
5866
* <mods> def x(): T = _
5967
*
60-
* 3.1.3 If `x` is not a member of `C`, and M is a Dotty trait:
68+
* 3.2.3 If `x` is not a member of `C`, and M is a Dotty trait:
6169
*
6270
* super[M].initial$x()
6371
*
64-
* 3.1.4 If `x` is not a member of `C`, and M is a Scala2.x trait, nothing gets added.
72+
* 3.2.4 If `x` is not a member of `C`, and M is a Scala2.x trait, nothing gets added.
6573
*
6674
*
67-
* 3.2 (done in `superCallOpt`) The call:
75+
* 3.3 (done in `superCallOpt`) The call:
6876
*
6977
* super[M].<init>
7078
*
71-
* 3.3 (done in `setters`) For every concrete setter `<mods> def x_=(y: T)` in M:
79+
* 3.4 (done in `setters`) For every concrete setter `<mods> def x_=(y: T)` in M:
7280
*
7381
* <mods> def x_=(y: T) = ()
7482
*
83+
* 4. (done in `transformTemplate` and `transformSym`) Drop all parameters from trait
84+
* constructors.
85+
*
86+
* 5. (done in `transformSym`) Drop ParamAccessor flag from all parameter accessors in traits.
87+
*
7588
* Conceptually, this is the second half of the previous mixin phase. It needs to run
7689
* after erasure because it copies references to possibly private inner classes and objects
7790
* into enclosing classes where they are not visible. This can only be done if all references
@@ -86,7 +99,9 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform =>
8699

87100
override def transformSym(sym: SymDenotation)(implicit ctx: Context): SymDenotation =
88101
if (sym.is(Accessor, butNot = Deferred) && sym.owner.is(Trait))
89-
sym.copySymDenotation(initFlags = sym.flags | Deferred).ensureNotPrivate
102+
sym.copySymDenotation(initFlags = sym.flags &~ ParamAccessor | Deferred).ensureNotPrivate
103+
else if (sym.isConstructor && sym.owner.is(Trait) && sym.info.firstParamTypes.nonEmpty)
104+
sym.copySymDenotation(info = MethodType(Nil, sym.info.resultType))
90105
else
91106
sym
92107

@@ -111,7 +126,7 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform =>
111126
def traitDefs(stats: List[Tree]): List[Tree] = {
112127
val initBuf = new mutable.ListBuffer[Tree]
113128
stats.flatMap({
114-
case stat: DefDef if stat.symbol.isGetter && !stat.rhs.isEmpty && !stat.symbol.is(Flags.Lazy) =>
129+
case stat: DefDef if stat.symbol.isGetter && !stat.rhs.isEmpty && !stat.symbol.is(Flags.Lazy) =>
115130
// make initializer that has all effects of previous getter,
116131
// replace getter rhs with empty tree.
117132
val vsym = stat.symbol
@@ -131,15 +146,22 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform =>
131146
}) ++ initBuf
132147
}
133148

134-
def transformSuper(tree: Tree): Tree = {
149+
/** Map constructor call to a pair of a supercall and a list of arguments
150+
* to be used as initializers of trait parameters if the target of the call
151+
* is a trait.
152+
*/
153+
def transformConstructor(tree: Tree): (Tree, List[Tree]) = {
135154
val Apply(sel @ Select(New(_), nme.CONSTRUCTOR), args) = tree
136-
superRef(tree.symbol, tree.pos).appliedToArgs(args)
155+
val (callArgs, initArgs) = if (tree.symbol.owner.is(Trait)) (Nil, args) else (args, Nil)
156+
(superRef(tree.symbol, tree.pos).appliedToArgs(callArgs), initArgs)
137157
}
138158

139-
val superCalls = (
159+
val superCallsAndArgs = (
140160
for (p <- impl.parents if p.symbol.isConstructor)
141-
yield p.symbol.owner -> transformSuper(p)
161+
yield p.symbol.owner -> transformConstructor(p)
142162
).toMap
163+
val superCalls = superCallsAndArgs.mapValues(_._1)
164+
val initArgs = superCallsAndArgs.mapValues(_._2)
143165

144166
def superCallOpt(baseCls: Symbol): List[Tree] = superCalls.get(baseCls) match {
145167
case Some(call) =>
@@ -155,35 +177,63 @@ class Mixin extends MiniPhaseTransform with SymTransformer { thisTransform =>
155177
def wasDeferred(sym: Symbol) =
156178
ctx.atPhase(thisTransform) { implicit ctx => sym is Deferred }
157179

158-
def traitInits(mixin: ClassSymbol): List[Tree] =
180+
def traitInits(mixin: ClassSymbol): List[Tree] = {
181+
var argNum = 0
182+
def nextArgument() = initArgs.get(mixin) match {
183+
case Some(arguments) =>
184+
try arguments(argNum) finally argNum += 1
185+
case None =>
186+
val (msg, pos) = impl.parents.find(_.tpe.typeSymbol == mixin) match {
187+
case Some(parent) => ("lacks argument list", parent.pos)
188+
case None =>
189+
("""is indirectly implemented,
190+
|needs to be implemented directly so that arguments can be passed""".stripMargin,
191+
cls.pos)
192+
}
193+
ctx.error(i"parameterized $mixin $msg", pos)
194+
EmptyTree
195+
}
196+
159197
for (getter <- mixin.info.decls.filter(getr => getr.isGetter && !wasDeferred(getr)).toList) yield {
160198
val isScala2x = mixin.is(Scala2x)
161199
def default = Underscore(getter.info.resultType)
162200
def initial = transformFollowing(superRef(initializer(getter)).appliedToNone)
163-
if (isCurrent(getter) || getter.is(ExpandedName))
201+
202+
/** A call to the implementation of `getter` in `mixin`'s implementation class */
203+
def lazyGetterCall = {
204+
def canbeImplClassGetter(sym: Symbol) = sym.info.firstParamTypes match {
205+
case t :: Nil => t.isDirectRef(mixin)
206+
case _ => false
207+
}
208+
val implClassGetter = mixin.implClass.info.nonPrivateDecl(getter.name)
209+
.suchThat(canbeImplClassGetter).symbol
210+
ref(mixin.implClass).select(implClassGetter).appliedTo(This(cls))
211+
}
212+
213+
if (isCurrent(getter) || getter.is(ExpandedName)) {
214+
val rhs =
215+
if (ctx.atPhase(thisTransform)(implicit ctx => getter.is(ParamAccessor))) nextArgument()
216+
else if (isScala2x)
217+
if (getter.is(Lazy)) lazyGetterCall
218+
else Underscore(getter.info.resultType)
219+
else transformFollowing(superRef(initializer(getter)).appliedToNone)
164220
// transformFollowing call is needed to make memoize & lazy vals run
165-
transformFollowing(
166-
DefDef(implementation(getter.asTerm),
167-
if (isScala2x) {
168-
if (getter.is(Flags.Lazy)) { // lazy vals need to have a rhs that will be the lazy initializer
169-
val sym = mixin.implClass.info.nonPrivateDecl(getter.name).suchThat(_.info.paramTypess match {
170-
case List(List(t: TypeRef)) => t.isDirectRef(mixin)
171-
case _ => false
172-
}).symbol // lazy val can be overloaded
173-
ref(mixin.implClass).select(sym).appliedTo(This(ctx.owner.asClass))
174-
}
175-
else default
176-
} else initial)
177-
)
221+
transformFollowing(DefDef(implementation(getter.asTerm), rhs))
222+
}
178223
else if (isScala2x) EmptyTree
179224
else initial
180225
}
226+
}
181227

182228
def setters(mixin: ClassSymbol): List[Tree] =
183229
for (setter <- mixin.info.decls.filter(setr => setr.isSetter && !wasDeferred(setr)).toList)
184230
yield DefDef(implementation(setter.asTerm), unitLiteral.withPos(cls.pos))
185231

186232
cpy.Template(impl)(
233+
constr =
234+
if (cls.is(Trait) && impl.constr.vparamss.flatten.nonEmpty)
235+
cpy.DefDef(impl.constr)(vparamss = Nil :: Nil)
236+
else impl.constr,
187237
parents = impl.parents.map(p => TypeTree(p.tpe).withPos(p.pos)),
188238
body =
189239
if (cls is Trait) traitDefs(impl.body)

src/dotty/tools/dotc/typer/Checking.scala

+11-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import annotation.unchecked
2020
import util.Positions._
2121
import util.{Stats, SimpleMap}
2222
import util.common._
23+
import transform.SymUtils._
2324
import Decorators._
2425
import Uniques._
2526
import ErrorReporting.{err, errorType, DiagnosticString}
@@ -328,9 +329,15 @@ trait Checking {
328329
}
329330
}
330331

331-
def checkInstantiatable(cls: ClassSymbol, pos: Position): Unit = {
332-
??? // to be done in later phase: check that class `cls` is legal in a new.
333-
}
332+
def checkParentCall(call: Tree, caller: ClassSymbol)(implicit ctx: Context) =
333+
if (!ctx.isAfterTyper) {
334+
val called = call.tpe.classSymbol
335+
if (caller is Trait)
336+
ctx.error(i"$caller may not call constructor of $called", call.pos)
337+
else if (called.is(Trait) && !caller.mixins.contains(called))
338+
ctx.error(i"""$called is already implemented by super${caller.superClass},
339+
|its constructor cannot be called again""".stripMargin, call.pos)
340+
}
334341
}
335342

336343
trait NoChecking extends Checking {
@@ -343,4 +350,5 @@ trait NoChecking extends Checking {
343350
override def checkImplicitParamsNotSingletons(vparamss: List[List[ValDef]])(implicit ctx: Context): Unit = ()
344351
override def checkFeasible(tp: Type, pos: Position, where: => String = "")(implicit ctx: Context): Type = tp
345352
override def checkNoDoubleDefs(cls: Symbol)(implicit ctx: Context): Unit = ()
353+
override def checkParentCall(call: Tree, caller: ClassSymbol)(implicit ctx: Context) = ()
346354
}

src/dotty/tools/dotc/typer/Typer.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
911911
if (tree.isType) typedType(tree)(superCtx)
912912
else {
913913
val result = typedExpr(tree)(superCtx)
914-
if ((cls is Trait) && result.tpe.classSymbol.isRealClass && !ctx.isAfterTyper)
915-
ctx.error(s"trait may not call constructor of ${result.tpe.classSymbol}", tree.pos)
914+
checkParentCall(result, cls)
916915
result
917916
}
918917

test/dotc/tests.scala

+2
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ class tests extends CompilerTest {
138138
@Test def neg_instantiateAbstract = compileFile(negDir, "instantiateAbstract", xerrors = 8)
139139
@Test def neg_selfInheritance = compileFile(negDir, "selfInheritance", xerrors = 5)
140140
@Test def neg_shadowedImplicits = compileFile(negDir, "arrayclone-new", xerrors = 2)
141+
@Test def neg_traitParamsTyper = compileFile(negDir, "traitParamsTyper", xerrors = 5)
142+
@Test def neg_traitParamsMixin = compileFile(negDir, "traitParamsMixin", xerrors = 2)
141143

142144
@Test def run_all = runFiles(runDir)
143145

tests/neg/traitParamsMixin.scala

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
trait T(x: Int) {
2+
def f = x
3+
}
4+
5+
class C extends T // error
6+
7+
trait U extends T
8+
9+
class D extends U { // error
10+
11+
}
12+

tests/neg/traitParamsTyper.scala

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
trait T(x: Int) {
2+
def f = x
3+
}
4+
5+
class C(x: Int) extends T() // error
6+
7+
trait U extends C with T
8+
9+
trait V extends C(1) with T(2) // two errors
10+
11+
trait W extends T(3) // error
12+
13+
14+
class E extends T(0)
15+
class F extends E with T(1) // error
16+

tests/run/traitParamInit.scala

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
object Trace {
2+
private var results = List[Any]()
3+
def apply[A](a: A) = {results ::= a; a}
4+
def fetchAndClear(): Seq[Any] = try results.reverse finally results = Nil
5+
}
6+
trait T(a: Any) {
7+
val ta = a
8+
Trace(s"T.<init>($ta)")
9+
val t_val = Trace("T.val")
10+
}
11+
12+
trait U(a: Any) extends T {
13+
val ua = a
14+
Trace(s"U.<init>($ua)")
15+
}
16+
17+
object Test {
18+
def check(expected: Any) = {
19+
val actual = Trace.fetchAndClear()
20+
if (actual != expected)
21+
sys.error(s"\n$actual\n$expected")
22+
}
23+
def main(args: Array[String]): Unit = {
24+
new T(Trace("ta")) with U(Trace("ua")) {}
25+
check(List("ta", "T.<init>(ta)", "T.val", "ua", "U.<init>(ua)"))
26+
27+
new U(Trace("ua")) with T(Trace("ta")) {}
28+
check(List("ta", "T.<init>(ta)", "T.val", "ua", "U.<init>(ua)"))
29+
}
30+
}

tests/run/traitParams.scala

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
object State {
2+
var s: Int = 0
3+
}
4+
5+
trait T(x: Int, val y: Int) {
6+
def f = x
7+
}
8+
9+
trait U extends T {
10+
State.s += 1
11+
override def f = super.f + y
12+
}
13+
trait U2(a: Any) extends T {
14+
def d = a // okay
15+
val v = a // okay
16+
a // used to crash
17+
}
18+
19+
import State._
20+
class C(x: Int) extends U with T(x, x * x + s)
21+
class C2(x: Int) extends T(x, x * x + s) with U
22+
23+
class D extends C(10) with T
24+
class D2 extends C2(10) with T
25+
26+
object Test {
27+
def main(args: Array[String]): Unit = {
28+
assert(new D().f == 110)
29+
assert(new D2().f == 111)
30+
}
31+
}
32+

0 commit comments

Comments
 (0)