Skip to content

Commit 131f070

Browse files
committed
Check separation of different parts of a declared type.
1 parent b1da91a commit 131f070

File tree

9 files changed

+222
-48
lines changed

9 files changed

+222
-48
lines changed

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

+137-17
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,28 @@ import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
1010
import CaptureSet.{Refs, emptySet, HiddenSet}
1111
import config.Printers.capt
1212
import StdNames.nme
13-
import util.{SimpleIdentitySet, EqHashMap}
13+
import util.{SimpleIdentitySet, EqHashMap, SrcPos}
14+
15+
object SepChecker:
16+
17+
/** Enumerates kinds of captures encountered so far */
18+
enum Captures:
19+
case None
20+
case Explicit // one or more explicitly declared captures
21+
case Hidden // exacttly one hidden captures
22+
case NeedsCheck // one hidden capture and one other capture (hidden or declared)
23+
24+
def add(that: Captures): Captures =
25+
if this == None then that
26+
else if that == None then this
27+
else if this == Explicit && that == Explicit then Explicit
28+
else NeedsCheck
29+
end Captures
1430

1531
class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
1632
import tpd.*
1733
import checker.*
34+
import SepChecker.*
1835

1936
/** The set of capabilities that are hidden by a polymorphic result type
2037
* of some previous definition.
@@ -52,21 +69,17 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
5269

5370
private def hidden(using Context): Refs =
5471
val seen: util.EqHashSet[CaptureRef] = new util.EqHashSet
55-
56-
def hiddenByElem(elem: CaptureRef): Refs =
57-
if seen.add(elem) then elem match
58-
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs.elems)
59-
case ReadOnlyCapability(ref) => hiddenByElem(ref).map(_.readOnly)
60-
case _ => emptySet
61-
else emptySet
62-
6372
def recur(cs: Refs): Refs =
6473
(emptySet /: cs): (elems, elem) =>
65-
elems ++ hiddenByElem(elem)
66-
74+
if seen.add(elem) then elems ++ hiddenByElem(elem, recur)
75+
else elems
6776
recur(refs)
6877
end hidden
6978

79+
private def containsHidden(using Context): Boolean =
80+
refs.exists: ref =>
81+
!hiddenByElem(ref, _ => emptySet).isEmpty
82+
7083
/** Deduct the footprint of `sym` and `sym*` from `refs` */
7184
private def deductSym(sym: Symbol)(using Context) =
7285
val ref = sym.termRef
@@ -79,6 +92,11 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
7992
refs -- captures(dep).footprint
8093
end extension
8194

95+
private def hiddenByElem(ref: CaptureRef, recur: Refs => Refs)(using Context): Refs = ref match
96+
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs.elems)
97+
case ReadOnlyCapability(ref1) => hiddenByElem(ref1, recur).map(_.readOnly)
98+
case _ => emptySet
99+
82100
/** The captures of an argument or prefix widened to the formal parameter, if
83101
* the latter contains a cap.
84102
*/
@@ -186,6 +204,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
186204
for (arg, idx) <- indexedArgs do
187205
if arg.needsSepCheck then
188206
val ac = formalCaptures(arg)
207+
checkType(arg.formalType, arg.srcPos, NoSymbol, " the argument's adapted type")
189208
val hiddenInArg = ac.hidden.footprint
190209
//println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
191210
val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
@@ -212,6 +231,105 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
212231
if !overlap.isEmpty then
213232
sepUseError(tree, usedFootprint, overlap)
214233

