Skip to content

Commit eef65e1

Browse files
committed
Use single parameter list for tpd.DefDef methods
1 parent 28e9af3 commit eef65e1

File tree

3 files changed

+69
-8
lines changed

3 files changed

+69
-8
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,17 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
206206
def SyntheticValDef(name: TermName, rhs: Tree)(using Context): ValDef =
207207
ValDef(newSymbol(ctx.owner, name, Synthetic, rhs.tpe.widen, coord = rhs.span), rhs)
208208

209-
def DefDef(sym: TermSymbol, tparams: List[TypeSymbol], vparamss: List[List[TermSymbol]],
209+
def DefDef(sym: TermSymbol, paramss: List[List[Symbol]],
210210
resultType: Type, rhs: Tree)(using Context): DefDef =
211-
sym.setParamss(tparams :: vparamss)
211+
sym.setParamss(paramss)
212212
ta.assignType(
213213
untpd.DefDef(
214214
sym.name,
215-
joinParams(
216-
tparams.map(tparam => TypeDef(tparam).withSpan(tparam.span)),
217-
vparamss.nestedMap(vparam => ValDef(vparam).withSpan(vparam.span))),
215+
paramss.map {
216+
case TypeSymbols(params) => params.map(param => TypeDef(param).withSpan(param.span))
217+
case TermSymbols(params) => params.map(param => ValDef(param).withSpan(param.span))
218+
case _ => ???
219+
},
218220
TypeTree(resultType),
219221
rhs),
220222
sym)
@@ -223,7 +225,57 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
223225
ta.assignType(DefDef(sym, Function.const(rhs) _), sym)
224226

225227
def DefDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Tree)(using Context): DefDef =
226-
polyDefDef(sym, Function.const(rhsFn))
228+
229+
// Map method type `tp` with remaining parameters stored in rawParamss to
230+
// final result type and all (given or synthesized) parameters
231+
def recur(tp: Type, remaining: List[List[Symbol]]): (Type, List[List[Symbol]]) = tp match
232+
case tp: PolyType =>
233+
val (tparams: List[TypeSymbol], remaining1) = remaining match
234+
case tparams :: remaining1 =>
235+
assert(tparams.hasSameLengthAs(tp.paramNames) && tparams.head.isType)
236+
(tparams.asInstanceOf[List[TypeSymbol]], remaining1)
237+
case nil =>
238+
(newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_)), Nil)
239+
val (rtp, paramss) = recur(tp.instantiate(tparams.map(_.typeRef)), remaining1)
240+
(rtp, tparams :: paramss)
241+
case tp: MethodType =>
242+
val isParamDependent = tp.isParamDependent
243+
val previousParamRefs = if isParamDependent then mutable.ListBuffer[TermRef]() else null
244+
245+
def valueParam(name: TermName, origInfo: Type): TermSymbol =
246+
val maybeImplicit =
247+
if tp.isContextualMethod then Given
248+
else if tp.isImplicitMethod then Implicit
249+
else EmptyFlags
250+
val maybeErased = if tp.isErasedMethod then Erased else EmptyFlags
251+
252+
def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)
253+
254+
if isParamDependent then
255+
val sym = makeSym(origInfo.substParams(tp, previousParamRefs.toList))
256+
previousParamRefs += sym.termRef
257+
sym
258+
else makeSym(origInfo)
259+
end valueParam
260+
261+
val (vparams: List[TermSymbol], remaining1) =
262+
if tp.paramNames.isEmpty then (Nil, remaining)
263+
else remaining match
264+
case vparams :: remaining1 =>
265+
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
266+
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
267+
case nil =>
268+
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
269+
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
270+
(rtp, vparams :: paramss)
271+
case _ =>
272+
assert(remaining.isEmpty)
273+
(tp.widenExpr, Nil)
274+
end recur
275+
276+
val (rtp, paramss) = recur(sym.info, sym.rawParamss)
277+
DefDef(sym, paramss, rtp, rhsFn(paramss.nestedMap(ref)))
278+
end DefDef
227279

228280
/** A DefDef with given method symbol `sym`.
229281
* @rhsFn A function from type parameter types and term parameter references
@@ -285,7 +337,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
285337
val targs = tparams.map(tparam => ref(tparam.typeRef))
286338
val argss = vparamss.nestedMap(vparam => Ident(vparam.termRef))
287339
sym.setParamss(tparams :: vparamss)
288-
DefDef(sym, tparams, vparamss, rtp, rhsFn(targs)(argss))
340+
DefDef(sym, joinSymbols(tparams, vparamss), rtp, rhsFn(targs)(argss))
289341
}
290342

291343
def TypeDef(sym: TypeSymbol)(using Context): TypeDef =

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,16 +845,25 @@ object Symbols {
845845
copies
846846
}
847847

848+
/** Matches lists of term symbols, including the empty list.
849+
* All symbols in the list are assumed to be of the same kind.
850+
*/
848851
object TermSymbols:
849852
def unapply(xs: List[Symbol])(using Context): Option[List[TermSymbol]] = xs match
850853
case (x: Symbol) :: _ if x.isType => None
851854
case _ => Some(xs.asInstanceOf[List[TermSymbol]])
852855

856+
/** Matches lists of type symbols, excluding the empty list.
857+
* All symbols in the list are assumed to be of the same kind.
858+
*/
853859
object TypeSymbols:
854860
def unapply(xs: List[Symbol])(using Context): Option[List[TypeSymbol]] = xs match
855861
case (x: Symbol) :: _ if x.isType => Some(xs.asInstanceOf[List[TypeSymbol]])
856862
case _ => None
857863

864+
def joinSymbols(xs: List[Symbol], ys: List[List[Symbol]]): List[List[Symbol]] =
865+
if xs.isEmpty then ys else xs :: ys
866+
858867
// ----- Locating predefined symbols ----------------------------------------
859868

860869
def requiredPackage(path: PreName)(using Context): TermSymbol = {

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class ExpandSAMs extends MiniPhase:
156156
}
157157

158158
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
159-
val List(paramRef, defaultRef) = paramRefss.head
159+
val List(paramRef, defaultRef) = paramRefss(1)
160160
def translateCase(cdef: CaseDef) =
161161
cdef.changeOwner(anonSym, applyOrElseFn)
162162
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)

0 commit comments

Comments
 (0)