From e4b0c22af1eb2e8c799bfda84af3c02d7d770684 Mon Sep 17 00:00:00 2001 From: EnzeXing Date: Tue, 24 Aug 2021 18:51:28 -0400 Subject: [PATCH 1/2] Adding fix-point computation --- .../tools/dotc/transform/init/Checker.scala | 22 ++++++-- .../tools/dotc/transform/init/Semantic.scala | 56 ++++++++++++------- 2 files changed, 54 insertions(+), 24 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala index d13c4df04547..17ed120978c5 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Checker.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Checker.scala @@ -17,6 +17,7 @@ import Phases._ import scala.collection.mutable +import util.EqHashMap class Checker extends Phase { @@ -78,11 +79,24 @@ class Checker extends Phase { val paramValues = tpl.constr.termParamss.flatten.map(param => param.symbol -> Hot).toMap - given Promoted = Promoted.empty - given Trace = Trace.empty - given Env = Env(paramValues) + // A wrapper for eval that uses two-cache method to compute fix-point after evaluating a class body + def fixPointEval(expr: Tree, thisV: Addr, klass: ClassSymbol, inputCache: evalCache, outputCache: evalCache): Result = { + import semantic.Heap._ + given Promoted = Promoted.empty + given Trace = Trace.empty + given Env = Env(paramValues) + given (evalCache, evalCache) = (inputCache, outputCache) + val res = eval(expr, thisV, klass) + if !res.errors.isEmpty then res + else if inputCache.equal(outputCache) then + inputCache.commitEvalCache(thisV) + res + else + thisV.emptyField() + fixPointEval(expr, thisV, klass, outputCache, new evalCache) + } - val res = eval(tpl, thisRef, cls) + val res = fixPointEval(tpl, thisRef, cls, new evalCache, new evalCache) res.errors.foreach(_.issue) } diff --git a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala index 1aa9656a2084..2bcedcc0ac66 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala @@ -124,7 +124,7 @@ class Semantic { */ def updateField(field: Symbol, value: Value): Contextual[Unit] = val fields = heap(ref).fields - assert(!fields.contains(field), field.show + " already init, new = " + value + ", ref =" + ref) + // assert(!fields.contains(field), field.show + " already init, new = " + value + ", ref =" + ref) fields(field) = value /** Update the immediate outer of the given `klass` of the abstract object @@ -133,6 +133,11 @@ class Semantic { */ def updateOuter(klass: ClassSymbol, value: Value): Contextual[Unit] = heap(ref).outers(klass) = value + + /** Eliminate the field is necessary when computing fix-point + * + */ + def emptyField(): Contextual[Unit] = heap(ref).fields.clear() end extension } type Heap = Heap.Heap @@ -268,7 +273,7 @@ class Semantic { } /** The state that threads through the interpreter */ - type Contextual[T] = (Env, Context, Trace, Promoted) ?=> T + type Contextual[T] = (Env, (evalCache, evalCache), Context, Trace, Promoted) ?=> T // ----- Error Handling ----------------------------------- @@ -337,7 +342,7 @@ class Semantic { if target.is(Flags.Lazy) then given Trace = trace1 val rhs = target.defTree.asInstanceOf[ValDef].rhs - eval(rhs, addr, target.owner.asClass, cacheResult = true) + eval(rhs, addr, target.owner.asClass) else val obj = heap(addr) if obj.fields.contains(target) then @@ -351,7 +356,7 @@ class Semantic { Result(Hot, Nil) else if target.hasSource then val rhs = target.defTree.asInstanceOf[ValOrDefDef].rhs - eval(rhs, addr, target.owner.asClass, cacheResult = true) + eval(rhs, addr, target.owner.asClass) else val error = CallUnknown(field, source, trace.toVector) Result(Hot, error :: Nil) @@ -404,15 +409,15 @@ class Semantic { if target.isPrimaryConstructor then given Env = env2 val tpl = cls.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template] - val res = withTrace(trace.add(cls.defTree)) { eval(tpl, addr, cls, cacheResult = true) } + val res = withTrace(trace.add(cls.defTree)) { eval(tpl, addr, cls) } Result(addr, res.errors) else if target.isConstructor then given Env = env2 - eval(ddef.rhs, addr, cls, cacheResult = true) + eval(ddef.rhs, addr, cls) else // normal method call withEnv(if isLocal then env else Env.empty) { - eval(ddef.rhs, addr, cls, cacheResult = true) ++ checkArgs + eval(ddef.rhs, addr, cls) ++ checkArgs } else if addr.canIgnoreMethodCall(target) then Result(Hot, Nil) @@ -433,7 +438,7 @@ class Semantic { if meth.name.toString == "tupled" then Result(value, Nil) // a call like `fun.tupled` else withEnv(env) { - eval(body, thisV, klass, cacheResult = true) ++ checkArgs + eval(body, thisV, klass) ++ checkArgs } case RefSet(refs) => @@ -702,19 +707,30 @@ class Semantic { * * This method only handles cache logic and delegates the work to `cases`. */ - def eval(expr: Tree, thisV: Addr, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) { + + type evalCache = EqHashMap[Tree, Result] + def inputCache(using caches: (evalCache, evalCache)) = caches._1 + def outputCache(using caches: (evalCache, evalCache)) = caches._2 + extension (evalcache: evalCache) + def equal(evalcache2: evalCache): Boolean = + evalcache.toSeq.forall((key, result) => evalcache2.contains(key) && evalcache2(key) == result) + + def commitEvalCache(thisV: Addr): Unit = + val innerMap = cache.getOrElseUpdate(thisV, new EqHashMap[Tree, Value]) + evalcache.toSeq.foreach((key, result) => innerMap(key) = result.value) + end extension + + def eval(expr: Tree, thisV: Addr, klass: ClassSymbol): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) { val innerMap = cache.getOrElseUpdate(thisV, new EqHashMap[Tree, Value]) - if (innerMap.contains(expr)) Result(innerMap(expr), Errors.empty) - else { - // no need to compute fix-point, because - // 1. the result is decided by `cfg` for a legal program - // (heap change is irrelevant thanks to monotonicity) - // 2. errors will have been reported for an illegal program - innerMap(expr) = Hot + if innerMap.contains(expr) then Result(innerMap(expr), Errors.empty) + else if outputCache.contains(expr) then outputCache(expr) + else + // need to compute fix-point for soundness + if !inputCache.contains(expr) then inputCache(expr) = Result(Hot, Errors.empty) + outputCache(expr) = inputCache(expr) val res = cases(expr, thisV, klass) - if cacheResult then innerMap(expr) = res.value else innerMap.remove(expr) + outputCache(expr) = res res - } } /** Evaluate a list of expressions */ @@ -889,7 +905,7 @@ class Semantic { case vdef : ValDef => // local val definition // TODO: support explicit @cold annotation for local definitions - eval(vdef.rhs, thisV, klass, cacheResult = true) + eval(vdef.rhs, thisV, klass) case ddef : DefDef => // local method @@ -1109,7 +1125,7 @@ class Semantic { tpl.body.foreach { case vdef : ValDef if !vdef.symbol.is(Flags.Lazy) && !vdef.rhs.isEmpty => given Env = Env.empty - val res = eval(vdef.rhs, thisV, klass, cacheResult = true) + val res = eval(vdef.rhs, thisV, klass) errorBuffer ++= res.errors thisV.updateField(vdef.symbol, res.value) fieldsChanged = true From 9da85fd22ff9b8bcb5c3bef6bd9cd2d914168eb8 Mon Sep 17 00:00:00 2001 From: EnzeXing Date: Tue, 24 Aug 2021 22:38:01 -0400 Subject: [PATCH 2/2] Update tests --- tests/init/neg/enum-desugared.check | 4 ++-- tests/init/neg/inner-loop.scala | 2 +- tests/init/neg/local-warm4.check | 5 ++--- tests/init/neg/unsound1.check | 4 ++++ tests/init/neg/unsound1.scala | 11 +++++++++++ tests/init/neg/unsound2.check | 6 ++++++ tests/init/neg/unsound2.scala | 10 ++++++++++ tests/init/neg/unsound3.check | 5 +++++ tests/init/neg/unsound3.scala | 13 +++++++++++++ tests/init/neg/unsound4.check | 6 ++++++ tests/init/neg/unsound4.scala | 4 ++++ 11 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 tests/init/neg/unsound1.check create mode 100644 tests/init/neg/unsound1.scala create mode 100644 tests/init/neg/unsound2.check create mode 100644 tests/init/neg/unsound2.scala create mode 100644 tests/init/neg/unsound3.check create mode 100644 tests/init/neg/unsound3.scala create mode 100644 tests/init/neg/unsound4.check create mode 100644 tests/init/neg/unsound4.scala diff --git a/tests/init/neg/enum-desugared.check b/tests/init/neg/enum-desugared.check index 567b8104c154..2b090378e877 100644 --- a/tests/init/neg/enum-desugared.check +++ b/tests/init/neg/enum-desugared.check @@ -13,6 +13,6 @@ | Cannot prove that the value is fully-initialized. May only use initialized value as method arguments. | | The unsafe promotion may cause the following problem: - | Calling the external method method ordinal may cause initialization errors. Calling trace: + | Calling the external method method name may cause initialization errors. Calling trace: | -> Array(this.LazyErrorId, this.NoExplanationID) // error // error [ enum-desugared.scala:17 ] - | -> def errorNumber: Int = this.ordinal() - 2 [ enum-desugared.scala:8 ] + | -> override def productPrefix: String = this.name() [ enum-desugared.scala:29 ] diff --git a/tests/init/neg/inner-loop.scala b/tests/init/neg/inner-loop.scala index d2b6a1fae8e0..c6d5c615580c 100644 --- a/tests/init/neg/inner-loop.scala +++ b/tests/init/neg/inner-loop.scala @@ -1,6 +1,6 @@ class Outer { outer => class Inner extends Outer { - val x = 5 + outer.n // error + val x = 5 + outer.n } val inner = new Inner val n = 6 // error diff --git a/tests/init/neg/local-warm4.check b/tests/init/neg/local-warm4.check index fda1ee1b928c..664900b3cffa 100644 --- a/tests/init/neg/local-warm4.check +++ b/tests/init/neg/local-warm4.check @@ -6,6 +6,5 @@ | -> class A(x: Int) extends Foo(x) { [ local-warm4.scala:6 ] | -> val b = new B(y) [ local-warm4.scala:10 ] | -> class B(x: Int) extends A(x) { [ local-warm4.scala:13 ] - | -> class A(x: Int) extends Foo(x) { [ local-warm4.scala:6 ] - | -> increment() [ local-warm4.scala:9 ] - | -> updateA() [ local-warm4.scala:21 ] + | -> if y < 10 then increment() [ local-warm4.scala:23 ] + | -> updateA() [ local-warm4.scala:21 ] diff --git a/tests/init/neg/unsound1.check b/tests/init/neg/unsound1.check new file mode 100644 index 000000000000..54e24546845c --- /dev/null +++ b/tests/init/neg/unsound1.check @@ -0,0 +1,4 @@ +-- Error: tests/init/neg/unsound1.scala:2:35 --------------------------------------------------------------------------- +2 | if (m > 0) println(foo(m - 1).a2.n) // error + | ^^^^^^^^^^^^^^^ + | Access field A.this.foo(A.this.m.-(1)).a2.n on a value with an unknown initialization status. diff --git a/tests/init/neg/unsound1.scala b/tests/init/neg/unsound1.scala new file mode 100644 index 000000000000..3854504c8478 --- /dev/null +++ b/tests/init/neg/unsound1.scala @@ -0,0 +1,11 @@ +class A(m: Int) { + if (m > 0) println(foo(m - 1).a2.n) // error + def foo(n: Int): B = + if (n % 2 == 0) + new B(new A(n - 1), foo(n - 1).a1) + else + new B(this, new A(n - 1)) + var n: Int = 10 +} + +class B(val a1: A, val a2: A) \ No newline at end of file diff --git a/tests/init/neg/unsound2.check b/tests/init/neg/unsound2.check new file mode 100644 index 000000000000..346caec4cb1b --- /dev/null +++ b/tests/init/neg/unsound2.check @@ -0,0 +1,6 @@ +-- Error: tests/init/neg/unsound2.scala:5:26 --------------------------------------------------------------------------- +5 | def getN: Int = a.n // error + | ^^^ + | Access field B.this.a.n on a value with an unknown initialization status. Calling trace: + | -> println(foo(x).getB) [ unsound2.scala:8 ] + | -> def foo(y: Int): B = if (y > 10) then B(bar(y - 1), foo(y - 1).getN) else B(bar(y), 10) [ unsound2.scala:2 ] diff --git a/tests/init/neg/unsound2.scala b/tests/init/neg/unsound2.scala new file mode 100644 index 000000000000..5ae0c624c32e --- /dev/null +++ b/tests/init/neg/unsound2.scala @@ -0,0 +1,10 @@ +case class A(x: Int) { + def foo(y: Int): B = if (y > 10) then B(bar(y - 1), foo(y - 1).getN) else B(bar(y), 10) + def bar(y: Int): A = if (y > 10) then A(y - 1) else this + class B(a: A, b: Int) { + def getN: Int = a.n // error + def getB: Int = b + } + println(foo(x).getB) + val n: Int = 10 +} \ No newline at end of file diff --git a/tests/init/neg/unsound3.check b/tests/init/neg/unsound3.check new file mode 100644 index 000000000000..71766cf2d10b --- /dev/null +++ b/tests/init/neg/unsound3.check @@ -0,0 +1,5 @@ +-- Error: tests/init/neg/unsound3.scala:10:38 -------------------------------------------------------------------------- +10 | if (x < 12) then foo().getC().b else newB // error + | ^^^^^^^^^^^^^^ + | Access field C.this.foo().getC().b on a value with an unknown initialization status. Calling trace: + | -> val b = foo() [ unsound3.scala:12 ] diff --git a/tests/init/neg/unsound3.scala b/tests/init/neg/unsound3.scala new file mode 100644 index 000000000000..9ede5c7f97d0 --- /dev/null +++ b/tests/init/neg/unsound3.scala @@ -0,0 +1,13 @@ +class B(c: C) { + def getC() = c +} + +class C { + var x = 10 + def foo(): B = { + x += 1 + val newB = new B(this) + if (x < 12) then foo().getC().b else newB // error + } + val b = foo() +} \ No newline at end of file diff --git a/tests/init/neg/unsound4.check b/tests/init/neg/unsound4.check new file mode 100644 index 000000000000..4ed254444928 --- /dev/null +++ b/tests/init/neg/unsound4.check @@ -0,0 +1,6 @@ +-- Error: tests/init/neg/unsound4.scala:3:8 ---------------------------------------------------------------------------- +3 | val aAgain = foo(5) // error + | ^ + | Access non-initialized value aAgain. Calling trace: + | -> val aAgain = foo(5) // error [ unsound4.scala:3 ] + | -> def foo(x: Int): A = if (x < 5) then this else foo(x - 1).aAgain [ unsound4.scala:2 ] diff --git a/tests/init/neg/unsound4.scala b/tests/init/neg/unsound4.scala new file mode 100644 index 000000000000..8a6e26fe8a6b --- /dev/null +++ b/tests/init/neg/unsound4.scala @@ -0,0 +1,4 @@ +class A { + def foo(x: Int): A = if (x < 5) then this else foo(x - 1).aAgain + val aAgain = foo(5) // error +} \ No newline at end of file