Skip to content

Commit 790370d

Browse files
committed
Fix #9176: fast check of cyclic object initialization
1 parent 585f424 commit 790370d

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package dotty.tools.dotc
2+
package transform
3+
package init
4+
5+
import core._
6+
import Flags._
7+
import Contexts._
8+
import Types._
9+
import Symbols._
10+
import Decorators._
11+
import printing.SyntaxHighlighting
12+
import reporting.trace
13+
import config.Printers.init
14+
15+
import ast.Trees._
16+
import ast.tpd
17+
18+
import scala.collection.mutable
19+
20+
21+
/** Check that static objects can be initialized without cycles
22+
*
23+
* For the check to be fast, the algorithm uses coarse approximation.
24+
* We construct a dependency graph as follows:
25+
*
26+
* - if a static object `O` is used in another class/static-object `B`,
27+
* then O -> B
28+
* - if a class `C` is instantiated in a another class/static-object `B`,
29+
* then C -> B
30+
* - if a static-object/class `A` extends another class `B`,
31+
* then A -> B
32+
*
33+
* Given the graph above, we check if there exists cycles.
34+
*
35+
* This check does not need to care about objects in libraries, as separate
36+
* compilation ensures that there cannot be cyles between two separately
37+
* compiled projects.
38+
*/
39+
class CheckGlobal {
40+
case class Dependency(sym: Symbol, source: tpd.Tree)
41+
42+
/** Checking state */
43+
case class State(var visited: Set[Symbol], path: Vector[tpd.Tree], obj: Symbol) {
44+
def cyclicPath(using Context): String = if (path.isEmpty) "" else " Cyclic path:\n" + {
45+
var indentCount = 0
46+
var last: String = ""
47+
val sb = new StringBuilder
48+
path.foreach { tree =>
49+
indentCount += 1
50+
val pos = tree.sourcePos
51+
val prefix = s"${ " " * indentCount }-> "
52+
val line =
53+
if pos.source.exists then
54+
val loc = "[ " + pos.source.file.name + ":" + (pos.line + 1) + " ]"
55+
val code = SyntaxHighlighting.highlight(pos.lineContent.trim)
56+
i"$code\t$loc"
57+
else
58+
tree.show
59+
60+
if (last != line) sb.append(prefix + line + "\n")
61+
62+
last = line
63+
}
64+
sb.toString
65+
}
66+
}
67+
68+
case class Error(state: State) {
69+
def issue(using Context): Unit =
70+
report.warning("Cylic object dependencies detected." + state.cyclicPath, state.obj.defTree.srcPos)
71+
}
72+
73+
/** Summary of dependencies */
74+
private val summaryCache = mutable.Map.empty[Symbol, List[Dependency]]
75+
76+
def check(obj: Symbol)(using Context): Unit = trace("checking " + obj.show, init) {
77+
checkDependencies(obj, State(visited = Set.empty, path = Vector.empty, obj)) match
78+
case Some(err) => err.issue
79+
case _ =>
80+
}
81+
82+
private def check(sym: Symbol, state: State)(using Context): Option[Error] = trace("checking " + sym.show, init) {
83+
if sym == state.obj then
84+
Some(Error(state))
85+
else if state.visited.contains(sym) then
86+
None
87+
else
88+
state.visited = state.visited + sym
89+
checkDependencies(sym, state)
90+
}
91+
92+
private def checkDependencies(sym: Symbol, state: State)(using Context): Option[Error] = trace("checking dependencies of " + sym.show, init) {
93+
val cls = if sym.is(Module) then sym.moduleClass.asClass else sym.asClass
94+
val deps = analyze(cls)
95+
Util.traceIndented("dependencies of " + sym.show + " = " + deps.map(_.sym.show).mkString(","), init)
96+
var res: Option[Error] = None
97+
// TODO: stop early
98+
deps.foreach { dep =>
99+
val state2: State = state.copy(path = state.path :+ dep.source)
100+
if res.isEmpty then res = check(dep.sym, state2)
101+
}
102+
res
103+
}
104+
105+
private def analyze(cls: ClassSymbol)(using Context): List[Dependency] =
106+
def isStaticObjectRef(sym: Symbol) =
107+
sym.isTerm && !sym.is(Package) && sym.is(Module)
108+
&& sym.isStatic && sym.moduleClass != cls
109+
110+
if (cls.defTree.isEmpty) Nil
111+
else if (summaryCache.contains(cls)) summaryCache(cls)
112+
else {
113+
val cdef = cls.defTree.asInstanceOf[tpd.TypeDef]
114+
val tpl = cdef.rhs.asInstanceOf[tpd.Template]
115+
var dependencies: List[Dependency] = Nil
116+
val traverser = new tpd.TreeTraverser {
117+
override def traverse(tree: tpd.Tree)(using Context): Unit =
118+
tree match {
119+
case tree: tpd.RefTree if isStaticObjectRef(tree.symbol) =>
120+
dependencies = Dependency(tree.symbol, tree) :: dependencies
121+
122+
case tdef: tpd.TypeDef =>
123+
// don't go into nested classes
124+
125+
case tree: tpd.New =>
126+
dependencies = Dependency(tree.tpe.classSymbol, tree) :: dependencies
127+
128+
case _ =>
129+
traverseChildren(tree)
130+
}
131+
}
132+
133+
// TODO: the traverser might create duplicate entries for parents
134+
tpl.parents.foreach { tree =>
135+
dependencies = Dependency(tree.tpe.classSymbol, tree) :: dependencies
136+
}
137+
138+
traverser.traverse(tpl)
139+
summaryCache(cls) = dependencies
140+
dependencies
141+
}
142+
143+
def debugCache(using Context) =
144+
summaryCache.map(_.show + " -> " + _.map(_.sym.show).mkString(",")).mkString("\n")
145+
}

compiler/src/dotty/tools/dotc/transform/init/Checker.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class Checker extends MiniPhase {
2626
// cache of class summary
2727
private val baseEnv = Env(null)
2828

29+
val globalChecker = new CheckGlobal
30+
2931
override val runsAfter = Set(Pickler.name)
3032

3133
override def isEnabled(using Context): Boolean =
@@ -58,6 +60,10 @@ class Checker extends MiniPhase {
5860
)
5961

6062
Checking.checkClassBody(tree)
63+
64+
// check cycles of object dependencies
65+
if cls.is(Flags.Module) && cls.isStatic then
66+
globalChecker.check(cls.sourceModule)
6167
}
6268

6369
tree

tests/init/neg/i9176.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class Foo(val opposite: Foo)
2+
case object A extends Foo(B) // error
3+
case object B extends Foo(A) // error
4+
object Test {
5+
def main(args: Array[String]): Unit = {
6+
println(A.opposite)
7+
println(B.opposite)
8+
}
9+
}

0 commit comments

Comments
 (0)