diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala index 3beadc3a..2a5a6aea 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/common/CompilerApiExtensions.scala @@ -1,9 +1,12 @@ package scala.tools.refactoring.common -import scala.tools.nsc.Global +import scala.collection.immutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.internal.util.SourceFile +import scala.tools.nsc.ast.parser.Tokens /* - * FIXME: This class duplicates functionality from org.scalaide.core.compiler.CompilerApiExtensions. + * FIXME: This class duplicates functionality from [[org.scalaide.core.compiler.CompilerApiExtensions]]. */ trait CompilerApiExtensions { this: CompilerAccess => @@ -55,4 +58,58 @@ trait CompilerApiExtensions { } } } + + /** A helper class to access the lexical tokens of `source`. + * + * Once constructed, instances of this class are thread-safe. + */ + class LexicalStructure(source: SourceFile) { + private val token = new ArrayBuffer[Int] + private val startOffset = new ArrayBuffer[Int] + private val endOffset = new ArrayBuffer[Int] + private val scanner = new syntaxAnalyzer.UnitScanner(new CompilationUnit(source)) + scanner.init() + + while (scanner.token != Tokens.EOF) { + startOffset += scanner.offset + token += scanner.token + scanner.nextToken + endOffset += scanner.lastOffset + } + + /** Return the index of the token that covers `offset`. + */ + private def locateIndex(offset: Int): Int = { + var lo = 0 + var hi = token.length - 1 + while (lo < hi) { + val mid = (lo + hi + 1) / 2 + if (startOffset(mid) <= offset) lo = mid + else hi = mid - 1 + } + lo + } + + /** Return all tokens between start and end offsets. + * + * The first token may start before `start` and the last token may span after `end`. + */ + def tokensBetween(start: Int, end: Int): immutable.Seq[Token] = { + val startIndex = locateIndex(start) + val endIndex = locateIndex(end) + + val tmp = for (i <- startIndex to endIndex) + yield Token(token(i), startOffset(i), endOffset(i)) + + tmp.toSeq + } + } + + /** A Scala token covering [start, end) + * + * @param tokenId one of scala.tools.nsc.ast.parser.Tokens identifiers + * @param start the offset of the first character in this token + * @param end the offset of the first character after this token + */ + case class Token(tokenId: Int, start: Int, end: Int) } diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala index 11052248..32e5258a 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/AbstractPrinter.scala @@ -9,7 +9,7 @@ import scala.reflect.internal.util.SourceFile trait AbstractPrinter extends CommonPrintUtils { - this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with Formatting => + this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with common.CompilerApiExtensions with Formatting => import global._ @@ -18,6 +18,11 @@ trait AbstractPrinter extends CommonPrintUtils { * the context or environment for the current printing. */ case class PrintingContext(ind: Indentation, changeSet: ChangeSet, parent: Tree, file: Option[SourceFile]) { + private lazy val lexical = file map (new LexicalStructure(_)) + + def tokensBetween(start: Int, end: Int): Seq[Token] = + lexical.map(_.tokensBetween(start, end)).getOrElse(Seq()) + lazy val newline: String = { if(file.exists(_.content.containsSlice("\r\n"))) "\r\n" @@ -32,4 +37,4 @@ trait AbstractPrinter extends CommonPrintUtils { def print(t: Tree, ctx: PrintingContext): Fragment -} \ No newline at end of file +} diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala index ae82c698..7ca80d92 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/PrettyPrinter.scala @@ -13,7 +13,7 @@ import language.implicitConversions trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter { - outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with Indentations with LayoutHelper with Formatting => + outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with common.CompilerApiExtensions with Indentations with LayoutHelper with Formatting => import global._ @@ -130,7 +130,7 @@ trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter { case Some(patP(patStr)) if guard == EmptyTree => Fragment(patStr) case _ => p(pat) } - + val arrowReq = new Requisite { def isRequired(l: Layout, r: Layout) = { !(l.contains("=>") || r.contains("=>")) diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala index 90f57ec9..ff7d75e7 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/ReusingPrinter.scala @@ -11,7 +11,7 @@ import language.implicitConversions trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { - outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with Formatting with Indentations => + outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with common.CompilerApiExtensions with Formatting with Indentations => import global._ @@ -202,22 +202,22 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { override def CaseDef(tree: CaseDef, pat: Tree, guard: Tree, body: Tree)(implicit ctx: PrintingContext) = { val arrowReq = new Requisite { - def isRequired(l: Layout, r: Layout) = { + override def isRequired(l: Layout, r: Layout) = { !(l.contains("=>") || r.contains("=>") || p(body).asText.startsWith("=>")) } // It's just nice to have a whitespace before and after the arrow - def getLayout = Layout(" => ") + override def getLayout = Layout(" => ") } val ifReq = new Requisite { - def isRequired(l: Layout, r: Layout) = { + override def isRequired(l: Layout, r: Layout) = { !(l.contains("if") || r.contains("if")) } // Leading and trailing whitespace is required in some cases! // e.g. `case i if i > 0 => ???` becomes `case iifi > 0 => ???` otherwise - def getLayout = Layout(" if ") + override def getLayout = Layout(" if ") } body match { @@ -921,10 +921,12 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { val body = p(rhs) val noEqualNeeded = body == EmptyFragment || rhs.tpe == null || (rhs.tpe != null && rhs.tpe.toString == "Unit") + def openingBrace = keepOpeningBrace(tree, tpt, rhs) + if (noEqualNeeded) l ++ mods_ ++ resultType ++ body ++ r else - l ++ mods_ ++ resultType ++ Requisite.anywhere("=", " = ") ++ body ++ r + l ++ mods_ ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r } } @@ -985,8 +987,9 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { case _ => false } - val resultType = - if (body == EmptyFragment && !existsTptInFile) + val isAbstract = body == EmptyFragment + val rawResultType = + if (isAbstract && !existsTptInFile) EmptyFragment else p(tpt, before = Requisite.allowSurroundingWhitespace(":", ": ")) @@ -1000,16 +1003,50 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter { } } - val noEqualNeeded = { - body == EmptyFragment || rhs.tpe == null || (rhs.tpe != null && rhs.tpe.toString == "Unit") + val noEqualNeeded = rawResultType == EmptyFragment || isAbstract + + val resultType = { + def addLeadingSpace = name.isOperatorName || name.endsWith('_') + if (rawResultType != EmptyFragment && addLeadingSpace) Layout(" ") ++ rawResultType else rawResultType } if (noEqualNeeded && !hasEqualInSource) { l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ body ++ r } else { - l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ body ++ r + val openingBrace = keepOpeningBrace(tree, tpt, rhs) + // In case a Unit return type is added to a method like `def f {}`, we + // need to remove the whitespace between name and rhs, otherwise the + // result would be `def f : Unit = {}`. + val modsAndNameTrimmed = + if (modsAndName.trailing.asText.trim.isEmpty) + Fragment(modsAndName.leading, modsAndName.center, NoLayout) + else + modsAndName + + l ++ modsAndNameTrimmed ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r } } + + /** + * In case a `ValOrDefDef` like `def f = {0}` contains a single expression + * in braces, we are screwed. + * In such cases the compiler removes the opening and closing braces from + * the tree. It also removes all the whitespace between the equal sign and + * the opening brace. If a refactoring edits such a definition, we need to + * get the whitespace+opening brace back. + */ + private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree)(implicit ctx: PrintingContext): String = tpt match { + case tpt: TypeTree if tpt.original != null && tree.pos != NoPosition && rhs.pos != NoPosition => + val tokens = ctx.tokensBetween(tree.pos.point, rhs.pos.start) + tokens find (_.tokenId == tools.nsc.ast.parser.Tokens.EQUALS) map { + case Token(_, start, _) => + val c = tree.pos.source.content + val skippedWsStart = if (c(start+1).isWhitespace) start+2 else start+1 + c.slice(skippedWsStart, rhs.pos.start).mkString + } getOrElse "" + case _ => + "" + } } trait SuperPrinters { diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala index 675b798f..689eeead 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/sourcegen/SourceGenerator.scala @@ -8,10 +8,11 @@ package sourcegen import common.Tracing import common.Change import common.PimpedTrees +import common.CompilerApiExtensions import scala.tools.refactoring.common.TextChange import scala.reflect.internal.util.SourceFile -trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with LayoutHelper with Formatting with TreeChangesDiscoverer { +trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with CompilerApiExtensions with LayoutHelper with Formatting with TreeChangesDiscoverer { self: Tracing with common.CompilerAccess => diff --git a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala index a39de45a..244b7e68 100644 --- a/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala +++ b/org.scala-refactoring.library/src/main/scala/scala/tools/refactoring/transformation/TreeFactory.scala @@ -64,7 +64,9 @@ trait TreeFactory { def mkReturn(s: List[Symbol]): Tree = s match { case Nil => EmptyTree - case x :: Nil => Ident(x) setType x.tpe + case x :: Nil => + val ident = if (x.isModuleClass) Ident(newTermName(s"${x.name}.type")) else Ident(x) + ident setType x.tpe case xs => typer.typed(gen.mkTuple(xs map (s => Ident(s) setType s.tpe))) match { case t: Apply => @@ -114,7 +116,7 @@ trait TreeFactory { } if (mods != NoMods) valOrVarDef setSymbol NoSymbol.newValue(name, newFlags = mods.flags) else valOrVarDef - } + } def mkParam(name: String, tpe: Type, defaultVal: Tree = EmptyTree): ValDef = { ValDef(Modifiers(Flags.PARAM), newTermName(name), TypeTree(tpe), defaultVal) diff --git a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala index 021b47ee..5ccbb361 100644 --- a/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala +++ b/org.scala-refactoring.library/src/test/scala/scala/tools/refactoring/tests/sourcegen/ReusingPrinterTest.scala @@ -25,8 +25,8 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { final implicit class ImplicitTreeHelper(original: Tree) { /** Needs to be executed on the PC thread. */ - def printsTo(expectedOutput: String): Unit = { - val sourceFile = new BatchSourceFile("noname", expectedOutput) + def printsTo(input: String, expectedOutput: String): Unit = { + val sourceFile = new BatchSourceFile("textInput", input) val expected = stripWhitespacePreservers(expectedOutput).trim() val actual = generate(original, sourceFile = Some(sourceFile)).asText.trim() if (actual != expected) @@ -40,7 +40,7 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { def after(trans: Transformation[Tree, Tree]): Unit = ask { () => val t = trans(treeFrom(input._1)) require(t.isDefined, "transformation was not successful") - t foreach (_.printsTo(input._2)) + t foreach (_.printsTo(input._1, input._2)) } } @@ -91,6 +91,207 @@ class ReusingPrinterTest extends TestHelper with SilentTracing { d.copy(tpt = newTpt) replaces d }}} + @Test + def add_return_type_to_val_with_single_expression_in_braces() = """ + package add_return_type_to_val_with_single_expression_in_braces + object X { + val a = { + 0 + } + val b = /* {str */ { + 0 + } + val c = { 0 match { + case i => i + }} + val d = { // {str} + 0 + } + val e = /* {str */ { // {str + 0 + } + val f={0} + } + """ becomes """ + package add_return_type_to_val_with_single_expression_in_braces + object X { + val a: Int = { + 0 + } + val b: Int = /* {str */ { + 0 + } + val c: Int = { 0 match { + case i => i + }} + val d: Int = { // {str} + 0 + } + val e: Int = /* {str */ { // {str + 0 + } + val f: Int = {0} + } + """ after topdown { matchingChildren { transform { + case d @ ValDef(_, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_return_type_to_val_with_multiple_expressions_in_braces() = """ + package add_return_type_to_val_with_multiple_expressions_in_braces + object X { + val foo = { + val a = 0 + a + } + } + """ becomes """ + package add_return_type_to_val_with_multiple_expressions_in_braces + object X { + val foo: Int = { + val a: Int = 0 + a + } + } + """ after topdown { matchingChildren { transform { + case d @ ValDef(_, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_return_type_to_def_with_single_expression_in_braces() = """ + package add_return_type_to_def_with_single_expression_in_braces + object X { + def a = { + 0 + } + def b = /* {str */ { + 0 + } + def c = { 0 match { + case i => i + }} + def d = { // {str} + 0 + } + def e = /* {str */ { // {str + 0 + } + def f={0} + } + """ becomes """ + package add_return_type_to_def_with_single_expression_in_braces + object X { + def a: Int = { + 0 + } + def b: Int = /* {str */ { + 0 + } + def c: Int = { 0 match { + case i => i + }} + def d: Int = { // {str} + 0 + } + def e: Int = /* {str */ { // {str + 0 + } + def f: Int = {0} + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_return_type_to_def_with_multiple_expressions_in_braces() = """ + package add_return_type_to_def_with_multiple_expressions_in_braces + object X { + def foo = { + def a = 0 + a + } + } + """ becomes """ + package add_return_type_to_def_with_multiple_expressions_in_braces + object X { + def foo: Int = { + def a: Int = 0 + a + } + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_Unit_return_type_to_def_with_single_expression_in_braces() = """ + package add_Unit_return_type_to_def_with_single_expression_in_braces + object X { + def foo { + println + } + def bar {} + def baz = () + } + """ becomes """ + package add_Unit_return_type_to_def_with_single_expression_in_braces + object X { + def foo: Unit = { + println + } + def bar: Unit = {} + def baz: Unit = () + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_space_before_return_type_of_def_when_it_ends_with_special_sign() = """ + package add_space_before_return_type_of_def_when_it_ends_with_special_sign + object X { + def foo_ = 0 + def ++ = 0 + } + """ becomes """ + package add_space_before_return_type_of_def_when_it_ends_with_special_sign + object X { + def foo_ : Int = 0 + def ++ : Int = 0 + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + + @Test + def add_type_keyword_to_return_type_when_it_represents_an_object() = """ + package add_type_keyword_to_return_type_when_it_represents_an_object + object X { + def o = X + } + """ becomes """ + package add_type_keyword_to_return_type_when_it_represents_an_object + object X { + def o: X.type = X + } + """ after topdown { matchingChildren { transform { + case d @ DefDef(_, _, _, _, tpt: TypeTree, _) => + val newTpt = tpt setOriginal mkReturn(List(tpt.tpe.typeSymbol)) + d.copy(tpt = newTpt) replaces d + }}} + @Test def add_override_flag() = """ package add_override_flag