Skip to content

Commit 5e49b12

Browse files
committed
Refine isParametric tests
Mutable variables can appeal to parametricty only if they are not captured. We use "not captured by any closure" as a sound approximation for that, since variables themselves are currently not tracked, so we cannot use soemthing more finegrained.
1 parent cfec1d0 commit 5e49b12

File tree

9 files changed

+173
-34
lines changed

9 files changed

+173
-34
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import typer.ErrorReporting.{Addenda, err}
1616
import typer.ProtoTypes.{AnySelectionProto, LhsProto}
1717
import util.{SimpleIdentitySet, EqHashMap, EqHashSet, SrcPos, Property}
1818
import transform.SymUtils.*
19-
import transform.{Recheck, PreRecheck}
19+
import transform.{Recheck, PreRecheck, CapturedVars}
2020
import Recheck.*
2121
import scala.collection.mutable
2222
import CaptureSet.{withCaptureSetsExplained, IdempotentCaptRefMap, CompareResult}
@@ -149,15 +149,25 @@ object CheckCaptures:
149149

150150
private val seen = new EqHashSet[TypeRef]
151151

152+
/** Check that there is at least one method containing carrier and defined
153+
* in the scope of tparam. E.g. this is OK:
154+
* def f[T] = { ... var x: T ... }
155+
* So is this:
156+
* class C[T] { def f() = { class D { var x: T }}}
157+
* But this is not OK:
158+
* class C[T] { object o { var x: T }}
159+
*/
152160
extension (tparam: Symbol) def isParametricIn(carrier: Symbol): Boolean =
153-
val encl = carrier.maybeOwner.enclosingMethodOrClass
154-
if encl.isClass then tparam.isParametricIn(encl)
155-
else
156-
def recur(encl: Symbol): Boolean =
157-
if tparam.owner == encl then true
158-
else if encl.isStatic || !encl.exists then false
159-
else recur(encl.owner.enclosingMethodOrClass)
160-
recur(encl)
161+
carrier.exists && {
162+
val encl = carrier.owner.enclosingMethodOrClass
163+
if encl.isClass then tparam.isParametricIn(encl)
164+
else
165+
def recur(encl: Symbol): Boolean =
166+
if tparam.owner == encl then true
167+
else if encl.isStatic || !encl.exists then false
168+
else recur(encl.owner.enclosingMethodOrClass)
169+
recur(encl)
170+
}
161171

162172
def traverse(t: Type) =
163173
t.dealiasKeepAnnots match
@@ -168,9 +178,12 @@ object CheckCaptures:
168178
t.info match
169179
case TypeBounds(_, hi) if !t.isSealed && !t.symbol.isParametricIn(carrier) =>
170180
if hi.isAny then
181+
val detailStr =
182+
if t eq tp then "variable"
183+
else i"refers to the type variable $t, which"
171184
report.error(
172185
em"""$what cannot $have $tp since
173-
|that type refers to the type variable $t, which is not sealed.
186+
|that type $detailStr is not sealed.
174187
|$addendum""",
175188
pos)
176189
else
@@ -549,7 +562,7 @@ class CheckCaptures extends Recheck, SymTransformer:
549562
for case (arg: TypeTree, formal, pname) <- args.lazyZip(polyType.paramRefs).lazyZip((polyType.paramNames)) do
550563
if formal.isSealed then
551564
def where = if fn.symbol.exists then i" in an argument of ${fn.symbol}" else ""
552-
disallowRootCapabilitiesIn(arg.knownType, fn.symbol,
565+
disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
553566
i"Sealed type variable $pname", "be instantiated to",
554567
i"This is often caused by a local capability$where\nleaking as part of its result.",
555568
tree.srcPos)
@@ -590,13 +603,58 @@ class CheckCaptures extends Recheck, SymTransformer:
590603
openClosures = openClosures.tail
591604
end recheckClosureBlock
592605

