Skip to content
This repository was archived by the owner on Sep 3, 2020. It is now read-only.

Gen return types for single expressions #66

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package scala.tools.refactoring.common

import scala.tools.nsc.Global
import scala.collection.immutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.internal.util.SourceFile
import scala.tools.nsc.ast.parser.Tokens

/*
* FIXME: This class duplicates functionality from org.scalaide.core.compiler.CompilerApiExtensions.
* FIXME: This class duplicates functionality from [[org.scalaide.core.compiler.CompilerApiExtensions]].
*/
trait CompilerApiExtensions {
this: CompilerAccess =>
Expand Down Expand Up @@ -55,4 +58,58 @@ trait CompilerApiExtensions {
}
}
}

/** A helper class to access the lexical tokens of `source`.
*
* Once constructed, instances of this class are thread-safe.
*/
class LexicalStructure(source: SourceFile) {
private val token = new ArrayBuffer[Int]
private val startOffset = new ArrayBuffer[Int]
private val endOffset = new ArrayBuffer[Int]
private val scanner = new syntaxAnalyzer.UnitScanner(new CompilationUnit(source))
scanner.init()

while (scanner.token != Tokens.EOF) {
startOffset += scanner.offset
token += scanner.token
scanner.nextToken
endOffset += scanner.lastOffset
}

/** Return the index of the token that covers `offset`.
*/
private def locateIndex(offset: Int): Int = {
var lo = 0
var hi = token.length - 1
while (lo < hi) {
val mid = (lo + hi + 1) / 2
if (startOffset(mid) <= offset) lo = mid
else hi = mid - 1
}
lo
}

/** Return all tokens between start and end offsets.
*
* The first token may start before `start` and the last token may span after `end`.
*/
def tokensBetween(start: Int, end: Int): immutable.Seq[Token] = {
val startIndex = locateIndex(start)
val endIndex = locateIndex(end)

val tmp = for (i <- startIndex to endIndex)
yield Token(token(i), startOffset(i), endOffset(i))

tmp.toSeq
}
}

/** A Scala token covering [start, end)
*
* @param tokenId one of scala.tools.nsc.ast.parser.Tokens identifiers
* @param start the offset of the first character in this token
* @param end the offset of the first character after this token
*/
case class Token(tokenId: Int, start: Int, end: Int)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import scala.reflect.internal.util.SourceFile

