Skip to content

Commit 5885ecf

Browse files
committed
Eliminate polyDefDef def and calls
1 parent eef65e1 commit 5885ecf

File tree

15 files changed

+73
-123
lines changed

15 files changed

+73
-123
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 20 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
224224
def DefDef(sym: TermSymbol, rhs: Tree = EmptyTree)(using Context): DefDef =
225225
ta.assignType(DefDef(sym, Function.const(rhs) _), sym)
226226

227+
/** A DefDef with given method symbol `sym`.
228+
* @rhsFn A function from parameter references
229+
* to the method's right-hand side.
230+
* Parameter symbols are taken from the `rawParamss` field of `sym`, or
231+
* are freshly generated if `rawParamss` is empty.
232+
*/
227233
def DefDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Tree)(using Context): DefDef =
228234

229235
// Map method type `tp` with remaining parameters stored in rawParamss to
@@ -277,69 +283,6 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
277283
DefDef(sym, paramss, rtp, rhsFn(paramss.nestedMap(ref)))
278284
end DefDef
279285

280-
/** A DefDef with given method symbol `sym`.
281-
* @rhsFn A function from type parameter types and term parameter references
282-
* to the method's right-hand side.
283-
* Parameter symbols are taken from the `rawParamss` field of `sym`, or
284-
* are freshly generated if `rawParamss` is empty.
285-
*/
286-
def polyDefDef(sym: TermSymbol, rhsFn: List[Tree] => List[List[Tree]] => Tree)(using Context): DefDef = {
287-
288-
val (tparams, existingParamss, mtp) = sym.info match {
289-
case tp: PolyType =>
290-
val (tparams, existingParamss) = sym.rawParamss match
291-
case tparams :: vparamss =>
292-
assert(tparams.hasSameLengthAs(tp.paramNames) && tparams.head.isType)
293-
(tparams.asInstanceOf[List[TypeSymbol]], vparamss)
294-
case _ =>
295-
(newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_)), Nil)
296-
(tparams, existingParamss, tp.instantiate(tparams map (_.typeRef)))
297-
case tp => (Nil, sym.rawParamss, tp)
298-
}
299-
300-
def valueParamss(tp: Type, existingParamss: List[List[Symbol]]): (List[List[TermSymbol]], Type) = tp match {
301-
case tp: MethodType =>
302-
val isParamDependent = tp.isParamDependent
303-
val previousParamRefs = if (isParamDependent) mutable.ListBuffer[TermRef]() else null
304-
305-
def valueParam(name: TermName, origInfo: Type): TermSymbol = {
306-
val maybeImplicit =
307-
if (tp.isContextualMethod) Given
308-
else if (tp.isImplicitMethod) Implicit
309-
else EmptyFlags
310-
val maybeErased = if (tp.isErasedMethod) Erased else EmptyFlags
311-
312-
def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)
313-
314-
if (isParamDependent) {
315-
val sym = makeSym(origInfo.substParams(tp, previousParamRefs.toList))
316-
previousParamRefs += sym.termRef
317-
sym
318-
}
319-
else
320-
makeSym(origInfo)
321-
}
322-
323-
val (params, existingParamss1) =
324-
if tp.paramInfos.isEmpty then (Nil, existingParamss)
325-
else existingParamss match
326-
case vparams :: existingParamss1 =>
327-
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
328-
(vparams.asInstanceOf[List[TermSymbol]], existingParamss1)
329-
case _ =>
330-
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
331-
val (paramss, rtp) =
332-
valueParamss(tp.instantiate(params map (_.termRef)), existingParamss1)
333-
(params :: paramss, rtp)
334-
case tp => (Nil, tp.widenExpr)
335-
}
336-
val (vparamss, rtp) = valueParamss(mtp, existingParamss)
337-
val targs = tparams.map(tparam => ref(tparam.typeRef))
338-
val argss = vparamss.nestedMap(vparam => Ident(vparam.termRef))
339-
sym.setParamss(tparams :: vparamss)
340-
DefDef(sym, joinSymbols(tparams, vparamss), rtp, rhsFn(targs)(argss))
341-
}
342-
343286
def TypeDef(sym: TypeSymbol)(using Context): TypeDef =
344287
ta.assignType(untpd.TypeDef(sym.name, TypeTree(sym.info)), sym)
345288

@@ -404,7 +347,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
404347
def forwarder(fn: TermSymbol, name: TermName) = {
405348
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
406349
if (fwdMeth.allOverriddenSymbols.exists(!_.is(Deferred))) fwdMeth.setFlag(Override)
407-
polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypeTrees(tprefs).appliedToArgss(prefss))
350+
DefDef(fwdMeth, ref(fn).appliedToArgss(_))
408351
}
409352
val forwarders = fns.lazyZip(methNames).map(forwarder)
410353
val cdef = ClassDef(cls, DefDef(constr), forwarders)
@@ -1285,12 +1228,21 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
12851228
Ident(defn.ScalaRuntimeModule.requiredMethod(name).termRef).appliedToTermArgs(args)
12861229

