@@ -10,11 +10,28 @@ import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
10
10
import CaptureSet .{Refs , emptySet , HiddenSet }
11
11
import config .Printers .capt
12
12
import StdNames .nme
13
- import util .{SimpleIdentitySet , EqHashMap }
13
+ import util .{SimpleIdentitySet , EqHashMap , SrcPos }
14
+
15
+ object SepChecker :
16
+
17
+ /** Enumerates kinds of captures encountered so far */
18
+ enum Captures :
19
+ case None
20
+ case Explicit // one or more explicitly declared captures
21
+ case Hidden // exacttly one hidden captures
22
+ case NeedsCheck // one hidden capture and one other capture (hidden or declared)
23
+
24
+ def add (that : Captures ): Captures =
25
+ if this == None then that
26
+ else if that == None then this
27
+ else if this == Explicit && that == Explicit then Explicit
28
+ else NeedsCheck
29
+ end Captures
14
30
15
31
class SepChecker (checker : CheckCaptures .CheckerAPI ) extends tpd.TreeTraverser :
16
32
import tpd .*
17
33
import checker .*
34
+ import SepChecker .*
18
35
19
36
/** The set of capabilities that are hidden by a polymorphic result type
20
37
* of some previous definition.
@@ -52,21 +69,17 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
52
69
53
70
private def hidden (using Context ): Refs =
54
71
val seen : util.EqHashSet [CaptureRef ] = new util.EqHashSet
55
-
56
- def hiddenByElem (elem : CaptureRef ): Refs =
57
- if seen.add(elem) then elem match
58
- case Fresh .Cap (hcs) => hcs.elems.filter(! _.isRootCapability) ++ recur(hcs.elems)
59
- case ReadOnlyCapability (ref) => hiddenByElem(ref).map(_.readOnly)
60
- case _ => emptySet
61
- else emptySet
62
-
63
72
def recur (cs : Refs ): Refs =
64
73
(emptySet /: cs): (elems, elem) =>
65
- elems ++ hiddenByElem(elem)
66
-
74
+ if seen.add(elem) then elems ++ hiddenByElem(elem, recur )
75
+ else elems
67
76
recur(refs)
68
77
end hidden
69
78
79
+ private def containsHidden (using Context ): Boolean =
80
+ refs.exists: ref =>
81
+ ! hiddenByElem(ref, _ => emptySet).isEmpty
82
+
70
83
/** Deduct the footprint of `sym` and `sym*` from `refs` */
71
84
private def deductSym (sym : Symbol )(using Context ) =
72
85
val ref = sym.termRef
@@ -79,6 +92,11 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
79
92
refs -- captures(dep).footprint
80
93
end extension
81
94
95
+ private def hiddenByElem (ref : CaptureRef , recur : Refs => Refs )(using Context ): Refs = ref match
96
+ case Fresh .Cap (hcs) => hcs.elems.filter(! _.isRootCapability) ++ recur(hcs.elems)
97
+ case ReadOnlyCapability (ref1) => hiddenByElem(ref1, recur).map(_.readOnly)
98
+ case _ => emptySet
99
+
82
100
/** The captures of an argument or prefix widened to the formal parameter, if
83
101
* the latter contains a cap.
84
102
*/
@@ -186,6 +204,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
186
204
for (arg, idx) <- indexedArgs do
187
205
if arg.needsSepCheck then
188
206
val ac = formalCaptures(arg)
207
+ checkType(arg.formalType, arg.srcPos, NoSymbol , " the argument's adapted type" )
189
208
val hiddenInArg = ac.hidden.footprint
190
209
// println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
191
210
val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
@@ -212,6 +231,105 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
212
231
if ! overlap.isEmpty then
213
232
sepUseError(tree, usedFootprint, overlap)
214
233
234
+ def checkType (tpt : Tree , sym : Symbol )(using Context ): Unit =
235
+ checkType(tpt.nuType, tpt.srcPos, sym, " " )
236
+
237
+ /** Check that all parts of type `tpe` are separated.
238
+ * @param tpe the type to check
239
+ * @param pos position for error reporting
240
+ * @param sym if `tpe` is the (result-) type of a val or def, the symbol of
241
+ * this definition, otherwise NoSymbol. If `sym` exists we
242
+ * deduct its associated direct and reach capabilities everywhere
243
+ * from the capture sets we check.
244
+ * @param what a string describing what kind of type it is
245
+ */
246
+ def checkType (tpe : Type , pos : SrcPos , sym : Symbol , what : String )(using Context ): Unit =
247
+
248
+ def checkParts (parts : List [Type ]): Unit =
249
+ var footprint : Refs = emptySet
250
+ var hiddenSet : Refs = emptySet
251
+ var checked = 0
252
+ for part <- parts do
253
+
254
+ /** Report an error if `current` and `next` overlap.
255
+ * @param current the footprint or hidden set seen so far
256
+ * @param next the footprint or hidden set of the next part
257
+ * @param mapRefs a function over the capture set elements of the next part
258
+ * that returns the references of the same kind as `current`
259
+ * (i.e. the part's footprint or hidden set)
260
+ * @param prevRel a verbal description of current ("references or "hides")
261
+ * @param nextRel a verbal descriiption of next
262
+ */
263
+ def checkSep (current : Refs , next : Refs , mapRefs : Refs => Refs , prevRel : String , nextRel : String ): Unit =
264
+ val globalOverlap = current.overlapWith(next)
265
+ if ! globalOverlap.isEmpty then
266
+ val (prevStr, prevRefs, overlap) = parts.iterator.take(checked)
267
+ .map: prev =>
268
+ val prevRefs = mapRefs(prev.deepCaptureSet.elems).footprint.deductSym(sym)
269
+ (i " , $prev , " , prevRefs, prevRefs.overlapWith(next))
270
+ .dropWhile(_._3.isEmpty)
271
+ .nextOption
272
+ .getOrElse((" " , current, globalOverlap))
273
+ report.error(
274
+ em """ Separation failure in $what type $tpe.
275
+ |One part, $part , $nextRel ${CaptureSet (next)}.
276
+ |A previous part $prevStr $prevRel ${CaptureSet (prevRefs)}.
277
+ |The two sets overlap at ${CaptureSet (overlap)}. """ ,
278
+ pos)
279
+
280
+ val partRefs = part.deepCaptureSet.elems
281
+ val partFootprint = partRefs.footprint.deductSym(sym)
282
+ val partHidden = partRefs.hidden.footprint.deductSym(sym) -- partFootprint
283
+
284
+ checkSep(footprint, partHidden, identity, " references" , " hides" )
285
+ checkSep(hiddenSet, partHidden, _.hidden, " also hides" , " hides" )
286
+ checkSep(hiddenSet, partFootprint, _.hidden, " hides" , " references" )
287
+
288
+ footprint ++= partFootprint
289
+ hiddenSet ++= partHidden
290
+ checked += 1
291
+ end for
292
+ end checkParts
293
+
294
+ object traverse extends TypeAccumulator [Captures ]:
295
+
296
+ /** A stack of part lists to check. We maintain this since immediately
297
+ * checking parts when traversing the type would check innermost to oputermost.
298
+ * But we want to check outermost parts first since this prioritized errors
299
+ * that are more obvious.
300
+ */
301
+ var toCheck : List [List [Type ]] = Nil
302
+
303
+ private val seen = util.HashSet [Symbol ]()
304
+
305
+ def apply (c : Captures , t : Type ) =
306
+ if variance < 0 then c
307
+ else
308
+ val t1 = t.dealias
309
+ t1 match
310
+ case t @ AppliedType (tycon, args) =>
311
+ val c1 = foldOver(Captures .None , t)
312
+ if c1 == Captures .NeedsCheck then
313
+ toCheck = (tycon :: args) :: toCheck
314
+ c.add(c1)
315
+ case t @ CapturingType (parent, cs) =>
316
+ val c1 = this (c, parent)
317
+ if cs.elems.containsHidden then c1.add(Captures .Hidden )
318
+ else if ! cs.elems.isEmpty then c1.add(Captures .Explicit )
319
+ else c1
320
+ case t : TypeRef if t.symbol.isAbstractOrParamType =>
321
+ if seen.contains(t.symbol) then c
322
+ else
323
+ seen += t.symbol
324
+ apply(apply(c, t.prefix), t.info.bounds.hi)
325
+ case t =>
326
+ foldOver(c, t)
327
+
328
+ if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
329
+ traverse(Captures .None , tpe)
330
+ traverse.toCheck.foreach(checkParts)
331
+ end checkType
332
+
215
333
private def collectMethodTypes (tp : Type ): List [TermLambda ] = tp match
216
334
case tp : MethodType => tp :: collectMethodTypes(tp.resType)
217
335
case tp : PolyType => collectMethodTypes(tp.resType)
@@ -231,7 +349,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
231
349
(formal, arg) <- mt.paramInfos.zip(args)
232
350
dep <- formal.captureSet.elems.toList
233
351
do
234
- val referred = dep match
352
+ val referred = dep.stripReach match
235
353
case dep : TermParamRef =>
236
354
argMap(dep.binder)(dep.paramNum) :: Nil
237
355
case dep : ThisType if dep.cls == fn.symbol.owner =>
@@ -269,11 +387,13 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
269
387
defsShadow = saved
270
388
case tree : ValOrDefDef =>
271
389
traverseChildren(tree)
272
- if previousDefs.nonEmpty && ! tree.symbol.isOneOf(TermParamOrAccessor ) then
273
- capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
274
- defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
275
- resultType(tree.symbol) = tree.tpt.nuType
276
- previousDefs.head += tree
390
+ if ! tree.symbol.isOneOf(TermParamOrAccessor ) then
391
+ checkType(tree.tpt, tree.symbol)
392
+ if previousDefs.nonEmpty then
393
+ capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
394
+ defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
395
+ resultType(tree.symbol) = tree.tpt.nuType
396
+ previousDefs.head += tree
277
397
case _ =>
278
398
traverseChildren(tree)
279
399
end SepChecker
0 commit comments