trait AbstractPrinter extends CommonPrintUtils {

this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with Formatting =>
this: common.Tracing with common.PimpedTrees with Indentations with common.CompilerAccess with common.CompilerApiExtensions with Formatting =>

import global._

Expand All @@ -18,6 +18,11 @@ trait AbstractPrinter extends CommonPrintUtils {
* the context or environment for the current printing.
*/
case class PrintingContext(ind: Indentation, changeSet: ChangeSet, parent: Tree, file: Option[SourceFile]) {
private lazy val lexical = file map (new LexicalStructure(_))

def tokensBetween(start: Int, end: Int): Seq[Token] =
lexical.map(_.tokensBetween(start, end)).getOrElse(Seq())

lazy val newline: String = {
if(file.exists(_.content.containsSlice("\r\n")))
"\r\n"
Expand All @@ -32,4 +37,4 @@ trait AbstractPrinter extends CommonPrintUtils {

def print(t: Tree, ctx: PrintingContext): Fragment

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import language.implicitConversions

trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter {

outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with Indentations with LayoutHelper with Formatting =>
outer: common.PimpedTrees with common.CompilerAccess with common.Tracing with common.CompilerApiExtensions with Indentations with LayoutHelper with Formatting =>

import global._

Expand Down Expand Up @@ -130,7 +130,7 @@ trait PrettyPrinter extends TreePrintingTraversals with AbstractPrinter {
case Some(patP(patStr)) if guard == EmptyTree => Fragment(patStr)
case _ => p(pat)
}

val arrowReq = new Requisite {
def isRequired(l: Layout, r: Layout) = {
!(l.contains("=>") || r.contains("=>"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import language.implicitConversions

trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {

outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with Formatting with Indentations =>
outer: LayoutHelper with common.Tracing with common.PimpedTrees with common.CompilerAccess with common.CompilerApiExtensions with Formatting with Indentations =>

import global._

Expand Down Expand Up @@ -202,22 +202,22 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {

override def CaseDef(tree: CaseDef, pat: Tree, guard: Tree, body: Tree)(implicit ctx: PrintingContext) = {
val arrowReq = new Requisite {
def isRequired(l: Layout, r: Layout) = {
override def isRequired(l: Layout, r: Layout) = {
!(l.contains("=>") || r.contains("=>") || p(body).asText.startsWith("=>"))
}

// It's just nice to have a whitespace before and after the arrow
def getLayout = Layout(" => ")
override def getLayout = Layout(" => ")
}

val ifReq = new Requisite {
def isRequired(l: Layout, r: Layout) = {
override def isRequired(l: Layout, r: Layout) = {
!(l.contains("if") || r.contains("if"))
}

// Leading and trailing whitespace is required in some cases!
// e.g. `case i if i > 0 => ???` becomes `case iifi > 0 => ???` otherwise
def getLayout = Layout(" if ")
override def getLayout = Layout(" if ")
}

body match {
Expand Down Expand Up @@ -921,10 +921,12 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {
val body = p(rhs)
val noEqualNeeded = body == EmptyFragment || rhs.tpe == null || (rhs.tpe != null && rhs.tpe.toString == "Unit")

def openingBrace = keepOpeningBrace(tree, tpt, rhs)

if (noEqualNeeded)
l ++ mods_ ++ resultType ++ body ++ r
else
l ++ mods_ ++ resultType ++ Requisite.anywhere("=", " = ") ++ body ++ r
l ++ mods_ ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r
}
}

Expand Down Expand Up @@ -985,8 +987,9 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {
case _ => false
}

val resultType =
if (body == EmptyFragment && !existsTptInFile)
val isAbstract = body == EmptyFragment
val rawResultType =
if (isAbstract && !existsTptInFile)
EmptyFragment
else
p(tpt, before = Requisite.allowSurroundingWhitespace(":", ": "))
Expand All @@ -1000,16 +1003,50 @@ trait ReusingPrinter extends TreePrintingTraversals with AbstractPrinter {
}
}

val noEqualNeeded = {
body == EmptyFragment || rhs.tpe == null || (rhs.tpe != null && rhs.tpe.toString == "Unit")
val noEqualNeeded = rawResultType == EmptyFragment || isAbstract

val resultType = {
def addLeadingSpace = name.isOperatorName || name.endsWith('_')
if (rawResultType != EmptyFragment && addLeadingSpace) Layout(" ") ++ rawResultType else rawResultType
}

if (noEqualNeeded && !hasEqualInSource) {
l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ body ++ r
} else {
l ++ modsAndName ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ body ++ r
val openingBrace = keepOpeningBrace(tree, tpt, rhs)
// In case a Unit return type is added to a method like `def f {}`, we
// need to remove the whitespace between name and rhs, otherwise the
// result would be `def f : Unit = {}`.
val modsAndNameTrimmed =
if (modsAndName.trailing.asText.trim.isEmpty)
Fragment(modsAndName.leading, modsAndName.center, NoLayout)
else
modsAndName

l ++ modsAndNameTrimmed ++ typeParameters ++ parameters ++ resultType ++ Requisite.anywhere("=", " = ") ++ openingBrace ++ body ++ r
}
}

/**
* In case a `ValOrDefDef` like `def f = {0}` contains a single expression
* in braces, we are screwed.
* In such cases the compiler removes the opening and closing braces from
* the tree. It also removes all the whitespace between the equal sign and
* the opening brace. If a refactoring edits such a definition, we need to
* get the whitespace+opening brace back.
*/
private def keepOpeningBrace(tree: Tree, tpt: Tree, rhs: Tree)(implicit ctx: PrintingContext): String = tpt match {
case tpt: TypeTree if tpt.original != null && tree.pos != NoPosition && rhs.pos != NoPosition =>
val tokens = ctx.tokensBetween(tree.pos.point, rhs.pos.start)
tokens find (_.tokenId == tools.nsc.ast.parser.Tokens.EQUALS) map {
case Token(_, start, _) =>
val c = tree.pos.source.content
val skippedWsStart = if (c(start+1).isWhitespace) start+2 else start+1
c.slice(skippedWsStart, rhs.pos.start).mkString
} getOrElse ""
case _ =>
""
}
}

trait SuperPrinters {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ package sourcegen
import common.Tracing
import common.Change
import common.PimpedTrees
import common.CompilerApiExtensions
import scala.tools.refactoring.common.TextChange
import scala.reflect.internal.util.SourceFile

trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with LayoutHelper with Formatting with TreeChangesDiscoverer {
trait SourceGenerator extends PrettyPrinter with Indentations with ReusingPrinter with PimpedTrees with CompilerApiExtensions with LayoutHelper with Formatting with TreeChangesDiscoverer {

self: Tracing with common.CompilerAccess =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ trait TreeFactory {

def mkReturn(s: List[Symbol]): Tree = s match {
case Nil => EmptyTree
case x :: Nil => Ident(x) setType x.tpe
case x :: Nil =>
val ident = if (x.isModuleClass) Ident(newTermName(s"${x.name}.type")) else Ident(x)
ident setType x.tpe
case xs =>
typer.typed(gen.mkTuple(xs map (s => Ident(s) setType s.tpe))) match {
case t: Apply =>
Expand Down Expand Up @@ -114,7 +116,7 @@ trait TreeFactory {
}

if (mods != NoMods) valOrVarDef setSymbol NoSymbol.newValue(name, newFlags = mods.flags) else valOrVarDef
}
}

def mkParam(name: String, tpe: Type, defaultVal: Tree = EmptyTree): ValDef = {
ValDef(Modifiers(Flags.PARAM), newTermName(name), TypeTree(tpe), defaultVal)
Expand Down
Loading