Skip to content
This repository was archived by the owner on Sep 3, 2020. It is now read-only.

Commit 98c0a55

Browse files
committed
Add return type to ValOrDefDef with single expressions in braces
Fixing this was complicated because the compiler removes braces from the tree when the code would compile without them. Therefore we manually have to restore them. To get the braces (and whitespace) back, we search for an equal sign in the area of the DefDef (which indicates the start of the rhs). Once it is found, all contents between the equal sign and the start and the expression of the rhs are put back into the tree. Because an equal sign can also occur in a comment, we have to parse the region instead of simply looking for the equal sign. This adds some overhead, but hopefully not too much. Furthermore, as reviewers suggested, some variable names are renamed. This also fixes a bug in the test suite where the source file is passed to the refactoring logic.
1 parent 4f5dbd8 commit 98c0a55

File tree

6 files changed

+151
-31
lines changed

6 files changed

+151
-31
lines changed

org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package scala.tools.refactoring.common
22

3-
import scala.tools.nsc.Global
3+
import scala.collection.immutable
4+
import scala.collection.mutable.ArrayBuffer
5+
import scala.reflect.internal.util.SourceFile
6+
import scala.tools.nsc.ast.parser.Tokens
47

58
/*
6-
* FIXME: This class duplicates functionality from org.scalaide.core.compiler.CompilerApiExtensions.
9+
* FIXME: This class duplicates functionality from [[org.scalaide.core.compiler.CompilerApiExtensions]].
710
*/
811
trait CompilerApiExtensions {
912
this: CompilerAccess =>
@@ -55,4 +58,58 @@ trait CompilerApiExtensions {
5558
}
5659
}
5760
}
61+
62+
/** A helper class to access the lexical tokens of `source`.
63+
*
64+
* Once constructed, instances of this class are thread-safe.
65+
*/
66+
class LexicalStructure(source: SourceFile) {
67+
private val token = new ArrayBuffer[Int]
68+
private val startOffset = new ArrayBuffer[Int]
69+
private val endOffset = new ArrayBuffer[Int]
70+
private val scanner = new syntaxAnalyzer.UnitScanner(new CompilationUnit(source))
71+
scanner.init()
72+
73+
while (scanner.token != Tokens.EOF) {
74+
startOffset += scanner.offset
75+
token += scanner.token
76+
scanner.nextToken
77+
endOffset += scanner.lastOffset
78+
}
79+
80+
/** Return the index of the token that covers `offset`.
81+
*/
82+
private def locateIndex(offset: Int): Int = {
83+
var lo = 0
84+
var hi = token.length - 1
85+
while (lo < hi) {
86+
val mid = (lo + hi + 1) / 2
87+
if (startOffset(mid) <= offset) lo = mid
88+
else hi = mid - 1
89+
}
90+
lo
91+
}
92+
93+
/** Return all tokens between start and end offsets.
94+
*
95+
* The first token may start before `start` and the last token may span after `end`.
96+
*/
97+
def tokensBetween(start: Int, end: Int): immutable.Seq[Token] = {
98+
val startIndex = locateIndex(start)
99+
val endIndex = locateIndex(end)
100+
101+
val tmp = for (i <- startIndex to endIndex)
102+
yield Token(token(i), startOffset(i), endOffset(i))
103+
104+
tmp.toSeq
105+
}
106+
}
107+
108+
/** A Scala token covering [start, end)
109+
*
110+
* @param tokenId one of scala.tools.nsc.ast.parser.Tokens identifiers
111+
* @param start the offset of the first character in this token
112+
* @param end the offset of the first character after this token
113+
*/
114+
case class Token(tokenId: Int, start: Int, end: Int)
58115
}

org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import scala.reflect.internal.util.SourceFile
99

