Skip to content

Commit a841499

Browse files
author
Lucy Martin
committed
Preventing compilation of a @tailrec method when it does not rewrite, but an inner method does
Adding warnings if there is an annotated def at the top level that is referenced from an inner def Potential options for different handling of defined and implicit inner methods Changes from PR.
1 parent adf089b commit a841499

File tree

8 files changed

+83
-4
lines changed

8 files changed

+83
-4
lines changed

compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala

+1
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
208208
case UnstableInlineAccessorID // errorNumber: 192
209209
case VolatileOnValID // errorNumber: 193
210210
case ExtensionNullifiedByMemberID // errorNumber: 194
211+
case TailrecNestedCallID //errorNumber: 195
211212

212213
def errorNumber = ordinal - 1
213214

compiler/src/dotty/tools/dotc/reporting/messages.scala

+14
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,20 @@ class TailrecNotApplicable(symbol: Symbol)(using Context)
19071907
def explain(using Context) = ""
19081908
}
19091909

1910+
class TailrecNestedCall(definition: Symbol, innerDef: Symbol)(using Context)
1911+
extends SyntaxMsg(TailrecNestedCallID) {
1912+
def msg(using Context) = {
1913+
s"The tail recursive def ${definition.name} contains a recursive call inside the non-inlined inner def ${innerDef.name}"
1914+
}
1915+
1916+
def explain(using Context) =
1917+
"""Tail recursion is only validated and optimised directly in the definition.
1918+
|Any calls to the recursive method via an inner def cannot be validated as
1919+
|tail recursive, nor optimised if they are. To enable tail recursion from
1920+
|inner calls, mark the inner def as inline.
1921+
|""".stripMargin
1922+
}
1923+
19101924
class FailureToEliminateExistential(tp: Type, tp1: Type, tp2: Type, boundSyms: List[Symbol], classRoot: Symbol)(using Context)
19111925
extends Message(FailureToEliminateExistentialID) {
19121926
def kind = MessageKind.Compatibility

compiler/src/dotty/tools/dotc/transform/TailRec.scala

+29-2
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,23 @@ class TailRec extends MiniPhase {
429429
assert(false, "We should never have gotten inside a pattern")
430430
tree
431431

432-
case tree: ValOrDefDef =>
432+
case tree: ValDef =>
433433
if (isMandatory) noTailTransform(tree.rhs)
434434
tree
435435

436+
case tree: DefDef =>
437+
if (isMandatory)
438+
if (tree.symbol.is(Synthetic))
439+
noTailTransform(tree.rhs)
440+
else
441+
// We can't tail recurse through nested definitions, so don't want to propagate to child nodes
442+
// We don't want to fail if there is a call that would recurse (as this would be a non self recurse), so don't
443+
// want to call noTailTransform
444+
// We can however warn in this case, as its likely in this situation that someone would expect a tail
445+
// recursion optimization and enabling this to optimise would be a simple case of inlining the inner method
446+
new NestedTailRecAlerter(method, tree.symbol).traverse(tree)
447+
tree
448+
436449
case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
437450
tree
438451

@@ -446,14 +459,28 @@ class TailRec extends MiniPhase {
446459

447460
case Return(expr, from) =>
448461
val fromSym = from.symbol
449-
val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)
462+
val inTailPosition = tailPositionLabeledSyms.contains(fromSym) // Label returns are only tail if the label is in tail position
463+
|| (fromSym eq method) // Method returns are only tail if we are looking at the original method
450464
cpy.Return(tree)(transform(expr, inTailPosition), from)
451465

452466
case _ =>
453467
super.transform(tree)
454468
}
455469
}
456470
}
471+
472+
class NestedTailRecAlerter(method: Symbol, inner: Symbol) extends TreeTraverser {
473+
override def traverse(tree: tpd.Tree)(using Context): Unit =
474+
tree match {
475+
case a: Apply =>
476+
if (a.fun.symbol eq method) {
477+
report.warning(new TailrecNestedCall(method, inner), a.srcPos)
478+
}
479+
traverseChildren(tree)
480+
case _ =>
481+
traverseChildren(tree)
482+
}
483+
}
457484
}
458485

459486
object TailRec {

tests/neg/i20105.check

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
-- [E195] Syntax Warning: tests\neg\i20105.scala:6:9 -------------------------------------------------------------------
2+
6 | foo()
3+
| ^^^^^
4+
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
5+
|
6+
| longer explanation available when compiling with `-explain`
7+
-- [E097] Syntax Error: tests\neg\i20105.scala:3:4 ---------------------------------------------------------------------
8+
3 |def foo(): Unit = // error
9+
| ^
10+
| TailRec optimisation not applicable, method foo contains no recursive calls

tests/neg/i20105.scala

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.annotation.tailrec
2+
@tailrec
3+
def foo(): Unit = // error
4+
def bar(): Unit =
5+
if (???)
6+
foo()
7+
else
8+
bar()
9+
bar()

tests/neg/i5397.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ object Test {
1616
rec3 // error: not in tail position
1717
})
1818

19-
@tailrec def rec4: Unit = {
20-
def local = rec4 // error: not in tail position
19+
// This is technically not breaching tail recursion as rec4 does not call itself, local does
20+
// This instead fails due to having no tail recursion at all
21+
@tailrec def rec4: Unit = { // error: no recursive calls
22+
def local = rec4
2123
}
2224

2325
@tailrec def rec5: Int = {

tests/warn/i20105.check

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-- [E195] Syntax Warning: tests\warn\i20105.scala:6:9 ------------------------------------------------------------------
2+
6 | foo() // warn
3+
| ^^^^^
4+
| The tail recursive def foo contains a recursive call inside the non-inlined inner def bar
5+
|
6+
| longer explanation available when compiling with `-explain`

tests/warn/i20105.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.annotation.tailrec
2+
@tailrec
3+
def foo(): Unit =
4+
def bar(): Unit =
5+
if (???)
6+
foo() // warn
7+
else
8+
bar()
9+
bar()
10+
foo()

0 commit comments

Comments
 (0)