Skip to content

Commit 0d8b6a9

Browse files
Lucy MartinWojciechMazur
Lucy Martin
authored andcommitted
#20105: Adding a warning to the case where nested named definitions contain non-tail recursive calls.
Code will now compile where a child def calls the parent def in a non-tail position (with the warning). Code will no longer compile if all calls to a @tailrec method are in named child methods (as these do not tail recurse). [Cherry-picked 01ada74]
1 parent b7b0d9b commit 0d8b6a9

File tree

8 files changed

+83
-4
lines changed

8 files changed

+83
-4
lines changed

compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
212212
case ContextBoundCompanionNotValueID // errorNumber: 196 - unused in LTS
213213
case InlinedAnonClassWarningID // errorNumber: 197
214214
case UnusedSymbolID // errorNumber: 198
215+
case TailrecNestedCallID //errorNumber: 199
215216

216217
def errorNumber = ordinal - 1
217218

compiler/src/dotty/tools/dotc/reporting/messages.scala

+14
Original file line numberDiff line numberDiff line change
@@ -1899,6 +1899,20 @@ class TailrecNotApplicable(symbol: Symbol)(using Context)
18991899
def explain(using Context) = ""
19001900
}
19011901

1902+
class TailrecNestedCall(definition: Symbol, innerDef: Symbol)(using Context)
1903+
extends SyntaxMsg(TailrecNestedCallID) {
1904+
def msg(using Context) = {
1905+
s"The tail recursive def ${definition.name} contains a recursive call inside the non-inlined inner def ${innerDef.name}"
1906+
}
1907+
1908+
def explain(using Context) =
1909+
"""Tail recursion is only validated and optimised directly in the definition.
1910+
|Any calls to the recursive method via an inner def cannot be validated as
1911+
|tail recursive, nor optimised if they are. To enable tail recursion from
1912+
|inner calls, mark the inner def as inline.
1913+
|""".stripMargin
1914+
}
1915+
19021916
class FailureToEliminateExistential(tp: Type, tp1: Type, tp2: Type, boundSyms: List[Symbol], classRoot: Symbol)(using Context)
19031917
extends Message(FailureToEliminateExistentialID) {
19041918
def kind = MessageKind.Compatibility

compiler/src/dotty/tools/dotc/transform/TailRec.scala

+29-2
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,23 @@ class TailRec extends MiniPhase {
427427
assert(false, "We should never have gotten inside a pattern")
428428
tree
429429

430-
case tree: ValOrDefDef =>
430+
case tree: ValDef =>
431431
if (isMandatory) noTailTransform(tree.rhs)
432432
tree
433433

434+
case tree: DefDef =>
435+
if (isMandatory)
436+
if (tree.symbol.is(Synthetic))
437+
noTailTransform(tree.rhs)
438+
else
439+
// We can't tail recurse through nested definitions, so don't want to propagate to child nodes
440+
// We don't want to fail if there is a call that would recurse (as this would be a non self recurse), so don't
441+
// want to call noTailTransform
442+
// We can however warn in this case, as its likely in this situation that someone would expect a tail
443+
// recursion optimization and enabling this to optimise would be a simple case of inlining the inner method
444+
new NestedTailRecAlerter(method, tree.symbol).traverse(tree)
445+
tree
446+
434447
case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
435448
tree
436449

@@ -444,14 +457,28 @@ class TailRec extends MiniPhase {
444457

445458
case Return(expr, from) =>
446459
val fromSym = from.symbol
447-
val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)
460+
val inTailPosition = tailPositionLabeledSyms.contains(fromSym) // Label returns are only tail if the label is in tail position
461+
|| (fromSym eq method) // Method returns are only tail if we are looking at the original method
448462
cpy.Return(tree)(transform(expr, inTailPosition), from)
449463

450464
case _ =>
451465
super.transform(tree)
452466
}
453467
}
454468
}
469+
470+
class NestedTailRecAlerter(method: Symbol, inner: Symbol) extends TreeTraverser {
471+
override def traverse(tree: tpd.Tree)(using Context): Unit =
472+
tree match {
473+
case a: Apply =>
474+
if (a.fun.symbol eq method) {
475+
report.warning(new TailrecNestedCall(method, inner), a.srcPos)
476+
}
477+
traverseChildren(tree)
478+
case _ =>
479+
traverseChildren(tree)
480+
}
481+
}
455482
}
456483

457484
object TailRec {

tests/neg/i20105.check

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
-- [E199] Syntax Warning: tests/neg/i20105.scala:6:9 -------------------------------------------------------------------
2+
6 | foo()
3+
| ^^^^^
4+
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
5+
|
6+
| longer explanation available when compiling with `-explain`
7+
-- [E097] Syntax Error: tests/neg/i20105.scala:3:4 ---------------------------------------------------------------------
8+
3 |def foo(): Unit = // error
9+
| ^
10+
| TailRec optimisation not applicable, method foo contains no recursive calls

tests/neg/i20105.scala

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.annotation.tailrec
2+
@tailrec
3+
def foo(): Unit = // error
4+
def bar(): Unit =
5+
if (???)
6+
foo()
7+
else
8+
bar()
9+
bar()

tests/neg/i5397.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ object Test {
1616
rec3 // error: not in tail position
1717
})
1818

19-
@tailrec def rec4: Unit = {
20-
def local = rec4 // error: not in tail position
19+
// This is technically not breaching tail recursion as rec4 does not call itself, local does
20+
// This instead fails due to having no tail recursion at all
21+
@tailrec def rec4: Unit = { // error: no recursive calls
22+
def local = rec4
2123
}
2224

2325
@tailrec def rec5: Int = {

tests/warn/i20105.check

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-- [E199] Syntax Warning: tests/warn/i20105.scala:6:9 ------------------------------------------------------------------
2+
6 | foo() // warn
3+
| ^^^^^
4+
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
5+
|
6+
| longer explanation available when compiling with `-explain`

tests/warn/i20105.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.annotation.tailrec
2+
@tailrec
3+
def foo(): Unit =
4+
def bar(): Unit =
5+
if (???)
6+
foo() // warn
7+
else
8+
bar()
9+
bar()
10+
foo()

0 commit comments

Comments
 (0)