606+
/** Maps mutable variables to the symbols that capture them (in the
607+
* CheckCaptures sense, i.e. symbol is referred to from a different method
608+
* than the one it is defined in).
609+
*/
610+
private val capturedBy = util.HashMap[Symbol, Symbol]()
611+
612+
/** Maps anonymous functions appearing as function arguments to
613+
* the function that is called.
614+
*/
615+
private val anonFunCallee = util.HashMap[Symbol, Symbol]()
616+
617+
/** Populates `capturedBy` and `anonFunCallee`. Called by `checkUnit`.
618+
*/
619+
private def collectCapturedMutVars(using Context) = new TreeTraverser:
620+
def traverse(tree: Tree)(using Context) = tree match
621+
case id: Ident =>
622+
val sym = id.symbol
623+
if sym.is(Mutable, butNot = Method) && sym.owner.isTerm then
624+
val enclMeth = ctx.owner.enclosingMethod
625+
if sym.enclosingMethod != enclMeth then
626+
capturedBy(sym) = enclMeth
627+
case Apply(fn, args) =>
628+
for case closureDef(mdef) <- args do
629+
anonFunCallee(mdef.symbol) = fn.symbol
630+
traverseChildren(tree)
631+
case Inlined(_, bindings, expansion) =>
632+
traverse(bindings)
633+
traverse(expansion)
634+
case mdef: DefDef =>
635+
if !mdef.symbol.isInlineMethod then traverseChildren(tree)
636+
case _ =>
637+
traverseChildren(tree)
638+
593639
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type =
594640
try
595641
if sym.is(Module) then sym.info // Modules are checked by checking the module class
596642
else
597643
if sym.is(Mutable) && !sym.hasAnnotation(defn.UncheckedCapturesAnnot) then
598-
disallowRootCapabilitiesIn(tree.tpt.knownType, sym,
599-
i"mutable $sym", "have type", "", sym.srcPos)
644+
val (carrier, addendum) = capturedBy.get(sym) match
645+
case Some(encl) =>
646+
val enclStr =
647+
if encl.isAnonymousFunction then
648+
val location = anonFunCallee.get(encl) match
649+
case Some(meth) if meth.exists => i" argument in a call to $meth"
650+
case _ => ""
651+
s"an anonymous function$location"
652+
else encl.show
653+
(NoSymbol, i"\nNote that $sym does not count as local since it is captured by $enclStr")
654+
case _ =>
655+
(sym, "")
656+
disallowRootCapabilitiesIn(
657+
tree.tpt.knownType, carrier, i"Mutable $sym", "have type", addendum, sym.srcPos)
600658
checkInferredResult(super.recheckValDef(tree, sym), tree)
601659
finally
602660
if !sym.is(Param) then
@@ -1170,11 +1228,12 @@ class CheckCaptures extends Recheck, SymTransformer:
11701228
private val setup: SetupAPI = thisPhase.prev.asInstanceOf[Setup]
11711229

11721230
override def checkUnit(unit: CompilationUnit)(using Context): Unit =
1173-
setup.setupUnit(ctx.compilationUnit.tpdTree, completeDef)
1231+
setup.setupUnit(unit.tpdTree, completeDef)
1232+
collectCapturedMutVars.traverse(unit.tpdTree)
11741233

11751234
if ctx.settings.YccPrintSetup.value then
11761235
val echoHeader = "[[syntax tree at end of cc setup]]"
1177-
val treeString = show(ctx.compilationUnit.tpdTree)
1236+
val treeString = show(unit.tpdTree)
11781237
report.echo(s"$echoHeader\n$treeString\n")
11791238

11801239
withCaptureSetsExplained:

tests/neg-custom-args/captures/buffers.check

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
-- Error: tests/neg-custom-args/captures/buffers.scala:11:6 ------------------------------------------------------------
22
11 | var elems: Array[A] = new Array[A](10) // error // error
33
| ^
4-
| mutable variable elems cannot have type Array[A] since
4+
| Mutable variable elems cannot have type Array[A] since
55
| that type refers to the type variable A, which is not sealed.
66
-- Error: tests/neg-custom-args/captures/buffers.scala:16:38 -----------------------------------------------------------
77
16 | def make[A: ClassTag](xs: A*) = new ArrayBuffer: // error
@@ -14,13 +14,13 @@
1414
11 | var elems: Array[A] = new Array[A](10) // error // error
1515
| ^^^^^^^^
1616
| Array cannot have element type A since
17-
| that type refers to the type variable A, which is not sealed.
17+
| that type variable is not sealed.
1818
| Since arrays are mutable, they have to be treated like variables,
1919
| so their element type must be sealed.
2020
-- Error: tests/neg-custom-args/captures/buffers.scala:22:9 ------------------------------------------------------------
2121
22 | val x: Array[A] = new Array[A](10) // error
2222
| ^^^^^^^^
2323
| Array cannot have element type A since
24-
| that type refers to the type variable A, which is not sealed.
24+
| that type variable is not sealed.
2525
| Since arrays are mutable, they have to be treated like variables,
2626
| so their element type must be sealed.

