Skip to content
This repository was archived by the owner on Sep 3, 2020. It is now read-only.

Organize Imports in Defs avoiding tree printer #142

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ trait CompilationUnitDependencies extends CompilerApiExtensions with ScalaVersio
*/
def dependencies(t: Tree): List[Select] = {
val wholeTree = t
val isNotInImports = new IsNotInImports(CompilationUnitDependencies.this.global, CompilationUnitDependencies.this, t);

def qualifierIsEnclosingPackage(t: Select) = {
enclosingPackage(wholeTree, t.pos) match {
Expand Down Expand Up @@ -386,7 +387,8 @@ trait CompilationUnitDependencies extends CompilerApiExtensions with ScalaVersio
&& hasStableQualifier(t)
&& !t.symbol.isLocal
&& !isRelativeToLocalImports(t)
&& !isDefinedLocallyAndQualifiedWithEnclosingPackage(t)) {
&& !isDefinedLocallyAndQualifiedWithEnclosingPackage(t)
&& isNotInImports.isSelectNotInRelativeImports(t)) {
addToResult(t)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package scala.tools.refactoring
package analysis

import tools.nsc.interactive.Global

class IsNotInImports(val global: Global, val cuDependenciesInstance: CompilationUnitDependencies with common.EnrichedTrees, tree: Global#Tree) {
import global._
val wholeTree = tree.asInstanceOf[Tree]

private def collectPotentialOwners(of: Select): List[Symbol] = {
val upToPosition = of.pos.start
def isThisSelectTree(fromWholeTree: Tree): Boolean =
fromWholeTree.pos.isRange && fromWholeTree.pos.start == upToPosition
var owners = List.empty[Symbol]
val collectPotentialOwners = new Traverser {
var owns = List.empty[Symbol]
override def traverse(t: Tree) = {
owns = currentOwner :: owns
t match {
case t if !t.pos.isRange || t.pos.start > upToPosition =>
case potential if isThisSelectTree(potential) =>
owners = owns.distinct
case t =>
super.traverse(t)
owns = owns.tail
}
}
}
collectPotentialOwners.traverse(wholeTree)
owners
}

def isSelectNotInRelativeImports(tested: Global#Select): Boolean = {
val doesNameFitInTested = compareNameWith(tested) _
val nonPackageOwners = collectPotentialOwners(tested.asInstanceOf[Select]).filterNot { _.hasPackageFlag }
def isValidPosition(t: Import): Boolean = t.pos.isRange && t.pos.start < tested.pos.start
val isImportForTested = new Traverser {
var found = false
override def traverse(t: Tree) = t match {
case imp: Import if isValidPosition(imp) && doesNameFitInTested(imp) && nonPackageOwners.contains(currentOwner) =>
found = true
case t => super.traverse(t)
}
}
isImportForTested.traverse(wholeTree.asInstanceOf[Tree])
!isImportForTested.found
}

private def compareNameWith(tested: Global#Select)(that: Global#Import): Boolean = {
import cuDependenciesInstance.additionalTreeMethodsForPositions
val Select(testedQual, testedName) = tested
val testedQName = List(testedQual.asInstanceOf[cuDependenciesInstance.global.Tree].nameString, testedName).mkString(".")
val Import(thatQual, thatSels) = that
val impNames = thatSels.map { sel =>
if (sel.name == nme.WILDCARD) thatQual.asInstanceOf[cuDependenciesInstance.global.Tree].nameString
else List(thatQual.asInstanceOf[cuDependenciesInstance.global.Tree].nameString, sel.name).mkString(".")
}
impNames.exists { testedQName.startsWith }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ object OrganizeImports {
*
*
*/
abstract class OrganizeImports extends MultiStageRefactoring with TreeFactory with TreeTraverser with UnusedImportsFinder with analysis.CompilationUnitDependencies with common.InteractiveScalaCompiler with common.TreeExtractors {
abstract class OrganizeImports extends MultiStageRefactoring with TreeFactory
with TreeTraverser
with UnusedImportsFinder
with analysis.CompilationUnitDependencies
with common.InteractiveScalaCompiler
with common.TreeExtractors {

import OrganizeImports.Algos
import global._
Expand Down Expand Up @@ -230,7 +235,7 @@ abstract class OrganizeImports extends MultiStageRefactoring with TreeFactory wi
def removeDuplicates(l: List[ImportSelector]) = {
l.groupBy(_.name.toString).map(_._2.head).toList
}
imp.copy(selectors = removeDuplicates(selectors).sortBy(_.name.toString))
imp.copy(selectors = removeDuplicates(selectors).sortBy(_.name.toString)).setPos(imp.pos)
}
}
}
Expand Down Expand Up @@ -463,99 +468,25 @@ abstract class OrganizeImports extends MultiStageRefactoring with TreeFactory wi
case (p, existingImports, others) =>
val imports = scala.Function.chain(participants)(existingImports)
p copy (stats = imports ::: others) replaces p
} &> transformation[Tree, Tree] {
case p: PackageDef =>
InnerImports.organizeImportsInMethodBlocks(p).replaces(p)
}

Right(transformFile(selection.file, organizeImports |> topdown(matchingChildren(organizeImports))))
}

object InnerImports {
class RemoveUnused(block: Tree) extends Participant {
private def treeWithoutImports(tree: Tree) = new Transformer {
override def transform(tree: Tree): Tree = tree match {
case Import(_, _) => EmptyTree
case t => super.transform(t)
}
}.transform(tree)

private lazy val allSelects = {
import scala.collection.mutable
val selects = mutable.ListBuffer[Select]()
val selectsTraverser = new Traverser {
override def traverse(tree: Tree): Unit = tree match {
case s @ Select(qual, _) =>
selects += s
traverse(qual)
case t => super.traverse(t)
}
}
selectsTraverser.traverse(treeWithoutImports(block))
selects.toList
}

protected def doApply(trees: List[Import]) = trees collect {
case imp @ Import(importQualifier: Select, importSelections) =>
val usedSelectors = importSelections filter { importSel =>
val importName = importSel.name.toString
val importSym = importQualifier.symbol
val isWildcard = importSel.name == nme.WILDCARD

allSelects.exists { foundSel =>
val foundName = foundSel.symbol.nameString
val foundSym = foundSel.qualifier.symbol
(isWildcard || foundName == importName) && foundSym == importSym
}
}
usedSelectors match {
case Nil => Import(EmptyTree, Nil)
case _ => imp.copy(selectors = usedSelectors)
}
val rootTree = abstractFileToTree(selection.file)
import oimports.DefImportsOrganizer
import oimports.NotPackageImportParticipants
val notPackageParticipants = new NotPackageImportParticipants(global, this)
import notPackageParticipants.RemoveDuplicatedByWildcard
import notPackageParticipants.{ RemoveUnused => NPRemovedUnused }
val regions = new DefImportsOrganizer(global).transformTreeToRegions(rootTree).map {
_.transform { i =>
scala.Function.chain { RemoveDuplicatedByWildcard.asInstanceOf[Participant] ::
(new NPRemovedUnused(rootTree)).asInstanceOf[Participant] ::
RemoveDuplicates ::
SortImportSelectors ::
SortImports ::
Nil }(i.asInstanceOf[List[Import]])
}
}

object RemoveDuplicatedByWildcard extends Participant {
protected def doApply(trees: List[Import]) = trees.map { imp =>
val wild = imp.selectors.find(_.name == nme.WILDCARD)
if (wild.nonEmpty)
imp.copy(selectors = wild.toList)
else
imp
}.groupBy {
_.expr.toString
}.collect {
case (_, imports) =>
val (wild, rest) = imports.partition(_.selectors.exists(_.name == nme.WILDCARD))
if (wild.nonEmpty)
wild
else
rest
}.flatten.toList
}

private def organizeImportsIfNoImportInSameLine(imports: List[Import])(organizeImports: List[Import] => List[Import]): List[Import] = {
val importsWithPosition = imports.filter { _.pos.isDefined }
if (importsWithPosition.nonEmpty &&
importsWithPosition.size != importsWithPosition.map { _.pos.line }.distinct.size)
imports
else organizeImports(imports)
}

def organizeImportsInMethodBlocks(tree: Tree): Tree = new Transformer {
override def transform(t: Tree) = t match {
case b @ Block(stats, _) if currentOwner.isMethod && !currentOwner.isLazy =>
val (rawImports, others) = stats.partition { _.isInstanceOf[Import] }
val imports = rawImports.asInstanceOf[List[Import]]
val importsOrganizer = scala.Function.chain(new RemoveUnused(b) :: RemoveDuplicatedByWildcard ::
RemoveDuplicates :: SortImportSelectors :: SortImports :: Nil)
val visitedOthers = others.map { t =>
transform(t).replaces(t)
}
b.copy(stats = organizeImportsIfNoImportInSameLine(imports)(importsOrganizer) ::: visitedOthers).replaces(b)
case skipPlainText: PlainText => skipPlainText
case t => super.transform(t)
}
}.transform(tree)
val changes = regions.map { _.print }
Right(transformFile(selection.file, organizeImports |> topdown(matchingChildren(organizeImports))) ::: changes)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package scala.tools.refactoring
package implementations.oimports

import scala.tools.nsc.interactive.Global
import scala.tools.refactoring.common.Change
import scala.tools.refactoring.common.TextChange
import scala.util.Properties

class DefImportsOrganizer(val global: Global) {
import global._

private def noAnyTwoImportsInSameLine(importsGroup: List[Global#Import]): Boolean =
importsGroup.size == importsGroup.map { _.pos.line }.distinct.size

private def importsGroupsFromTree(tree: Global#Tree): List[List[Global#Import]] = {
val impTraverser = new Traverser {
var groupedImports = List.empty[List[Import]]
var group = List.empty[Import]
override def traverse(tree: Tree) = tree match {
case imp: Import =>
group = group :+ imp
case t =>
if (group.nonEmpty) {
groupedImports = groupedImports :+ group
group = List.empty[Import]
}
super.traverse(t)
}
}
impTraverser.traverse(tree.asInstanceOf[Tree])
impTraverser.groupedImports
}

private val util = new TreeToolbox(global)
import util.forTreesOfKind

private def forTreesOfBlocks(tree: Global#Tree) = forTreesOfKind[Global#Block](tree) { (collected, currentOwner) => {
case b: util.global.Block if currentOwner.isMethod && !currentOwner.isLazy =>
collected += b
}
}

private def toRegions(groupedImports: List[List[Global#Import]]): List[Region] =
groupedImports.map {
case imports @ h :: _ => Some(Region(imports)(global))
case _ => None
}.filter {
_.nonEmpty
}.map { _.get }

def transformTreeToRegions(tree: Global#Tree): List[Region] = toRegions(forTreesOfBlocks(tree).flatMap { block =>
importsGroupsFromTree(block).filter {
noAnyTwoImportsInSameLine
}
})
}

class TreeToolbox(val global: Global) {
import scala.collection._
import global._

private class TreeCollector[T <: Global#Tree](traverserBody: (mutable.ListBuffer[T], Global#Symbol) => PartialFunction[Tree, Unit]) extends Traverser {
val collected = mutable.ListBuffer.empty[T]
override def traverse(tree: Tree): Unit = traverserBody(collected, currentOwner).orElse[Tree, Unit] {
case t => super.traverse(t)
}(tree)
}

def forTreesOfKind[T <: Global#Tree](tree: Global#Tree)(traverserBody: (mutable.ListBuffer[T], Global#Symbol) => PartialFunction[Tree, Unit]): List[T] = {
val treeTraverser = new TreeCollector[T](traverserBody)
treeTraverser.traverse(tree.asInstanceOf[Tree])
treeTraverser.collected.toList
}
}

import scala.reflect.internal.util.SourceFile
case class Region private (imports: List[Global#Import], startPos: Global#Position, endPos: Global#Position,
source: SourceFile, indentation: String, printImport: Global#Import => String) {
def transform(transformation: List[Global#Import] => List[Global#Import]): Region =
copy(imports = transformation(imports))

private def printEmptyImports: Change = {
val fromBeginningOfLine = source.lineToOffset(source.offsetToLine(startPos.start))
val toEndOfLine = endPos.end + Properties.lineSeparator.length
TextChange(source, fromBeginningOfLine, toEndOfLine, "")
}

def print: Change = if (imports.nonEmpty) printNonEmptyImports else printEmptyImports

private def printNonEmptyImports: Change = {
val from = startPos.pos.start
val to = endPos.pos.end
val text = imports.zipWithIndex.foldLeft("") { (acc, imp) =>
def isLast(idx: Int) = idx == imports.size - 1
imp match {
case (imp, 0) if isLast(0) =>
acc + printImport(imp)
case (imp, 0) =>
acc + printImport(imp) + Properties.lineSeparator
case (imp, idx) if isLast(idx) =>
acc + indentation + printImport(imp)
case (imp, _) =>
acc + indentation + printImport(imp) + Properties.lineSeparator
}
}
TextChange(source, from, to, text)
}
}

object Region {
private def indentation(imp: Global#Import): String = {
val sourceFile = imp.pos.source
sourceFile.lineToString(sourceFile.offsetToLine(imp.pos.start)).takeWhile { _.isWhitespace }
}

def apply(imports: List[Global#Import])(global: Global): Region = {
assert(imports.nonEmpty)
val source = imports.head.pos.source
def printImport(imp: Global#Import): String = {
import global._
val RenameArrow = " => "
val prefix = source.content.slice(imp.pos.start, imp.pos.end).mkString.reverse.dropWhile { _ != '.' }.reverse
val suffix = imp.selectors.map { sel =>
if (sel.name == sel.rename || sel.name == nme.WILDCARD)
sel.name.toString
else
sel.name + RenameArrow + sel.rename
}
val areBracesNeeded = suffix.size > 1 || suffix.exists { _ contains RenameArrow }
prefix + suffix.mkString(if (areBracesNeeded) "{" else "", ", ", if (areBracesNeeded) "}" else "")
}
Region(imports, imports.head.pos, imports.last.pos, source, indentation(imports.head), printImport)
}
}
Loading