234+
def checkType(tpt: Tree, sym: Symbol)(using Context): Unit =
235+
checkType(tpt.nuType, tpt.srcPos, sym, "")
236+
237+
/** Check that all parts of type `tpe` are separated.
238+
* @param tpe the type to check
239+
* @param pos position for error reporting
240+
* @param sym if `tpe` is the (result-) type of a val or def, the symbol of
241+
* this definition, otherwise NoSymbol. If `sym` exists we
242+
* deduct its associated direct and reach capabilities everywhere
243+
* from the capture sets we check.
244+
* @param what a string describing what kind of type it is
245+
*/
246+
def checkType(tpe: Type, pos: SrcPos, sym: Symbol, what: String)(using Context): Unit =
247+
248+
def checkParts(parts: List[Type]): Unit =
249+
var footprint: Refs = emptySet
250+
var hiddenSet: Refs = emptySet
251+
var checked = 0
252+
for part <- parts do
253+
254+
/** Report an error if `current` and `next` overlap.
255+
* @param current the footprint or hidden set seen so far
256+
* @param next the footprint or hidden set of the next part
257+
* @param mapRefs a function over the capture set elements of the next part
258+
* that returns the references of the same kind as `current`
259+
* (i.e. the part's footprint or hidden set)
260+
* @param prevRel a verbal description of current ("references or "hides")
261+
* @param nextRel a verbal descriiption of next
262+
*/
263+
def checkSep(current: Refs, next: Refs, mapRefs: Refs => Refs, prevRel: String, nextRel: String): Unit =
264+
val globalOverlap = current.overlapWith(next)
265+
if !globalOverlap.isEmpty then
266+
val (prevStr, prevRefs, overlap) = parts.iterator.take(checked)
267+
.map: prev =>
268+
val prevRefs = mapRefs(prev.deepCaptureSet.elems).footprint.deductSym(sym)
269+
(i", $prev , ", prevRefs, prevRefs.overlapWith(next))
270+
.dropWhile(_._3.isEmpty)
271+
.nextOption
272+
.getOrElse(("", current, globalOverlap))
273+
report.error(
274+
em"""Separation failure in$what type $tpe.
275+
|One part, $part , $nextRel ${CaptureSet(next)}.
276+
|A previous part$prevStr $prevRel ${CaptureSet(prevRefs)}.
277+
|The two sets overlap at ${CaptureSet(overlap)}.""",
278+
pos)
279+
280+
val partRefs = part.deepCaptureSet.elems
281+
val partFootprint = partRefs.footprint.deductSym(sym)
282+
val partHidden = partRefs.hidden.footprint.deductSym(sym) -- partFootprint
283+
284+
checkSep(footprint, partHidden, identity, "references", "hides")
285+
checkSep(hiddenSet, partHidden, _.hidden, "also hides", "hides")
286+
checkSep(hiddenSet, partFootprint, _.hidden, "hides", "references")
287+
288+
footprint ++= partFootprint
289+
hiddenSet ++= partHidden
290+
checked += 1
291+
end for
292+
end checkParts
293+
294+
object traverse extends TypeAccumulator[Captures]:
295+
296+
/** A stack of part lists to check. We maintain this since immediately
297+
* checking parts when traversing the type would check innermost to oputermost.
298+
* But we want to check outermost parts first since this prioritized errors
299+
* that are more obvious.
300+
*/
301+
var toCheck: List[List[Type]] = Nil
302+
303+
private val seen = util.HashSet[Symbol]()
304+
305+
def apply(c: Captures, t: Type) =
306+
if variance < 0 then c
307+
else
308+
val t1 = t.dealias
309+
t1 match
310+
case t @ AppliedType(tycon, args) =>
311+
val c1 = foldOver(Captures.None, t)
312+
if c1 == Captures.NeedsCheck then
313+
toCheck = (tycon :: args) :: toCheck
314+
c.add(c1)
315+
case t @ CapturingType(parent, cs) =>
316+
val c1 = this(c, parent)
317+
if cs.elems.containsHidden then c1.add(Captures.Hidden)
318+
else if !cs.elems.isEmpty then c1.add(Captures.Explicit)
319+
else c1
320+
case t: TypeRef if t.symbol.isAbstractOrParamType =>
321+
if seen.contains(t.symbol) then c
322+
else
323+
seen += t.symbol
324+
apply(apply(c, t.prefix), t.info.bounds.hi)
325+
case t =>
326+
foldOver(c, t)
327+
328+
if !tpe.hasAnnotation(defn.UntrackedCapturesAnnot) then
329+
traverse(Captures.None, tpe)
330+
traverse.toCheck.foreach(checkParts)
331+
end checkType
332+
215333
private def collectMethodTypes(tp: Type): List[TermLambda] = tp match
216334
case tp: MethodType => tp :: collectMethodTypes(tp.resType)
217335
case tp: PolyType => collectMethodTypes(tp.resType)
@@ -231,7 +349,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
231349
(formal, arg) <- mt.paramInfos.zip(args)
232350
dep <- formal.captureSet.elems.toList
233351
do
234-
val referred = dep match
352+
val referred = dep.stripReach match
235353
case dep: TermParamRef =>
236354
argMap(dep.binder)(dep.paramNum) :: Nil
237355
case dep: ThisType if dep.cls == fn.symbol.owner =>
@@ -269,11 +387,13 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
269387
defsShadow = saved
270388
case tree: ValOrDefDef =>
271389
traverseChildren(tree)
272-
if previousDefs.nonEmpty && !tree.symbol.isOneOf(TermParamOrAccessor) then
273-
capt.println(i"sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}")
274-
defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
275-
resultType(tree.symbol) = tree.tpt.nuType
276-
previousDefs.head += tree
390+
if !tree.symbol.isOneOf(TermParamOrAccessor) then
391+
checkType(tree.tpt, tree.symbol)
392+
if previousDefs.nonEmpty then
393+
capt.println(i"sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}")
394+
defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
395+
resultType(tree.symbol) = tree.tpt.nuType
396+
previousDefs.head += tree
277397
case _ =>
278398
traverseChildren(tree)
279399
end SepChecker
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt-depfun.scala:10:43 ----------------------------------
2-
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
2+
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error
33
| ^^^^^^^
44
| Found: Str^{} ->{ac, y, z} Str^{y, z}
55
| Required: Str^{y, z} ->{fresh} Str^{y, z}
66
|
77
| longer explanation available when compiling with `-explain`
8-
-- Error: tests/neg-custom-args/captures/capt-depfun.scala:10:24 -------------------------------------------------------
9-
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
10-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
11-
| Separation failure: Str^{y, z} => Str^{y, z} captures a root element hiding {ac, y, z}
12-
| and also refers to {y, z}.
13-
| The two sets overlap at {y, z}