tests/neg-custom-args/captures/levels.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
-- Error: tests/neg-custom-args/captures/levels.scala:6:16 -------------------------------------------------------------
22
6 | private var v: T = init // error
33
| ^
4-
| mutable variable v cannot have type T since
5-
| that type refers to the type variable T, which is not sealed.
4+
| Mutable variable v cannot have type T since
5+
| that type variable is not sealed.
66
-- Error: tests/neg-custom-args/captures/levels.scala:17:13 ------------------------------------------------------------
77
17 | val _ = Ref[String => String]((x: String) => x) // error
88
| ^^^^^^^^^^^^^^^^^^^^^
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
-- [E129] Potential Issue Warning: tests/neg-custom-args/captures/sealed-leaks.scala:31:6 ------------------------------
2+
31 | ()
3+
| ^^
4+
| A pure expression does nothing in statement position
5+
|
6+
| longer explanation available when compiling with `-explain`
7+
-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:12:27 ------------------------------------------------------
8+
12 | val later2 = usingLogFile[(() => Unit) | Null] { f => () => f.write(0) } // error
9+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
10+
| Sealed type variable T cannot be instantiated to (() => Unit) | Null since
11+
| that type captures the root capability `cap`.
12+
| This is often caused by a local capability in an argument of method usingLogFile
13+
| leaking as part of its result.
14+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/sealed-leaks.scala:19:26 ---------------------------------
15+
19 | usingLogFile { f => x = f } // error
16+
| ^
17+
| Found: (f : java.io.FileOutputStream^)
18+
| Required: (java.io.FileOutputStream | Null)^{cap[Test2]}
19+
|
20+
| longer explanation available when compiling with `-explain`
21+
-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:30:10 ------------------------------------------------------
22+
30 | var x: T = y // error
23+
| ^
24+
| Mutable variable x cannot have type T since
25+
| that type variable is not sealed.
26+
-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:39:8 -------------------------------------------------------
27+
39 | var x: T = y // error
28+
| ^
29+
| Mutable variable x cannot have type T since
30+
| that type variable is not sealed.
31+
|
32+
| Note that variable x does not count as local since it is captured by an anonymous function
33+
-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:43:8 -------------------------------------------------------
34+
43 | var x: T = y // error
35+
| ^
36+
|Mutable variable x cannot have type T since
37+
|that type variable is not sealed.
38+
|
39+
|Note that variable x does not count as local since it is captured by an anonymous function argument in a call to method identity
40+
-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:47:8 -------------------------------------------------------
41+
47 | var x: T = y // error
42+
| ^
43+
| Mutable variable x cannot have type T since
44+
| that type variable is not sealed.
45+
|
46+
| Note that variable x does not count as local since it is captured by method foo
47+
-- Error: tests/neg-custom-args/captures/sealed-leaks.scala:11:14 ------------------------------------------------------
48+
11 | val later = usingLogFile { f => () => f.write(0) } // error
49+
| ^^^^^^^^^^^^
50+
| local reference f leaks into outer capture set of type parameter T of method usingLogFile

tests/neg-custom-args/captures/sealed-leaks.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,34 @@ def Test2 =
1818

1919
usingLogFile { f => x = f } // error
2020

21-
later()
21+
later()
22+
23+
def Test3 =
24+
def f[T](y: T) =
25+
var x: T = y
26+
()
27+
28+
class C[T](y: T):
29+
object o:
30+
var x: T = y // error
31+
()
32+
33+
class C2[T](y: T):
34+
def f =
35+
var x: T = y // ok
36+
()
37+
38+
def g1[T](y: T): T => Unit =
39+
var x: T = y // error
40+
y => x = y
41+
42+
def g2[T](y: T): T => Unit =
43+
var x: T = y // error
44+
identity(y => x = y)
45+
46+
def g3[T](y: T): Unit =
47+
var x: T = y // error
48+
def foo =
49+
x = y
50+
()
51+

