Skip to content

Commit c2aa05f

Browse files
committed
Use Labeled blocks in TailRec, instead of label-defs.
It's easier to first explain on an example. Consider the following tail-recursive method: def fact(n: Int, acc: Int): Int = if (n == 0) acc else fact(n - 1, n * acc) It is now translated as follows by the `tailrec` transform: def fact(n: Int, acc: Int): Int = { var n$tailLocal1: Int = n var acc$tailLocal1: Int = acc while (true) { tailLabel1[Unit]: { return { if (n$tailLocal1 == 0) { acc } else { val n$tailLocal1$tmp1: Int = n$tailLocal1 - 1 val acc$tailLocal1$tmp1: Int = n$tailLocal1 * acc$tailLocal1 n$tailLocal1 = n$tailLocal1$tmp1 acc$tailLocal1 = acc$tailLocal1$tmp1 (return[tailLabel1] ()): Int } } } } throw null // unreachable code } First, we allocate local `var`s for every parameter, as well as `this` if necessary. When we find a tail-recursive call, we evaluate the arguments into temporaries, then assign them to the `var`s. It is necessary to use temporaries in order not to use the new contents of a param local when computing the new value of another param local. We avoid reassigning param locals if their rhs (i.e., the actual argument to the recursive call) is itself, which does happen quite often in practice. In particular, we thus avoid reassigning the local var for `this` if the prefix is empty. We could further optimize this by avoiding the reassignment if the prefix is non-empty but equivalent to `this`. If only one parameter ends up changing value in any particular tail-recursive call, we can avoid the temporaries and directly assign it. This is also a fairly common situation, especially after discarding useless assignments to the local for `this`. After all that, we `return` from a labeled block, which is right inside an infinite `while` loop. The net result is to loop back to the beginning, implementing the jump. The `return` node is explicitly ascribed with the previous result type, so that lubs upstream are not affected (not doing so can cause Ycheck errors). For control flows that do *not* end up in a tail-recursive call, the result value is given to an explicit `return` out of the enclosing method, which prevents the looping. There is one pretty ugly artifact: after the `while` loop, we must insert a `throw null` for the body to still typecheck as an `Int` (the result type of the `def`). This could be avoided if we dared type a `WhileDo(Literal(Constant(true)), body)` as having type `Nothing` rather than `Unit`. This is probably dangerous, though, as we have no guarantee that further transformations will leave the `true` alone, especially in the presence of compiler plugins. If the `true` gets wrapped in any way, the type of the `WhileDo` will be altered, and chaos will ensue. In the future, we could enhance the codegen to avoid emitting that dead code. This should not be too difficult: * emitting a `WhileDo` whose argument is `true` would set the generated `BType` to `Nothing`. * then, when emitting a `Block`, we would drop any statements and expr following a statement whose generated `BType` was `Nothing`. This commit does not go to such lengths, however. This change removes the last source of label-defs in the compiler. After this commit, we will be able to entirely remove label-defs.
1 parent f102203 commit c2aa05f

File tree

2 files changed

+146
-86
lines changed

2 files changed