1010
trait AbstractPrinter extends CommonPrintUtils {
1111

12-
this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with Formatting =>
12+
this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with common.CompilerApiExtensions with Formatting =>
1313

1414
import global._
1515

@@ -18,6 +18,11 @@ trait AbstractPrinter extends CommonPrintUtils {
1818
* the context or environment for the current printing.
1919
*/
2020
case class PrintingContext(ind: Indentation, changeSet: ChangeSet, parent: Tree, file: Option[SourceFile]) {
21+
private lazy val lexical = file map (new LexicalStructure(_))
22+
23+
def tokensBetween(start: Int, end: Int): Seq[Token] =
24+
lexical.map(_.tokensBetween(start, end)).getOrElse(Seq())
25+
2126
lazy val newline: String = {
2227
if(file.exists(_.content.containsSlice("\r\n")))
2328
"\r\n"
@@ -32,4 +37,4 @@ trait AbstractPrinter extends CommonPrintUtils {
3237

3338
def print(t: Tree, ctx: PrintingContext): Fragment
3439

35-
}
40+
}

org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import language.implicitConversions
1313

1414
trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter {
1515

16-
outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with Indentations with LayoutHelper with Formatting =>
16+
outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with common.CompilerApiExtensions with Indentations with LayoutHelper with Formatting =>
1717

1818
import global._
1919

@@ -130,7 +130,7 @@ trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter {
130130
case Some(patP(patStr)) if guard == EmptyTree => Fragment(patStr)
131131
case _ => p(pat)
132132
}
133-
133+
134134
val arrowReq = new Requisite {
135135
def isRequired(l: Layout, r: Layout) = {
136136
!(l.contains("=>") || r.contains("=>"))

org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import language.implicitConversions
1111

1212
trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {
1313

14-
outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with Formatting with Indentations =>
14+
outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with common.CompilerApiExtensions with Formatting with Indentations =>
1515

1616
import global._
1717

@@ -988,7 +988,7 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {
988988
}
989989

990990
val isAbstract = body == EmptyFragment
991-
val resultType =
991+
val rawResultType =
992992
if (isAbstract && !existsTptInFile)
993993
EmptyFragment
994994
else
@@ -1003,42 +1003,47 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {
10031003
}
10041004
}
10051005

1006-
val noEqualNeeded = resultType == EmptyFragment || isAbstract
1006+
val noEqualNeeded = rawResultType == EmptyFragment || isAbstract
10071007

1008-
val resultType2 = {
1008+
val resultType = {
10091009
def addLeadingSpace = name.isOperatorName || name.endsWith('_')
1010-
if (resultType != EmptyFragment && addLeadingSpace) Layout(" ") ++ resultType else resultType
1010+
if (rawResultType != EmptyFragment && addLeadingSpace) Layout(" ") ++ rawResultType else rawResultType
10111011
}
10121012

10131013
if (noEqualNeeded && !hasEqualInSource) {
1014-
l ++ modsAndName ++ typeParameters ++ parameters ++ resultType2 ++ body ++ r
1014+
l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ body ++ r
10151015
} else {
10161016
val openingBrace = keepOpeningBrace(tree, tpt, rhs)
10171017
// In case a Unit return type is added to a method like `def f {}`, we
10181018
// need to remove the whitespace between name and rhs, otherwise the
10191019
// result would be `def f : Unit = {}`.
1020-
val modsAndName2 =
1020+
val modsAndNameTrimmed =
10211021
if (modsAndName.trailing.asText.trim.isEmpty)
10221022
Fragment(modsAndName.leading, modsAndName.center, NoLayout)
10231023
else
10241024
modsAndName
10251025

1026-
l ++ modsAndName2 ++ typeParameters ++ parameters ++ resultType2 ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r
1026+
l ++ modsAndNameTrimmed ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r
10271027
}
10281028
}
10291029

10301030
/**
1031-
* In case a definition like `def f = {0}` contains a single expression in
1032-
* braces, we need to find the braces manually because they are no part of
1033-
* the tree.
1031+
* In case a `ValOrDefDef` like `def f = {0}` contains a single expression
1032+
* in braces, we are screwed.
1033+
* In such cases the compiler removes the opening and closing braces from
1034+
* the tree. It also removes all the whitespace between the equal sign and
1035+
* the opening brace. If a refactoring edits such a definition, we need to
1036+
* get the whitespace+opening brace back.
10341037
*/
1035-
private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree): String = tpt match {
1038+
private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree)(implicit ctx: PrintingContext): String = tpt match {
10361039
case tpt: TypeTree if tpt.original != null && tree.pos != NoPosition && rhs.pos != NoPosition =>
1037-
val OpeningBrace = "(?s).*(\\{.*)".r
1038-
Layout(tree.pos.source, tree.pos.point, rhs.pos.start).asText match {
1039-
case OpeningBrace(brace) => brace
1040-
case _ => ""
1041-
}
1040+
val tokens = ctx.tokensBetween(tree.pos.point, rhs.pos.start)
1041+
tokens find (_.tokenId == tools.nsc.ast.parser.Tokens.EQUALS) map {
1042+
case Token(_, start, _) =>
1043+
val c = tree.pos.source.content
1044+
val skippedWsStart = if (c(start+1).isWhitespace) start+2 else start+1
1045+
c.slice(skippedWsStart, rhs.pos.start).mkString
1046+
} getOrElse ""
10421047
case _ =>
10431048
""
10441049
}

org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ package sourcegen
88
import common.Tracing
99
import common.Change
1010
import common.PimpedTrees
11+
import common.CompilerApiExtensions
1112
import scala.tools.refactoring.common.TextChange
1213
import scala.reflect.internal.util.SourceFile
1314

14-
trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with LayoutHelper with Formatting with TreeChangesDiscoverer {
15+
trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with CompilerApiExtensions with LayoutHelper with Formatting with TreeChangesDiscoverer {
1516

1617
self: Tracing with common.CompilerAccess =>
1718

org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class ReusingPrinterTest extends TestHelper with SilentTracing {
2525

2626
final implicit class ImplicitTreeHelper(original: Tree) {
2727
/** Needs to be executed on the PC thread. */
28-
def printsTo(expectedOutput: String): Unit = {
29-
val sourceFile = new BatchSourceFile("noname", expectedOutput)
28+
def printsTo(input: String, expectedOutput: String): Unit = {
29+
val sourceFile = new BatchSourceFile("textInput", input)
3030
val expected = stripWhitespacePreservers(expectedOutput).trim()
3131
val actual = generate(original, sourceFile = Some(sourceFile)).asText.trim()
3232
if (actual != expected)
@@ -40,7 +40,7 @@ class ReusingPrinterTest extends TestHelper with SilentTracing {
4040
def after(trans: Transformation[Tree, Tree]): Unit = ask { () =>
4141
val t = trans(treeFrom(input._1))
4242
require(t.isDefined, "transformation was not successful")
43-
t foreach (_.printsTo(input._2))
43+
t foreach (_.printsTo(input._1, input._2))
4444
}
4545
}
4646

@@ -95,16 +95,42 @@ class ReusingPrinterTest extends TestHelper with SilentTracing {
9595
def add_return_type_to_val_with_single_expression_in_braces() = """
9696
package add_return_type_to_val_with_single_expression_in_braces
9797
object X {
98-
val foo = {
98+
val a = {
99+
0
100+
}
101+
val b = /* {str */ {
102+
0
103+
}
104+
val c = { 0 match {
105+
case i => i
106+
}}
107+
val d = { // {str}
108+
0
109+
}
110+
val e = /* {str */ { // {str
99111
0
100112
}
113+
val f={0}
101114
}
102115
""" becomes """
103116
package add_return_type_to_val_with_single_expression_in_braces
104117
object X {
105-
val foo: Int = {
118+
val a: Int = {
119+
0
120+
}
121+
val b: Int = /* {str */ {
122+
0
123+
}
124+
val c: Int = { 0 match {
125+
case i => i
126+
}}
127+
val d: Int = { // {str}
128+
0
129+
}
130+
val e: Int = /* {str */ { // {str
106131
0
107132
}
133+
val f: Int = {0}
108134
}
109135
""" after topdown { matchingChildren { transform {
110136
case d @ ValDef(_, _, tpt: TypeTree, _) =>
@@ -139,16 +165,42 @@ class ReusingPrinterTest extends TestHelper with SilentTracing {
139165
def add_return_type_to_def_with_single_expression_in_braces() = """
140166
package add_return_type_to_def_with_single_expression_in_braces
141167
object X {
142-
def foo = {
168+
def a = {
169+
0
170+
}
171+
def b = /* {str */ {
172+
0
173+
}
174+
def c = { 0 match {
175+
case i => i
176+
}}
177+
def d = { // {str}
178+
0
179+
}
180+
def e = /* {str */ { // {str
143181
0
144182
}
183+
def f={0}
145184
}
146185
""" becomes """
147186
package add_return_type_to_def_with_single_expression_in_braces
148187
object X {
149-
def foo: Int = {
188+
def a: Int = {
189+
0
190+
}
191+
def b: Int = /* {str */ {
192+
0
193+
}
194+
def c: Int = { 0 match {
195+
case i => i
196+
}}
197+
def d: Int = { // {str}
198+
0
199+
}
200+
def e: Int = /* {str */ { // {str
150201
0
151202
}
203+
def f: Int = {0}
152204
}
153205
""" after topdown { matchingChildren { transform {
154206
case d @ DefDef(_, _, _, _, tpt: TypeTree, _) =>

0 commit comments

Comments
 (0)