Skip to content

Code refactoring of initialization checker #16066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 17 additions & 27 deletions compiler/src/dotty/tools/dotc/transform/init/Checker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import StdNames._
import dotty.tools.dotc.transform._
import Phases._

import scala.collection.mutable

import Semantic._

class Checker extends Phase {
class Checker extends Phase:

override def phaseName: String = Checker.name

Expand All @@ -31,17 +32,23 @@ class Checker extends Phase {

override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
val checkCtx = ctx.fresh.setPhase(this.start)
Semantic.checkTasks(using checkCtx) {
val traverser = new InitTreeTraverser()
units.foreach { unit => traverser.traverse(unit.tpdTree) }
}
val traverser = new InitTreeTraverser()
units.foreach { unit => traverser.traverse(unit.tpdTree) }
val classes = traverser.getClasses()

Semantic.checkClasses(classes)(using checkCtx)

units

def run(using Context): Unit = {
def run(using Context): Unit =
// ignore, we already called `Semantic.check()` in `runOn`
}
()

class InitTreeTraverser extends TreeTraverser:
private val classes: mutable.ArrayBuffer[ClassSymbol] = new mutable.ArrayBuffer

def getClasses(): List[ClassSymbol] = classes.toList

class InitTreeTraverser(using WorkList) extends TreeTraverser {
override def traverse(tree: Tree)(using Context): Unit =
traverseChildren(tree)
tree match {
Expand All @@ -53,29 +60,12 @@ class Checker extends Phase {
mdef match
case tdef: TypeDef if tdef.isClassDef =>
val cls = tdef.symbol.asClass
val thisRef = ThisRef(cls)
if shouldCheckClass(cls) then Semantic.addTask(thisRef)
classes.append(cls)
case _ =>

case _ =>
}
}

private def shouldCheckClass(cls: ClassSymbol)(using Context) = {
val instantiable: Boolean =
cls.is(Flags.Module) ||
!cls.isOneOf(Flags.AbstractOrTrait) && {
// see `Checking.checkInstantiable` in typer
val tp = cls.appliedRef
val stp = SkolemType(tp)
val selfType = cls.givenSelfType.asSeenFrom(stp, cls)
!selfType.exists || stp <:< selfType
}

// A concrete class may not be instantiated if the self type is not satisfied
instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass
}
}
end InitTreeTraverser

object Checker:
val name: String = "initChecker"
Expand Down
145 changes: 83 additions & 62 deletions compiler/src/dotty/tools/dotc/transform/init/Semantic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1206,72 +1206,49 @@ object Semantic:
cls == defn.AnyValClass ||
cls == defn.ObjectClass

// ----- Work list ---------------------------------------------------
case class Task(value: ThisRef)

class WorkList private[Semantic]():
private val pendingTasks: mutable.ArrayBuffer[Task] = new mutable.ArrayBuffer

def addTask(task: Task): Unit =
if !pendingTasks.contains(task) then pendingTasks.append(task)

/** Process the worklist until done */
final def work()(using Cache, Context): Unit =
for task <- pendingTasks
do doTask(task)

/** Check an individual class
*
* This method should only be called from the work list scheduler.
*/
private def doTask(task: Task)(using Cache, Context): Unit =
val thisRef = task.value
val tpl = thisRef.klass.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]

@tailrec
def iterate(): Unit = {
given Promoted = Promoted.empty(thisRef.klass)
given Trace = Trace.empty.add(thisRef.klass.defTree)
given reporter: Reporter.BufferedReporter = new Reporter.BufferedReporter
// ----- API --------------------------------

thisRef.ensureFresh()
/** Check an individual class
*
* The class to be checked must be an instantiable concrete class.
*/
private def checkClass(classSym: ClassSymbol)(using Cache, Context): Unit =
val thisRef = ThisRef(classSym)
val tpl = classSym.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]

// set up constructor parameters
for param <- tpl.constr.termParamss.flatten do
thisRef.updateField(param.symbol, Hot)
@tailrec
def iterate(): Unit = {
given Promoted = Promoted.empty(classSym)
given Trace = Trace.empty.add(classSym.defTree)
given reporter: Reporter.BufferedReporter = new Reporter.BufferedReporter

log("checking " + task) { eval(tpl, thisRef, thisRef.klass) }
reporter.errors.foreach(_.issue)
thisRef.ensureFresh()

if cache.hasChanged && reporter.errors.isEmpty then
// code to prepare cache and heap for next iteration
cache.prepareForNextIteration()
iterate()
else
cache.prepareForNextClass()
}
// set up constructor parameters
for param <- tpl.constr.termParamss.flatten do
thisRef.updateField(param.symbol, Hot)

iterate()
end doTask
end WorkList
inline def workList(using wl: WorkList): WorkList = wl
log("checking " + classSym) { eval(tpl, thisRef, classSym) }
reporter.errors.foreach(_.issue)

// ----- API --------------------------------
if cache.hasChanged && reporter.errors.isEmpty then
// code to prepare cache and heap for next iteration
cache.prepareForNextIteration()
iterate()
else
cache.prepareForNextClass()
}

