@@ -28,6 +28,16 @@ object SepChecker:
28
28
else NeedsCheck
29
29
end Captures
30
30
31
+ /** The kind of checked type, used for composing error messages */
32
+ enum TypeKind :
33
+ case Result (sym : Symbol , inferred : Boolean )
34
+ case Argument
35
+
36
+ def dclSym = this match
37
+ case Result (sym, _) => sym
38
+ case _ => NoSymbol
39
+ end TypeKind
40
+
31
41
class SepChecker (checker : CheckCaptures .CheckerAPI ) extends tpd.TreeTraverser :
32
42
import tpd .*
33
43
import checker .*
@@ -204,7 +214,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
204
214
for (arg, idx) <- indexedArgs do
205
215
if arg.needsSepCheck then
206
216
val ac = formalCaptures(arg)
207
- checkType(arg.formalType, arg.srcPos, NoSymbol , " the argument's adapted type " )
217
+ checkType(arg.formalType, arg.srcPos, TypeKind . Argument )
208
218
val hiddenInArg = ac.hidden.footprint
209
219
// println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
210
220
val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
@@ -232,18 +242,29 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
232
242
sepUseError(tree, usedFootprint, overlap)
233
243
234
244
def checkType (tpt : Tree , sym : Symbol )(using Context ): Unit =
235
- checkType(tpt.nuType, tpt.srcPos, sym, " " )
236
-
237
- /** Check that all parts of type `tpe` are separated.
238
- * @param tpe the type to check
239
- * @param pos position for error reporting
240
- * @param sym if `tpe` is the (result-) type of a val or def, the symbol of
241
- * this definition, otherwise NoSymbol. If `sym` exists we
242
- * deduct its associated direct and reach capabilities everywhere
243
- * from the capture sets we check.
244
- * @param what a string describing what kind of type it is
245
- */
246
- def checkType (tpe : Type , pos : SrcPos , sym : Symbol , what : String )(using Context ): Unit =
245
+ checkType(tpt.nuType, tpt.srcPos,
246
+ TypeKind .Result (sym, inferred = tpt.isInstanceOf [InferredTypeTree ]))
247
+
248
+ /** Check that all parts of type `tpe` are separated. */
249
+ def checkType (tpe : Type , pos : SrcPos , kind : TypeKind )(using Context ): Unit =
250
+
251
+ def typeDescr = kind match
252
+ case TypeKind .Result (sym, inferred) =>
253
+ def inferredStr = if inferred then " inferred" else " "
254
+ def resultStr = if sym.info.isInstanceOf [MethodicType ] then " result" else " "
255
+ i " $sym's $inferredStr$resultStr"
256
+ case TypeKind .Argument =>
257
+ " the argument's adapted type"
258
+
259
+ def explicitRefs (tp : Type ): Refs = tp match
260
+ case tp : (TermRef | ThisType ) => SimpleIdentitySet (tp)
261
+ case AnnotatedType (parent, _) => explicitRefs(parent)
262
+ case AndType (tp1, tp2) => explicitRefs(tp1) ++ explicitRefs(tp2)
263
+ case OrType (tp1, tp2) => explicitRefs(tp1) ** explicitRefs(tp2)
264
+ case _ => emptySet
265
+
266
+ def prune (refs : Refs ): Refs =
267
+ refs.deductSym(kind.dclSym) -- explicitRefs(tpe)
247
268
248
269
def checkParts (parts : List [Type ]): Unit =
249
270
var footprint : Refs = emptySet
@@ -265,21 +286,21 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
265
286
if ! globalOverlap.isEmpty then
266
287
val (prevStr, prevRefs, overlap) = parts.iterator.take(checked)
267
288
.map: prev =>
268
- val prevRefs = mapRefs(prev.deepCaptureSet.elems).footprint.deductSym(sym )
289
+ val prevRefs = prune( mapRefs(prev.deepCaptureSet.elems).footprint)
269
290
(i " , $prev , " , prevRefs, prevRefs.overlapWith(next))
270
291
.dropWhile(_._3.isEmpty)
271
292
.nextOption
272
293
.getOrElse((" " , current, globalOverlap))
273
294
report.error(
274
- em """ Separation failure in $what type $tpe.
295
+ em """ Separation failure in $typeDescr type $tpe.
275
296
|One part, $part , $nextRel ${CaptureSet (next)}.
276
297
|A previous part $prevStr $prevRel ${CaptureSet (prevRefs)}.
277
298
|The two sets overlap at ${CaptureSet (overlap)}. """ ,
278
299
pos)
279
300
280
301
val partRefs = part.deepCaptureSet.elems
281
- val partFootprint = partRefs.footprint.deductSym(sym )
282
- val partHidden = partRefs.hidden.footprint.deductSym(sym ) -- partFootprint
302
+ val partFootprint = prune( partRefs.footprint)
303
+ val partHidden = prune( partRefs.hidden.footprint) -- partFootprint
283
304
284
305
checkSep(footprint, partHidden, identity, " references" , " hides" )
285
306
checkSep(hiddenSet, partHidden, _.hidden, " also hides" , " hides" )
@@ -325,9 +346,43 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
325
346
case t =>
326
347
foldOver(c, t)
327
348
349
+ def checkParameters () =
350
+ val badParams = mutable.ListBuffer [Symbol ]()
351
+ def currentOwner = kind.dclSym.orElse(ctx.owner)
352
+ for hiddenRef <- prune(tpe.deepCaptureSet.elems.hidden.footprint) do
353
+ val refSym = hiddenRef.termSymbol
354
+ if refSym.is(TermParam )
355
+ && ! refSym.hasAnnotation(defn.ConsumeAnnot )
356
+ && ! refSym.info.derivesFrom(defn.Caps_SharedCapability )
357
+ && currentOwner.isContainedIn(refSym.owner)
358
+ then
359
+ badParams += refSym
360
+ if badParams.nonEmpty then
361
+ def paramsStr (params : List [Symbol ]): String = (params : @ unchecked) match
362
+ case p :: Nil => i " ${p.name}"
363
+ case p :: p2 :: Nil => i " ${p.name} and ${p2.name}"
364
+ case p :: ps => i " ${p.name}, ${paramsStr(ps)}"
365
+ val (pluralS, singleS) = if badParams.tail.isEmpty then (" " , " s" ) else (" s" , " " )
366
+ report.error(
367
+ em """ Separation failure: $typeDescr type $tpe hides parameter $pluralS ${paramsStr(badParams.toList)}
368
+ |The parameter $pluralS need $singleS to be annotated with @consume to allow this. """ ,
369
+ pos)
370
+
371
+ def flagHiddenParams =
372
+ kind match
373
+ case TypeKind .Result (sym, _) =>
374
+ ! sym.isAnonymousFunction // we don't check return types of anonymous functions
375
+ && ! sym.is(Case ) // We don't check so far binders in patterns since they
376
+ // have inferred universal types. TODO come back to this;
377
+ // either infer more precise types for such binders or
378
+ // "see through them" when we look at hidden sets.
379
+ case TypeKind .Argument =>
380
+ false
381
+
328
382
if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
329
383
traverse(Captures .None , tpe)
330
384
traverse.toCheck.foreach(checkParts)
385
+ if flagHiddenParams then checkParameters()
331
386
end checkType
332
387
333
388
private def collectMethodTypes (tp : Type ): List [TermLambda ] = tp match
0 commit comments