tests/neg-custom-args/captures/capt-depfun.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ class Str
77
def f(y: Cap, z: Cap) =
88
def g(): C @retains(y, z) = ???
99
val ac: ((x: Cap) => Str @retains(x) => Str @retains(x)) = ???
10-
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
10+
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error
+25-10
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
1-
-- Error: tests/neg-custom-args/captures/reaches2.scala:8:10 -----------------------------------------------------------
2-
8 | ps.map((x, y) => compose1(x, y)) // error // error
3-
| ^
4-
|reference ps* is not included in the allowed capture set {}
5-
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
6-
-- Error: tests/neg-custom-args/captures/reaches2.scala:8:13 -----------------------------------------------------------
7-
8 | ps.map((x, y) => compose1(x, y)) // error // error
8-
| ^
9-
|reference ps* is not included in the allowed capture set {}
10-
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
1+
-- Error: tests/neg-custom-args/captures/reaches2.scala:10:10 ----------------------------------------------------------
2+
10 | ps.map((x, y) => compose1(x, y)) // error // error // error
3+
| ^
4+
|reference ps* is not included in the allowed capture set {}
5+
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
6+
-- Error: tests/neg-custom-args/captures/reaches2.scala:10:13 ----------------------------------------------------------
7+
10 | ps.map((x, y) => compose1(x, y)) // error // error // error
8+
| ^
9+
|reference ps* is not included in the allowed capture set {}
10+
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
11+
-- Error: tests/neg-custom-args/captures/reaches2.scala:10:31 ----------------------------------------------------------
12+
10 | ps.map((x, y) => compose1(x, y)) // error // error // error
13+
| ^
14+
| Separation failure: argument of type (x$0: A) ->{y} box A^?
15+
| to method compose1: [A, B, C](f: A => B, g: B => C): A ->{f, g} C
16+
| corresponds to capture-polymorphic formal parameter g of type box A^? => box A^?
17+
| and captures {ps*}, but this capability is also passed separately
18+
| in the first argument with type (x$0: A) ->{x} box A^?.
19+
|
20+
| Capture set of first argument : {x}
21+
| Hidden set of current argument : {y}
22+
| Footprint of first argument : {x, ps*}
23+
| Hidden footprint of current argument : {y, ps*}
24+
| Declared footprint of current argument: {}
25+
| Undeclared overlap of footprints : {ps*}
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import language.`3.8` // sepchecks on
2+
13
class List[+A]:
24
def map[B](f: A -> B): List[B] = ???
35

46
def compose1[A, B, C](f: A => B, g: B => C): A ->{f, g} C =
57
z => g(f(z))
68