/** Add a checking task to the work list */
def addTask(thisRef: ThisRef)(using WorkList) = workList.addTask(Task(thisRef))
iterate()
end checkClass

/** Check the specified tasks
*
* Semantic.checkTasks {
* Semantic.addTask(...)
* }
/**
* Check the specified concrete classes
*/
def checkTasks(using Context)(taskBuilder: WorkList ?=> Unit): Unit =
val workList = new WorkList
val cache = new Cache
taskBuilder(using workList)
workList.work()(using cache, ctx)
def checkClasses(classes: List[ClassSymbol])(using Context): Unit =
given Cache()
for classSym <- classes if isConcreteClass(classSym) do
checkClass(classSym)

// ----- Semantic definition --------------------------------

Expand All @@ -1296,7 +1273,10 @@ object Semantic:
*
* This method only handles cache logic and delegates the work to `cases`.
*
* The parameter `cacheResult` is used to reduce the size of the cache.
* @param expr The expression to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the expression is located.
* @param cacheResult It is used to reduce the size of the cache.
*/
def eval(expr: Tree, thisV: Ref, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) {
cache.get(thisV, expr) match
Expand Down Expand Up @@ -1326,6 +1306,10 @@ object Semantic:
/** Handles the evaluation of different expressions
*
* Note: Recursive call should go to `eval` instead of `cases`.
*
* @param expr The expression to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the expression `expr` is located.
*/
def cases(expr: Tree, thisV: Ref, klass: ClassSymbol): Contextual[Value] =
val trace2 = trace.add(expr)
Expand Down Expand Up @@ -1503,7 +1487,14 @@ object Semantic:
report.error("[Internal error] unexpected tree" + Trace.show, expr)
Hot

/** Handle semantics of leaf nodes */
/** Handle semantics of leaf nodes
*
* For leaf nodes, their semantics is determined by their types.
*
* @param tp The type to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the type `tp` is located.
*/
def cases(tp: Type, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) {
tp match
case _: ConstantType =>
Expand Down Expand Up @@ -1541,7 +1532,12 @@ object Semantic:
Hot
}

/** Resolve C.this that appear in `klass` */
/** Resolve C.this that appear in `klass`
*
* @param target The class symbol for `C` for which `C.this` is to be resolved.
* @param thisV The value for `D.this` where `D` is represented by the parameter `klass`.
* @param klass The enclosing class where the type `C.this` is located.
*/
def resolveThis(target: ClassSymbol, thisV: Value, klass: ClassSymbol): Contextual[Value] = log("resolving " + target.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) {
if target == klass then thisV
else if target.is(Flags.Package) then Hot
Expand All @@ -1566,7 +1562,12 @@ object Semantic:

}

/** Compute the outer value that correspond to `tref.prefix` */
/** Compute the outer value that correspond to `tref.prefix`
*
* @param tref The type whose prefix is to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the type `tref` is located.
*/
def outerValue(tref: TypeRef, thisV: Ref, klass: ClassSymbol): Contextual[Value] =
val cls = tref.classSymbol.asClass
if tref.prefix == NoPrefix then
Expand All @@ -1577,7 +1578,12 @@ object Semantic:
if cls.isAllOf(Flags.JavaInterface) then Hot
else cases(tref.prefix, thisV, klass)

/** Initialize part of an abstract object in `klass` of the inheritance chain */
/** Initialize part of an abstract object in `klass` of the inheritance chain
*
* @param tpl The class body to be evaluated.
* @param thisV The value of the current object to be initialized.
* @param klass The class to which the template belongs.
*/
def init(tpl: Template, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("init " + klass.show, printer, (_: Value).show) {
val paramsMap = tpl.constr.termParamss.flatten.map { vdef =>
vdef.name -> thisV.objekt.field(vdef.symbol)
Expand Down Expand Up @@ -1782,3 +1788,18 @@ object Semantic:
if (sym.isEffectivelyFinal || sym.isConstructor) sym
else sym.matchingMember(cls.appliedRef)
}

private def isConcreteClass(cls: ClassSymbol)(using Context) = {
val instantiable: Boolean =
cls.is(Flags.Module) ||
!cls.isOneOf(Flags.AbstractOrTrait) && {
// see `Checking.checkInstantiable` in typer
val tp = cls.appliedRef
val stp = SkolemType(tp)
val selfType = cls.givenSelfType.asSeenFrom(stp, cls)
!selfType.exists || stp <:< selfType
}

// A concrete class may not be instantiated if the self type is not satisfied
instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass
}