+146
-86
lines changed

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ object NameKinds {
288288
val NonLocalReturnKeyName = new UniqueNameKind("nonLocalReturnKey")
289289
val WildcardParamName = new UniqueNameKind("_$")
290290
val TailLabelName = new UniqueNameKind("tailLabel")
291+
val TailLocalName = new UniqueNameKind("$tailLocal")
292+
val TailTempName = new UniqueNameKind("$tmp")
291293
val ExceptionBinderName = new UniqueNameKind("ex")
292294
val SkolemName = new UniqueNameKind("?")
293295
val LiftedTreeName = new UniqueNameKind("liftedTree")

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 144 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ package transform
44
import ast.Trees._
55
import ast.{TreeTypeMap, tpd}
66
import core._
7+
import Constants.Constant
78
import Contexts.Context
89
import Decorators._
910
import Symbols._
1011
import StdNames.nme
1112
import Types._
12-
import NameKinds.TailLabelName
13+
import NameKinds.{TailLabelName, TailLocalName, TailTempName}
1314
import MegaPhase.MiniPhase
1415
import reporting.diagnostic.messages.TailrecNotApplicable
1516
import util.Property
@@ -18,7 +19,8 @@ import util.Property
1819
* A Tail Rec Transformer
1920
* @author Erik Stenman, Iulian Dragos,
2021
* ported and heavily modified for dotty by Dmitry Petrashko
21-
* moved after erasure by Sébastien Doeraene
22+
* moved after erasure and adapted to emit `Labeled` blocks
23+
* by Sébastien Doeraene
2224
* @version 1.1
2325
*
2426
* What it does:
@@ -33,10 +35,29 @@ import util.Property
3335
* contain such calls are not transformed).
3436
* </p>
3537
* <p>
36-
* Self-recursive calls in tail-position are replaced by jumps to a
37-
* label at the beginning of the method. As the JVM provides no way to
38-
* jump from a method to another one, non-recursive calls in
39-
* tail-position are not optimized.
38+
* When a method contains at least one tail-recursive call, its rhs
39+
* is wrapped in the following structure:
40+
* </p>
41+
* <pre>
42+
* var localForParam1: T1 = param1
43+
* ...
44+
* while (true) {
45+
* tailResult[ResultType]: {
46+
* return {
47+
* // original rhs
48+
* }
49+
* }
50+
* }
51+
* </pre>
52+
* <p>
53+
* Self-recursive calls in tail-position are then replaced by (a)
54+
* reassigning the local `var`s substituting formal parameters and
55+
* (b) a `return` from the `tailResult` labeled block, which has the
56+
* net effect of looping back to the beginning of the method.
57+
* </p>
58+
* <p>
59+
* As the JVM provides no way to jump from a method to another one,
60+
* non-recursive calls in tail-position are not optimized.
4061
* </p>
4162
* <p>
4263
* A method call is self-recursive if it calls the current method and
@@ -47,14 +68,8 @@ import util.Property
4768
* </p>
4869
* <p>
4970
* This phase has been moved after erasure to allow the use of vars
50-
* for the parameters combined with a `WhileDo` (upcoming change).
51-
* This is also beneficial to support polymorphic tail-recursive
52-
* calls.
53-
* </p>
54-
* <p>
55-
* If a method contains self-recursive calls, a label is added to at
56-
* the beginning of its body and the calls are replaced by jumps to
57-
* that label.
71+
* for the parameters combined with a `WhileDo`. This is also
72+
* beneficial to support polymorphic tail-recursive calls.
5873
* </p>
5974
* <p>
6075
* In scalac, if the method had type parameters, the call must contain
@@ -73,25 +88,6 @@ class TailRec extends MiniPhase {
7388

7489
override def runsAfter = Set(Erasure.name) // tailrec assumes erased types
7590

76-
final val labelFlags = Flags.Synthetic | Flags.Label | Flags.Method
77-
78-
private def mkLabel(method: Symbol)(implicit ctx: Context): TermSymbol = {
79-
val name = TailLabelName.fresh()
80-
81-
if (method.owner.isClass) {
82-
val MethodTpe(paramNames, paramInfos, resultType) = method.info
83-
84-
val enclosingClass = method.enclosingClass.asClass
85-
val thisParamType =
86-
if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
87-
else enclosingClass.classInfo.selfType
88-
89-
ctx.newSymbol(method, name.toTermName, labelFlags,
90-
MethodType(nme.SELF :: paramNames, thisParamType :: paramInfos, resultType))
91-
}
92-
else ctx.newSymbol(method, name.toTermName, labelFlags, method.info)
93-
}
94-
9591
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = {
9692
val sym = tree.symbol
9793
tree match {
@@ -101,62 +97,60 @@ class TailRec extends MiniPhase {
10197
cpy.DefDef(dd)(rhs = {
10298
val defIsTopLevel = sym.owner.isClass
10399
val origMeth = sym
104-
val label = mkLabel(sym)
105100
val owner = ctx.owner.enclosingClass.asClass
106101

107-
var rewrote = false
108-
109102
// Note: this can be split in two separate transforms(in different groups),
110103
// than first one will collect info about which transformations and rewritings should be applied
111104
// and second one will actually apply,
112105
// now this speculatively transforms tree and throws away result in many cases
113-
val rhsSemiTransformed = {
114-
val transformer = new TailRecElimination(origMeth, owner, mandatory, label)
115-
val rhs = transformer.transform(dd.rhs)
116-
rewrote = transformer.rewrote
117-
rhs
118-
}
119-
120-
if (rewrote) {
121-
if (tree.symbol.owner.isClass) {
122-
val classSym = tree.symbol.owner.asClass
123-
124-
val labelDef = DefDef(label, vrefss => {
125-
assert(vrefss.size == 1, vrefss)
126-
val vrefs = vrefss.head
127-
val thisRef = vrefs.head
128-
val origMeth = tree.symbol
129-
val origVParams = vparams.map(_.symbol)
106+
val transformer = new TailRecElimination(origMeth, owner, vparams.map(_.symbol), mandatory)
107+
val rhsSemiTransformed = transformer.transform(dd.rhs)
108+
109+
if (transformer.rewrote) {
110+
val varForRewrittenThis = transformer.varForRewrittenThis
111+
val rewrittenParamSyms = transformer.rewrittenParamSyms
112+
val varsForRewrittenParamSyms = transformer.varsForRewrittenParamSyms
113+
114+
val initialValDefs = {
115+
val initialParamValDefs = for ((param, local) <- rewrittenParamSyms.zip(varsForRewrittenParamSyms)) yield {
116+
ValDef(local.asTerm, ref(param))
117+
}
118+
varForRewrittenThis match {
119+
case Some(local) => ValDef(local.asTerm, This(tree.symbol.owner.asClass)) :: initialParamValDefs
120+
case none => initialParamValDefs
121+
}
122+
}
123+
124+
val rhsFullyTransformed = varForRewrittenThis match {
125+
case Some(localThisSym) =>
126+
val classSym = tree.symbol.owner.asClass
127+
val thisRef = localThisSym.termRef
130128
new TreeTypeMap(
131-
typeMap = identity(_)
132-
.substThisUnlessStatic(classSym, thisRef.tpe)
133-
.subst(origVParams, vrefs.tail.map(_.tpe)),
129+
typeMap = _.substThisUnlessStatic(classSym, thisRef)
130+
.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)),
134131
treeMap = {
135-
case tree: This if tree.symbol == classSym => thisRef
132+
case tree: This if tree.symbol == classSym => Ident(thisRef)
136133
case tree => tree
137-
},
138-
oldOwners = origMeth :: Nil,
139-
newOwners = label :: Nil
134+
}
140135
).transform(rhsSemiTransformed)
141-
})
142-
val callIntoLabel = ref(label).appliedToArgs(This(classSym) :: vparams.map(x => ref(x.symbol)))
143-
Block(List(labelDef), callIntoLabel)
144-
} else { // inner method. Tail recursion does not change `this`
145-
val labelDef = DefDef(label, vrefss => {
146-
assert(vrefss.size == 1, vrefss)
147-
val vrefs = vrefss.head
148-
val origMeth = tree.symbol
149-
val origVParams = vparams.map(_.symbol)
136+
137+
case none =>
150138
new TreeTypeMap(
151-
typeMap = identity(_)
152-
.subst(origVParams, vrefs.map(_.tpe)),
153-
oldOwners = origMeth :: Nil,
154-
newOwners = label :: Nil
139+
typeMap = _.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef))
155140
).transform(rhsSemiTransformed)
156-
})
157-
val callIntoLabel = ref(label).appliedToArgs(vparams.map(x => ref(x.symbol)))
158-
Block(List(labelDef), callIntoLabel)
159-
}} else {
141+
}
142+
143+
Block(
144+
initialValDefs :::
145+
WhileDo(Literal(Constant(true)), {
146+
Labeled(transformer.continueLabel.get.asTerm, {
147+
Return(rhsFullyTransformed, ref(origMeth))
148+
})
149+
}) ::
150+
Nil,
151+
Throw(Literal(Constant(null))) // unreachable code
152+
)
153+
} else {
160154
if (mandatory) ctx.error(
161155
"TailRec optimisation not applicable, method not tail recursive",
162156
// FIXME: want to report this error on `dd.namePos`, but
@@ -175,12 +169,51 @@ class TailRec extends MiniPhase {
175169

176170
}
177171

178-
class TailRecElimination(method: Symbol, enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
172+
class TailRecElimination(method: Symbol, enclosingClass: Symbol, paramSyms: List[Symbol], isMandatory: Boolean) extends tpd.TreeMap {
179173

180174
import dotty.tools.dotc.ast.tpd._
181175

182176
var rewrote = false
183177

178+
var continueLabel: Option[Symbol] = None
179+
var varForRewrittenThis: Option[Symbol] = None
180+
var rewrittenParamSyms: List[Symbol] = Nil
181+
var varsForRewrittenParamSyms: List[Symbol] = Nil
182+
183+
private def getContinueLabel()(implicit c: Context): Symbol = {
184+
continueLabel match {
185+
case Some(sym) => sym
186+
case none =>
187+
val sym = c.newSymbol(method, TailLabelName.fresh(), Flags.Label, defn.UnitType)
188+
continueLabel = Some(sym)
189+
sym
190+
}
191+
}
192+
193+
private def getVarForRewrittenThis()(implicit c: Context): Symbol = {
194+
varForRewrittenThis match {
195+
case Some(sym) => sym
196+
case none =>
197+
val tpe =
198+
if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
199+
else enclosingClass.asClass.classInfo.selfType
200+
val sym = c.newSymbol(method, nme.SELF, Flags.Synthetic | Flags.Mutable, tpe)
201+
varForRewrittenThis = Some(sym)
202+
sym
203+
}
204+
}
205+
206+
private def getVarForRewrittenParam(param: Symbol)(implicit c: Context): Symbol = {
207+
rewrittenParamSyms.indexOf(param) match {
208+
case -1 =>
209+
val sym = c.newSymbol(method, TailLocalName.fresh(param.name.toTermName), Flags.Synthetic | Flags.Mutable, param.info)
210+
rewrittenParamSyms ::= param
211+
varsForRewrittenParamSyms ::= sym
212+
sym
213+
case index => varsForRewrittenParamSyms(index)
214+
}
215+
}
216+
184217
/** Symbols of Labeled blocks that are in tail position. */
185218
private val tailPositionLabeledSyms = new collection.mutable.HashSet[Symbol]()
186219

@@ -234,15 +267,40 @@ class TailRec extends MiniPhase {
234267
if (ctx.tailPos) {
235268
c.debuglog("Rewriting tail recursive call: " + tree.pos)
236269
rewrote = true
237-
def receiver =
238-
if (prefix eq EmptyTree) This(enclosingClass.asClass)
239-
else noTailTransform(prefix)
240270

241-
val argumentsWithReceiver =
242-
if (this.method.owner.isClass) receiver :: arguments
243-
else arguments
244-
245-
tpd.cpy.Apply(tree)(ref(label), argumentsWithReceiver)
271+
val assignParamPairs = for {
272+
(param, arg) <- paramSyms.zip(arguments)
273+
if (arg match {
274+
case arg: Ident => arg.symbol != param
275+
case _ => true
276+
})
277+
} yield {
278+
(getVarForRewrittenParam(param), arg)
279+
}
280+
281+
val assignThisAndParamPairs = {
282+
if (prefix eq EmptyTree) assignParamPairs
283+
else {
284+
// TODO Opt: also avoid assigning `this` if the prefix is `this.`
285+
(getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
286+
}
287+
}
288+
289+
val assignments = assignThisAndParamPairs match {
290+
case (lhs, rhs) :: Nil =>
291+
Assign(ref(lhs), rhs) :: Nil
292+
case _ :: _ =>
293+
val (tempValDefs, assigns) = (for ((lhs, rhs) <- assignThisAndParamPairs) yield {
294+
val temp = c.newSymbol(method, TailTempName.fresh(lhs.name.toTermName), Flags.Synthetic, lhs.info)
295+
(ValDef(temp, rhs), Assign(ref(lhs), ref(temp)).withPos(tree.pos))
296+
}).unzip
297+
tempValDefs ::: assigns
298+
case nil =>
299+
Nil
300+
}
301+
302+
val tpt = TypeTree(method.info.resultType)
303+
Block(assignments, Typed(Return(Literal(Constant(())).withPos(tree.pos), ref(getContinueLabel())), tpt))
246304
}
247305
else fail("it is not in tail position")
248306
} else {

0 commit comments

Comments
 (0)