Skip to content

Optimize some computation-intensive code #9721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/backend/jvm/BTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ abstract class BTypes {
case DOUBLE => "D"
case ClassBType(internalName) => "L" + internalName + ";"
case ArrayBType(component) => "[" + component
case MethodBType(args, res) => "(" + args.mkString + ")" + res
case MethodBType(args, res) => args.mkString("(", "", ")" + res)
}

/**
Expand Down
46 changes: 30 additions & 16 deletions compiler/src/dotty/tools/dotc/core/Decorators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object Decorators {
NoSymbol
}

final val MaxFilterRecursions = 1000
final val MaxFilterRecursions = 10

/** Implements filterConserve, zipWithConserve methods
* on lists that avoid duplication of list nodes where feasible.
Expand Down Expand Up @@ -105,24 +105,38 @@ object Decorators {
}

/** Like `xs filter p` but returns list `xs` itself - instead of a copy -
* if `p` is true for all elements and `xs` is not longer
* than `MaxFilterRecursions`.
* if `p` is true for all elements.
*/
def filterConserve(p: T => Boolean): List[T] = {
def loop(xs: List[T], nrec: Int): List[T] = xs match {
case Nil => xs
def filterConserve(p: T => Boolean): List[T] =

def addAll(buf: ListBuffer[T], from: List[T], until: List[T]): ListBuffer[T] =
if from eq until then buf else addAll(buf += from.head, from.tail, until)

def loopWithBuffer(buf: ListBuffer[T], xs: List[T]): List[T] = xs match
case x :: xs1 =>
if (nrec < MaxFilterRecursions) {
val ys1 = loop(xs1, nrec + 1)
if (p(x))
if (ys1 eq xs1) xs else x :: ys1
if p(x) then buf += x
loopWithBuffer(buf, xs1)
case nil => buf.toList

def loop(keep: List[T], explore: List[T], keepCount: Int, recCount: Int): List[T] =
explore match
case x :: rest =>
if p(x) then
loop(keep, rest, keepCount + 1, recCount)
else if keepCount <= 3 && recCount <= MaxFilterRecursions then
val rest1 = loop(rest, rest, 0, recCount + 1)
keepCount match
case 0 => rest1
case 1 => keep.head :: rest1
case 2 => keep.head :: keep.tail.head :: rest1
case 3 => val tl = keep.tail; keep.head :: tl.head :: tl.tail.head :: rest1
else
ys1
}
else xs filter p
}
loop(xs, 0)
}
loopWithBuffer(addAll(new ListBuffer[T], keep, explore), rest)
case nil =>
keep

loop(xs, xs, 0, 0)
end filterConserve

/** Like `xs.lazyZip(ys).map(f)`, but returns list `xs` itself
* - instead of a copy - if function `f` maps all elements of
Expand Down
22 changes: 13 additions & 9 deletions compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package core

import java.security.MessageDigest
import scala.io.Codec
import Int.MaxValue
import Names._, StdNames._, Contexts._, Symbols._, Flags._, NameKinds._, Types._
import scala.internal.Chars
import Chars.isOperatorPart
import scala.internal.Chars.{isOperatorPart, digit2int}
import Definitions._
import nme._
import Decorators.concat
Expand Down Expand Up @@ -207,14 +207,18 @@ object NameOps {

/** Parsed function arity for function with some specific prefix */
private def functionArityFor(prefix: String): Int =
if (name.startsWith(prefix)) {
val suffix = name.toString.substring(prefix.length)
if (suffix.matches("\\d+"))
suffix.toInt
inline val MaxSafeInt = MaxValue / 10
val first = name.firstPart
def collectDigits(acc: Int, idx: Int): Int =
if idx == first.length then acc
else
-1
}
else -1
val d = digit2int(first(idx), 10)
if d < 0 || acc > MaxSafeInt then -1
else collectDigits(acc * 10 + d, idx + 1)
if first.startsWith(prefix) && prefix.length < first.length then
collectDigits(0, prefix.length)
else
-1

/** The name of the generic runtime operation corresponding to an array operation */
def genericArrayOp: TermName = name match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
if isRemovable(param.binder) then remove(param.binder)
else updateEntry(this, param, replacement)

def removeParam(ps: List[TypeParamRef]) = ps.filter(param ne _)
def removeParam(ps: List[TypeParamRef]) = ps.filterConserve(param ne _)

def replaceParam(tp: Type, atPoly: TypeLambda, atIdx: Int): Type =
current.ensureNonCyclic(atPoly.paramRefs(atIdx), tp.substParam(param, replacement))
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2501,7 +2501,6 @@ object TypeComparer {
case _ => String.valueOf(res)
}


/** The approximation state indicates how the pair of types currently compared
* relates to the types compared originally.
* - `None` : They are still the same types
Expand Down
11 changes: 8 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,8 @@ object Types {
*/
final def possibleSamMethods(using Context): Seq[SingleDenotation] = {
record("possibleSamMethods")
abstractTermMembers
.filterNot(m => m.symbol.matchingMember(defn.ObjectType).exists || m.symbol.isSuperAccessor)
abstractTermMembers.toList.filterConserve(m =>
!m.symbol.matchingMember(defn.ObjectType).exists && !m.symbol.isSuperAccessor)
}

