@@ -346,17 +346,19 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
346
346
case t =>
347
347
foldOver(c, t)
348
348
349
- def checkParams (refsToCheck : Refs , descr : => String ) =
349
+ def checkRefs (refsToCheck : Refs , descr : => String ) =
350
350
val badParams = mutable.ListBuffer [Symbol ]()
351
351
def currentOwner = kind.dclSym.orElse(ctx.owner)
352
- for hiddenRef <- prune(refsToCheck.footprint) do
353
- val refSym = hiddenRef.termSymbol
354
- if refSym.is(TermParam )
355
- && ! refSym.hasAnnotation(defn.ConsumeAnnot )
356
- && ! refSym.info.derivesFrom(defn.Caps_SharedCapability )
357
- && currentOwner.isContainedIn(refSym.owner)
358
- then
359
- badParams += refSym
352
+ for hiddenRef <- prune(refsToCheck) do
353
+ val refSym = hiddenRef.pathRoot.termSymbol // TODO also hangle ThisTypes as pathRoots
354
+ if refSym.exists && ! refSym.info.derivesFrom(defn.Caps_SharedCapability ) then
355
+ if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.owner.enclosingMethodOrClass) then
356
+ report.error(em """ Separation failure: $descr non-local $refSym""" , pos)
357
+ else if refSym.is(TermParam )
358
+ && ! refSym.hasAnnotation(defn.ConsumeAnnot )
359
+ && currentOwner.isContainedIn(refSym.owner)
360
+ then
361
+ badParams += refSym
360
362
if badParams.nonEmpty then
361
363
def paramsStr (params : List [Symbol ]): String = (params : @ unchecked) match
362
364
case p :: Nil => i " ${p.name}"
@@ -368,25 +370,28 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
368
370
|The parameter $pluralS need $singleS to be annotated with @consume to allow this. """ ,
369
371
pos)
370
372
371
- def checkParameters () = kind match
373
+ def checkLegalRefs () = kind match
372
374
case TypeKind .Result (sym, _) =>
373
375
if ! sym.isAnonymousFunction // we don't check return types of anonymous functions
374
376
&& ! sym.is(Case ) // We don't check so far binders in patterns since they
375
377
// have inferred universal types. TODO come back to this;
376
378
// either infer more precise types for such binders or
377
379
// "see through them" when we look at hidden sets.
378
- then checkParams(tpe.deepCaptureSet.elems.hidden, i " $typeDescr type $tpe hides " )
380
+ then
381
+ val refs = tpe.deepCaptureSet.elems
382
+ val toCheck = refs.hidden.footprint -- refs.footprint
383
+ checkRefs(toCheck, i " $typeDescr type $tpe hides " )
379
384
case TypeKind .Argument (arg) =>
380
385
if tpe.hasAnnotation(defn.ConsumeAnnot ) then
381
386
val capts = captures(arg)
382
387
def descr (verb : String ) = i " argument to @consume parameter with type ${arg.nuType} $verb"
383
- checkParams (capts, descr(" refers to" ))
384
- checkParams (capts.hidden, descr(" hides" ))
388
+ checkRefs (capts.footprint , descr(" refers to" ))
389
+ checkRefs (capts.hidden.footprint , descr(" hides" ))
385
390
386
391
if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
387
392
traverse(Captures .None , tpe)
388
393
traverse.toCheck.foreach(checkParts)
389
- checkParameters ()
394
+ checkLegalRefs ()
390
395
end checkType
391
396
392
397
private def collectMethodTypes (tp : Type ): List [TermLambda ] = tp match
@@ -426,10 +431,12 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
426
431
if argss.nestedExists(_.needsSepCheck) then
427
432
checkApply(tree, argss.flatten, dependencies(tree, argss))
428
433
434
+ def isUnsafeAssumeSeparate (tree : Tree )(using Context ): Boolean = tree match
435
+ case tree : Apply => tree.symbol == defn.Caps_unsafeAssumeSeparate
436
+ case _ => false
437
+
429
438
def traverse (tree : Tree )(using Context ): Unit =
430
- tree match
431
- case tree : Apply if tree.symbol == defn.Caps_unsafeAssumeSeparate => return
432
- case _ =>
439
+ if isUnsafeAssumeSeparate(tree) then return
433
440
checkUse(tree)
434
441
tree match
435
442
case tree : GenericApply =>
@@ -446,7 +453,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
446
453
defsShadow = saved
447
454
case tree : ValOrDefDef =>
448
455
traverseChildren(tree)
449
- if ! tree.symbol.isOneOf(TermParamOrAccessor ) then
456
+ if ! tree.symbol.isOneOf(TermParamOrAccessor ) && ! isUnsafeAssumeSeparate(tree.rhs) then
450
457
checkType(tree.tpt, tree.symbol)
451
458
if previousDefs.nonEmpty then
452
459
capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
@@ -460,5 +467,3 @@ end SepChecker
460
467
461
468
462
469
463
-
464
-
0 commit comments