79
def mapCompose[A](ps: List[(A => A, A => A)]): List[A ->{ps*} A] =
8-
ps.map((x, y) => compose1(x, y)) // error // error
10+
ps.map((x, y) => compose1(x, y)) // error // error // error
911

Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:7:10 ---------------------------------------------------------
2-
7 | println(c) // error
3-
| ^
4-
| Separation failure: Illegal access to {c} which is hidden by the previous definition
5-
| of value xs with type List[box () => Unit].
6-
| This type hides capabilities {xs*, c}
7-
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:10:33 --------------------------------------------------------
8-
10 | foo((() => println(c)) :: Nil, c) // error
1+
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:10:10 --------------------------------------------------------
2+
10 | println(c) // error
3+
| ^
4+
| Separation failure: Illegal access to {c} which is hidden by the previous definition
5+
| of value xs with type List[box () => Unit].
6+
| This type hides capabilities {xs*, c}
7+
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:13:33 --------------------------------------------------------
8+
13 | foo((() => println(c)) :: Nil, c) // error
99
| ^
1010
| Separation failure: argument of type (c : Object^)
1111
| to method foo: (xs: List[box () => Unit], y: Object^): Nothing
@@ -19,3 +19,24 @@
1919
| Hidden footprint of current argument : {c}
2020
| Declared footprint of current argument: {}
2121
| Undeclared overlap of footprints : {c}
22+
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:14:10 --------------------------------------------------------
23+
14 | val x1: (Object^, Object^) = (c, c) // error
24+
| ^^^^^^^^^^^^^^^^^^
25+
| Separation failure in type (box Object^, box Object^).
26+
| One part, box Object^ , hides {c}.
27+
| A previous part, box Object^ , also hides {c}.
28+
| The two sets overlap at {c}.
29+
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:15:10 --------------------------------------------------------
30+
15 | val x2: (Object^, Object^{d}) = (d, d) // error
31+
| ^^^^^^^^^^^^^^^^^^^^^
32+
| Separation failure in type (box Object^, box Object^{d}).
33+
| One part, box Object^{d} , references {d}.
34+
| A previous part, box Object^ , hides {d}.
35+
| The two sets overlap at {d}.
36+
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:27:6 ---------------------------------------------------------
37+
27 | bar((c, c)) // error
38+
| ^^^^^^
39+
| Separation failure in the argument's adapted type type (box Object^, box Object^).
40+
| One part, box Object^ , hides {c}.
41+
| A previous part, box Object^ , also hides {c}.
42+
| The two sets overlap at {c}.
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
11
import language.future // sepchecks on
22

3+
34
def foo(xs: List[() => Unit], y: Object^) = ???
45

6+
def bar(x: (Object^, Object^)): Unit = ???
7+
58
def Test(c: Object^) =
69
val xs: List[() => Unit] = (() => println(c)) :: Nil
710
println(c) // error
811

9-
def Test2(c: Object^) =
12+
def Test2(c: Object^, d: Object^): Unit =
1013
foo((() => println(c)) :: Nil, c) // error
14+
val x1: (Object^, Object^) = (c, c) // error
15+
val x2: (Object^, Object^{d}) = (d, d) // error
16+
17+
def Test3(c: Object^, d: Object^) =
18+
val x: (Object^, Object^) = (c, d) // ok
19+
20+
def Test4(c: Object^, d: Object^) =
21+
val x: (Object^, Object^{c}) = (d, c) // ok
22+
23+
def Test5(c: Object^, d: Object^): Unit =
24+
bar((c, d)) // ok
25+
26+
def Test6(c: Object^, d: Object^): Unit =
27+
bar((c, c)) // error
28+

tests/pos-custom-args/captures/i15749a.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import caps.cap
22
import caps.use
3+
import language.`3.7` // sepchecks on
34

45
class Unit
56
object u extends Unit
@@ -13,7 +14,7 @@ def test =
1314
def wrapper[T](x: T): Wrapper[T] = Wrapper:
1415
[X] => (op: T ->{cap} X) => op(x)
1516

16-
def strictMap[A <: Top, B <: Top](mx: Wrapper[A])(f: A ->{cap} B): Wrapper[B] =
17+
def strictMap[A <: Top, B <: Top](mx: Wrapper[A])(f: A ->{cap, mx*} B): Wrapper[B] =
1718
mx.value((x: A) => wrapper(f(x)))
1819

1920
def force[A](thunk: Unit ->{cap} A): A = thunk(u)

tests/run-custom-args/captures/colltest5/CollectionStrawManCC5_1.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import scala.reflect.ClassTag
66
import annotation.unchecked.{uncheckedVariance, uncheckedCaptures}
77
import annotation.tailrec
88
import caps.cap
9+
import caps.untrackedCaptures
910
import language.`3.7` // sepchecks on
1011

1112
/** A strawman architecture for new collections. It contains some
@@ -68,11 +69,13 @@ object CollectionStrawMan5 {
6869
/** Base trait for strict collections */
6970
trait Buildable[+A] extends Iterable[A] {
7071
protected def newBuilder: Builder[A, Repr] @uncheckedVariance
71-
override def partition(p: A => Boolean): (Repr, Repr) = {
72+
override def partition(p: A => Boolean): (Repr, Repr) @untrackedCaptures =
73+
// Without untrackedCaptures this fails SepChecks.checkType.
74+
// But this is probably an error in the hiding logic.
75+
// TODO remove @untrackedCaptures and investigate
7276
val l, r = newBuilder
7377
iterator.foreach(x => (if (p(x)) l else r) += x)
7478
(l.result, r.result)
75-
}
7679
// one might also override other transforms here to avoid generating
7780
// iterators if it helps efficiency.
7881
}

0 commit comments

Comments
 (0)