/** The set of abstract type members of this type. */
Expand Down Expand Up @@ -3180,7 +3180,12 @@ object Types {
private var myParamRefs: List[ParamRefType] = null

def paramRefs: List[ParamRefType] = {
if (myParamRefs == null) myParamRefs = paramNames.indices.toList.map(newParamRef)
if myParamRefs == null then
def recur(paramNames: List[ThisName], i: Int): List[ParamRefType] =
paramNames match
case _ :: rest => newParamRef(i) :: recur(rest, i + 1)
case _ => Nil
myParamRefs = recur(paramNames, 0)
myParamRefs
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ object Erasure {
if (takesBridges(ctx.owner)) new Bridges(ctx.owner.asClass, erasurePhase).add(stats0)
else stats0
val (stats2, finalCtx) = super.typedStats(stats1, exprOwner)
(stats2.filter(!_.isEmpty), finalCtx)
(stats2.filterConserve(!_.isEmpty), finalCtx)
}

override def adapt(tree: Tree, pt: Type, locked: TypeVars, tryGadtHealing: Boolean)(using Context): Tree =
Expand Down
13 changes: 7 additions & 6 deletions compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ object OverridingPairs {
* pair has already been treated in a parent class.
* This may be refined in subclasses. @see Bridges for a use case.
*/
protected def parents: Array[Symbol] = base.info.parents.toArray map (_.typeSymbol)
protected def parents: Array[Symbol] = base.info.parents.toArray.map(_.typeSymbol)

/** Does `sym1` match `sym2` so that it qualifies as overriding.
* Types always match. Term symbols match if their membertypes
Expand Down Expand Up @@ -64,11 +64,12 @@ object OverridingPairs {
decls
}

private val subParents =
val subParents = MutableSymbolMap[BitSet]()
for (bc <- base.info.baseClasses)
subParents(bc) = BitSet(parents.indices.filter(parents(_).derivesFrom(bc)): _*)
subParents
private val subParents = MutableSymbolMap[BitSet]()
for bc <- base.info.baseClasses do
var bits = BitSet.empty
for i <- 0 until parents.length do
if parents(i).derivesFrom(bc) then bits += i
subParents(bc) = bits

private def hasCommonParentAsSubclass(cls1: Symbol, cls2: Symbol): Boolean =
(subParents(cls1) intersect subParents(cls2)).nonEmpty
Expand Down
12 changes: 8 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import reporting._
import config.Printers.{exhaustivity => debug}
import util.SrcPos
import NullOpsDecorator._
import collection.mutable

/** Space logic for checking exhaustivity and unreachability of pattern matching
*
Expand Down Expand Up @@ -123,10 +124,13 @@ trait SpaceLogic {
else if (canDecompose(tp) && decompose(tp).isEmpty) Empty
else sp
case Or(spaces) =>
val set = spaces.map(simplify(_)).flatMap {
case Or(ss) => ss
case s => Seq(s)
} filter (_ != Empty)
val buf = new mutable.ListBuffer[Space]
def include(s: Space) = if s != Empty then buf += s
for space <- spaces do
simplify(space) match
case Or(ss) => ss.foreach(include)
case s => include(s)
val set = buf.toList

if (set.isEmpty) Empty
else if (set.size == 1) set.toList(0)
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ trait Applications extends Compatibility {
}

def narrowBySize(alts: List[TermRef]): List[TermRef] =
alts.filter(sizeFits(_))
alts.filterConserve(sizeFits(_))

def narrowByShapes(alts: List[TermRef]): List[TermRef] =
val normArgs = args.mapWithIndexConserve(normArg(alts, _, _))
Expand All @@ -1843,11 +1843,11 @@ trait Applications extends Compatibility {
alts

def narrowByTrees(alts: List[TermRef], args: List[Tree], resultType: Type): List[TermRef] = {
val alts2 = alts.filter(alt =>
val alts2 = alts.filterConserve(alt =>
isDirectlyApplicableMethodRef(alt, args, resultType)
)
if (alts2.isEmpty && !ctx.isAfterTyper)
alts.filter(alt =>
alts.filterConserve(alt =>
isApplicableMethodRef(alt, args, resultType, keepConstraint = false)
)
else
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ trait Implicits:
case retained: SearchSuccess =>
val newPending =
if (retained eq found) || remaining.isEmpty then remaining
else remaining.filter(cand =>
else remaining.filterConserve(cand =>
compareCandidate(retained, cand.ref, cand.level) <= 0)
rank(newPending, retained, rfailures)
case fail: SearchFailure =>
Expand Down
20 changes: 13 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,20 @@ object Nullables:
tree match
case Block(stats, expr) =>
var shadowed: Set[(Name, List[Span])] = Set.empty
for case (stat: ValDef) <- stats if stat.mods.is(Mutable) do
for prevSpans <- candidates.put(stat.name, Nil) do
shadowed += (stat.name -> prevSpans)
reachable += stat.name
for stat <- stats do
stat match
case stat: ValDef if stat.mods.is(Mutable) =>
for prevSpans <- candidates.put(stat.name, Nil) do
shadowed += (stat.name -> prevSpans)
reachable += stat.name
case _ =>
traverseChildren(tree)
for case (stat: ValDef) <- stats if stat.mods.is(Mutable) do
for spans <- candidates.remove(stat.name) do
tracked += (stat.nameSpan.start -> spans) // candidates that survive until here are tracked
for stat <- stats do
stat match
case stat: ValDef if stat.mods.is(Mutable) =>
for spans <- candidates.remove(stat.name) do
tracked += (stat.nameSpan.start -> spans) // candidates that survive until here are tracked
case _ =>
candidates ++= shadowed
case Assign(Ident(name), rhs) =>
candidates.get(name) match
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ trait TypeAssigner {
}

def avoidingType(expr: Tree, bindings: List[Tree])(using Context): Type =
TypeOps.avoid(expr.tpe, localSyms(bindings).filter(_.isTerm))
TypeOps.avoid(expr.tpe, localSyms(bindings).filterConserve(_.isTerm))

def avoidPrivateLeaks(sym: Symbol)(using Context): Type =
if sym.owner.isClass && !sym.isOneOf(JavaOrPrivateOrSynthetic)
Expand Down