tests/pos-special/stdlib/collection/Iterator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
868868
*/
869869
def duplicate: (Iterator[A]^{this}, Iterator[A]^{this}) = {
870870
val gap = new scala.collection.mutable.Queue[A @uncheckedCaptures]
871-
var ahead: Iterator[A] = null
871+
var ahead: Iterator[A @uncheckedCaptures] = null // ahead is captured by Partner, so A is not recognized as parametric
872872
class Partner extends AbstractIterator[A] {
873873
override def knownSize: Int = self.synchronized {
874874
val thisSize = self.knownSize

tests/pos-special/stdlib/collection/immutable/LazyListIterable.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -852,9 +852,9 @@ final class LazyListIterable[+A] private(private[this] var lazyState: () => Lazy
852852
else if (!isEmpty) {
853853
b.append(head)
854854
var cursor = this
855-
@inline def appendCursorElement(): Unit = b.append(sep).append(cursor.head)
855+
inline def appendCursorElement(): Unit = b.append(sep).append(cursor.head)
856856
var scout = tail
857-
@inline def scoutNonEmpty: Boolean = scout.stateDefined && !scout.isEmpty
857+
inline def scoutNonEmpty: Boolean = scout.stateDefined && !scout.isEmpty
858858
if ((cursor ne scout) && (!scout.stateDefined || (cursor.state ne scout.state))) {
859859
cursor = scout
860860
if (scoutNonEmpty) {
@@ -998,7 +998,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
998998

999999
private def filterImpl[A](ll: LazyListIterable[A]^, p: A => Boolean, isFlipped: Boolean): LazyListIterable[A]^{ll, p} = {
10001000
// DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD
1001-
var restRef = ll // val restRef = new ObjectRef(ll)
1001+
var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[filterImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric
10021002
newLL {
10031003
var elem: A = null.asInstanceOf[A]
10041004
var found = false
@@ -1015,7 +1015,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10151015

10161016
private def collectImpl[A, B](ll: LazyListIterable[A]^, pf: PartialFunction[A, B]^): LazyListIterable[B]^{ll, pf} = {
10171017
// DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD
1018-
var restRef = ll // val restRef = new ObjectRef(ll)
1018+
var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[collectImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric
10191019
newLL {
10201020
val marker = Statics.pfMarker
10211021
val toMarker = anyToMarker.asInstanceOf[A => B] // safe because Function1 is erased
@@ -1034,9 +1034,9 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10341034

10351035
private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f} = {
10361036
// DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD
1037-
var restRef = ll // val restRef = new ObjectRef(ll)
1037+
var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[flatMapImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric
10381038
newLL {
1039-
var it: Iterator[B]^{ll, f} = null
1039+
var it: Iterator[B @uncheckedCaptures]^{ll, f} = null
10401040
var itHasNext = false
10411041
var rest = restRef // var rest = restRef.elem
10421042
while (!itHasNext && !rest.isEmpty) {
@@ -1058,7 +1058,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10581058

10591059
private def dropImpl[A](ll: LazyListIterable[A]^, n: Int): LazyListIterable[A]^{ll} = {
10601060
// DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD
1061-
var restRef = ll // val restRef = new ObjectRef(ll)
1061+
var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[dropImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric
10621062
var iRef = n // val iRef = new IntRef(n)
10631063
newLL {
10641064
var rest = restRef // var rest = restRef.elem
@@ -1075,7 +1075,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10751075

10761076
private def dropWhileImpl[A](ll: LazyListIterable[A]^, p: A => Boolean): LazyListIterable[A]^{ll, p} = {
10771077
// DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD
1078-
var restRef = ll // val restRef = new ObjectRef(ll)
1078+
var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[dropWhileImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric
10791079
newLL {
10801080
var rest = restRef // var rest = restRef.elem
10811081
while (!rest.isEmpty && p(rest.head)) {
@@ -1088,8 +1088,8 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10881088

10891089
private def takeRightImpl[A](ll: LazyListIterable[A]^, n: Int): LazyListIterable[A]^{ll} = {
10901090
// DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD
1091-
var restRef = ll // val restRef = new ObjectRef(ll)
1092-
var scoutRef = ll // val scoutRef = new ObjectRef(ll)
1091+
var restRef: LazyListIterable[A @uncheckedCaptures]^{cap[takeRightImpl]} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric
1092+
var scoutRef: LazyListIterable[A @uncheckedCaptures]^{cap[takeRightImpl]} = ll // same situation
10931093
var remainingRef = n // val remainingRef = new IntRef(n)
10941094
newLL {
10951095
var scout = scoutRef // var scout = scoutRef.elem

tests/pos-special/stdlib/collection/immutable/TreeSeqMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ object TreeSeqMap extends MapFactory[TreeSeqMap] {
609609
}
610610

611611
final def splitAt(n: Int): (Ordering[T], Ordering[T]) = {
612-
var rear = Ordering.empty[T]
612+
var rear: Ordering[T @uncheckedCaptures] = Ordering.empty[T]
613613
var i = n
614614
(modifyOrRemove { (o, v) =>
615615
i -= 1

tests/pos-special/stdlib/collection/immutable/Vector.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ sealed abstract class Vector[+A] private[immutable] (private[immutable] final va
229229
// k >= 0, k = suffix.knownSize
230230
val tinyAppendLimit = 4 + vectorSliceCount
231231
if (k < tinyAppendLimit) {
232-
var v: Vector[B] = this
232+
var v: Vector[B @uncheckedCaptures] = this
233233
suffix match {
234234
case it: Iterable[_] => it.asInstanceOf[Iterable[B]].foreach(x => v = v.appended(x))
235235
case _ => suffix.iterator.foreach(x => v = v.appended(x))

0 commit comments

Comments
 (0)