12871230
/** An extractor that pulls out type arguments */
1288-
object MaybePoly {
1289-
def unapply(tree: Tree): Option[(Tree, List[Tree])] = tree match {
1231+
object MaybePoly:
1232+
def unapply(tree: Tree): Option[(Tree, List[Tree])] = tree match
12901233
case TypeApply(tree, targs) => Some(tree, targs)
12911234
case _ => Some(tree, Nil)
1292-
}
1293-
}
1235+
1236+
object TypeArgs:
1237+
def unapply(ts: List[Tree]): Option[List[Tree]] =
1238+
if ts.nonEmpty && ts.head.isType then Some(ts) else None
1239+
1240+
/** Split argument clauses into a leading type argument clause if it exists and
1241+
* remaining clauses
1242+
*/
1243+
def splitArgs(argss: List[List[Tree]]): (List[Tree], List[List[Tree]]) = argss match
1244+
case TypeArgs(targs) :: argss1 => (targs, argss1)
1245+
case _ => (Nil, argss)
12941246

12951247
/** A key to be used in a context property that tracks enclosing inlined calls */
12961248
private val InlinedCalls = Property.Key[List[Tree]]()

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -861,9 +861,6 @@ object Symbols {
861861
case (x: Symbol) :: _ if x.isType => Some(xs.asInstanceOf[List[TypeSymbol]])
862862
case _ => None
863863

864-
def joinSymbols(xs: List[Symbol], ys: List[List[Symbol]]): List[List[Symbol]] =
865-
if xs.isEmpty then ys else xs :: ys
866-
867864
// ----- Locating predefined symbols ----------------------------------------
868865

869866
def requiredPackage(path: PreName)(using Context): TermSymbol = {

compiler/src/dotty/tools/dotc/transform/AccessProxies.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,18 @@ abstract class AccessProxies {
3737
*/
3838
private def accessorDefs(cls: Symbol)(using Context): Iterator[DefDef] =
3939
for (accessor <- cls.info.decls.iterator; accessed <- accessedBy.remove(accessor).toOption) yield
40-
polyDefDef(accessor.asTerm, tps => argss => {
40+
DefDef(accessor.asTerm, prefss => {
4141
def numTypeParams = accessed.info match {
4242
case info: PolyType => info.paramNames.length
4343
case _ => 0
4444
}
45+
val (targs, argss) = splitArgs(prefss)
4546
val (accessRef, forwardedTpts, forwardedArgss) =
4647
if (passReceiverAsArg(accessor.name))
47-
(argss.head.head.select(accessed), tps.takeRight(numTypeParams), argss.tail)
48+
(argss.head.head.select(accessed), targs.takeRight(numTypeParams), argss.tail)
4849
else
4950
(if (accessed.isStatic) ref(accessed) else ref(TermRef(cls.thisType, accessed)),
50-
tps, argss)
51+
targs, argss)
5152
val rhs =
5253
if (accessor.name.isSetterName &&
5354
forwardedArgss.nonEmpty && forwardedArgss.head.nonEmpty) // defensive conditions

compiler/src/dotty/tools/dotc/transform/ElimRepeated.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,12 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
222222
.get
223223
.symbol.asTerm
224224
// Generate the method
225-
val forwarderDef = polyDefDef(forwarderSym, trefs => vrefss => {
226-
val init :+ (last :+ vararg) = vrefss
225+
val forwarderDef = DefDef(forwarderSym, prefss => {
226+
val init :+ (last :+ vararg) = prefss
227227
// Can't call `.argTypes` here because the underlying array type is of the
228228
// form `Array[? <: SomeType]`, so we need `.argInfos` to get the `TypeBounds`.
229229
val elemtp = vararg.tpe.widen.argInfos.head
230230
ref(sym.termRef)
231-
.appliedToTypeTrees(trefs)
232231
.appliedToArgss(init)
233232
.appliedToTermArgs(last :+ wrapArray(vararg, elemtp))
234233
})

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,14 @@ class FirstTransform extends MiniPhase with InfoTransformer { thisPhase =>
117117
override def transformTemplate(impl: Template)(using Context): Tree =
118118
cpy.Template(impl)(self = EmptyValDef)
119119

120-
override def transformDefDef(ddef: DefDef)(using Context): Tree = {
120+
override def transformDefDef(ddef: DefDef)(using Context): Tree =
121121
val meth = ddef.symbol.asTerm
122-
if (meth.hasAnnotation(defn.NativeAnnot)) {
122+
if meth.hasAnnotation(defn.NativeAnnot) then
123123
meth.resetFlag(Deferred)
124-
polyDefDef(meth,
125-
_ => _ => ref(defn.Sys_error.termRef).withSpan(ddef.span)
124+
DefDef(meth, _ =>
125+
ref(defn.Sys_error.termRef).withSpan(ddef.span)
126126
.appliedTo(Literal(Constant(s"native method stub"))))
127-
}
128-
129127
else ddef
130-
}
131128

132129
override def transformStats(trees: List[Tree])(using Context): List[Tree] =
133130
ast.Trees.flatten(atPhase(thisPhase.next)(reorderAndComplete(trees)))

compiler/src/dotty/tools/dotc/transform/FullParameterization.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ trait FullParameterization {
148148
* of class that contained original defDef
149149
*/
150150
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(using Context): Tree =
151-
polyDefDef(derived, trefs => vrefss => {
151+
DefDef(derived, prefss => {
152+
val (trefs, vrefss) = splitArgs(prefss)
152153
val origMeth = originalDef.symbol
153154
val origClass = origMeth.enclosingClass.asClass
154155
val origLeadingTypeParamSyms = allInstanceTypeParams(originalDef, abstractOverClass)

compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ class HoistSuperArgs extends MiniPhase with IdentityDenotTransformer { thisPhase
129129
cpy.Apply(arg)(fn, hoistSuperArg(arg1, cdef) :: Nil)
130130
case _ if arg.existsSubTree(needsHoist) =>
131131
val superMeth = newSuperArgMethod(arg.tpe)
132-
val superArgDef = polyDefDef(superMeth, trefs => vrefss => {
133-
val paramSyms = trefs.map(_.tpe.typeSymbol) ::: vrefss.flatten.map(_.symbol)
132+
val superArgDef = DefDef(superMeth, prefss => {
133+
val paramSyms = prefss.flatten.map(pref =>
134+
if pref.isType then pref.tpe.typeSymbol else pref.symbol)
134135
val tmap = new TreeTypeMap(
135136
typeMap = new TypeMap {
136137
lazy val origToParam = origParams.zip(paramSyms).toMap

compiler/src/dotty/tools/dotc/transform/Mixin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class Mixin extends MiniPhase with SymTransformer { thisPhase =>
285285
for (meth <- mixin.info.decls.toList if needsMixinForwarder(meth))
286286
yield {
287287
util.Stats.record("mixin forwarders")
288-
transformFollowing(polyDefDef(mkForwarderSym(meth.asTerm, Bridge), forwarderRhsFn(meth)))
288+
transformFollowing(DefDef(mkForwarderSym(meth.asTerm, Bridge), forwarderRhsFn(meth)))
289289
}
290290

291291
cpy.Template(impl)(

compiler/src/dotty/tools/dotc/transform/MixinOps.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,17 @@ class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
7777
final val PrivateOrAccessor: FlagSet = Private | Accessor
7878
final val PrivateOrAccessorOrDeferred: FlagSet = Private | Accessor | Deferred
7979

80-
def forwarderRhsFn(target: Symbol): List[Tree] => List[List[Tree]] => Tree = {
81-
targs => vrefss =>
80+
def forwarderRhsFn(target: Symbol): List[List[Tree]] => Tree =
81+
prefss =>
82+
val (targs, vargss) = splitArgs(prefss)
8283
val tapp = superRef(target).appliedToTypeTrees(targs)
83-
vrefss match {
84+
vargss match
8485
case Nil | List(Nil) =>
8586
// Overriding is somewhat loose about `()T` vs `=> T`, so just pick
8687
// whichever makes sense for `target`
8788
tapp.ensureApplied
8889
case _ =>
89-
tapp.appliedToArgss(vrefss)
90-
}
91-
}
90+
tapp.appliedToArgss(vargss)
9291

9392
private def competingMethodsIterator(meth: Symbol): Iterator[Symbol] =
9493
cls.baseClasses.iterator

compiler/src/dotty/tools/dotc/transform/ResolveSuper.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class ResolveSuper extends MiniPhase with IdentityDenotTransformer { thisPhase =
4949
for (superAcc <- mixin.info.decls.filter(_.isSuperAccessor))
5050
yield {
5151
util.Stats.record("super accessors")
52-
polyDefDef(mkForwarderSym(superAcc.asTerm), forwarderRhsFn(rebindSuper(cls, superAcc)))
52+
DefDef(mkForwarderSym(superAcc.asTerm), forwarderRhsFn(rebindSuper(cls, superAcc)))
5353
}
5454

5555
val overrides = mixins.flatMap(superAccessors)
@@ -64,7 +64,7 @@ class ResolveSuper extends MiniPhase with IdentityDenotTransformer { thisPhase =
6464
val cls = meth.owner.asClass
6565
val ops = new MixinOps(cls, thisPhase)
6666
import ops._
67-
polyDefDef(meth, forwarderRhsFn(rebindSuper(cls, meth)))
67+
DefDef(meth, forwarderRhsFn(rebindSuper(cls, meth)))
6868
}
6969
else ddef
7070
}

0 commit comments

Comments
 (0)