Skip to content

Commit 4c8a50f

Browse files
committed
Separation checking for applications
Check separation from source 3.7 on. We currently only check applications, other areas of separation checking are still to be implemented.
1 parent ad11819 commit 4c8a50f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+712
-115
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ object ccConfig:
5454

5555
/** If true, turn on separation checking */
5656
def useFresh(using Context): Boolean =
57-
Feature.sourceVersion.stable.isAtLeast(SourceVersion.`future`)
57+
Feature.sourceVersion.stable.isAtLeast(SourceVersion.`3.7`)
5858

5959
end ccConfig
6060

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+23
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,17 @@ object CheckCaptures:
242242

243243
/** Was a new type installed for this tree? */
244244
def hasNuType: Boolean
245+
246+
/** Is this tree passed to a parameter or assigned to a value with a type
247+
* that contains cap in no-flip covariant position, which will necessite
248+
* a separation check?
249+
*/
250+
def needsSepCheck: Boolean
251+
252+
/** If a tree is an argument for which needsSepCheck is true,
253+
* the type of the formal paremeter corresponding to the argument.
254+
*/
255+
def formalType: Type
245256
end CheckerAPI
246257

247258
class CheckCaptures extends Recheck, SymTransformer:
@@ -282,6 +293,15 @@ class CheckCaptures extends Recheck, SymTransformer:
282293
*/
283294
private val todoAtPostCheck = new mutable.ListBuffer[() => Unit]
284295

296+
/** Maps trees that need a separation check because they are arguments to
297+
* polymorphic parameters. The trees are mapped to the formal parameter type.
298+
*/
299+
private val sepCheckFormals = util.EqHashMap[Tree, Type]()
300+
301+
extension [T <: Tree](tree: T)
302+
def needsSepCheck: Boolean = sepCheckFormals.contains(tree)
303+
def formalType: Type = sepCheckFormals.getOrElse(tree, NoType)
304+
285305
/** Instantiate capture set variables appearing contra-variantly to their
286306
* upper approximation.
287307
*/
@@ -662,6 +682,8 @@ class CheckCaptures extends Recheck, SymTransformer:
662682
// The @use annotation is added to `formal` by `prepareFunction`
663683
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
664684
markFree(argType.deepCaptureSet, arg.srcPos)
685+
if formal.containsCap then
686+
sepCheckFormals(arg) = freshenedFormal
665687
argType
666688

667689
/** Map existential captures in result to `cap` and implement the following
@@ -1786,6 +1808,7 @@ class CheckCaptures extends Recheck, SymTransformer:
17861808
end checker
17871809

17881810
checker.traverse(unit)(using ctx.withOwner(defn.RootClass))
1811+
if ccConfig.useFresh then SepChecker(this).traverse(unit)
17891812
if !ctx.reporter.errorsReported then
17901813
// We dont report errors here if previous errors were reported, because other
17911814
// errors often result in bad applied types, but flagging these bad types gives
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
import ast.tpd
5+
import collection.mutable
6+
7+
import core.*
8+
import Symbols.*, Types.*
9+
import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
10+
import CaptureSet.{Refs, emptySet}
11+
import config.Printers.capt
12+
import StdNames.nme
13+
14+
class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
15+
import tpd.*
16+
import checker.*
17+
18+
extension (refs: Refs)
19+
private def footprint(using Context): Refs =
20+
def recur(elems: Refs, newElems: List[CaptureRef]): Refs = newElems match
21+
case newElem :: newElems1 =>
22+
val superElems = newElem.captureSetOfInfo.elems.filter: superElem =>
23+
!superElem.isMaxCapability && !elems.contains(superElem)
24+
recur(elems ++ superElems, newElems1 ++ superElems.toList)
25+
case Nil => elems
26+
val elems: Refs = refs.filter(!_.isMaxCapability)
27+
recur(elems, elems.toList)
28+
29+
private def overlapWith(other: Refs)(using Context): Refs =
30+
val refs1 = refs
31+
val refs2 = other
32+
def common(refs1: Refs, refs2: Refs) =
33+
refs1.filter: ref =>
34+
ref.isExclusive && refs2.exists(_.stripReadOnly eq ref)
35+
common(refs, other) ++ common(other, refs)
36+
37+
private def hidden(refs: Refs)(using Context): Refs =
38+
val seen: util.EqHashSet[CaptureRef] = new util.EqHashSet
39+
40+
def hiddenByElem(elem: CaptureRef): Refs =
41+
if seen.add(elem) then elem match
42+
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs.elems)
43+
case ReadOnlyCapability(ref) => hiddenByElem(ref).map(_.readOnly)
44+
case _ => emptySet
45+
else emptySet
46+
47+
def recur(cs: Refs): Refs =
48+
(emptySet /: cs): (elems, elem) =>
49+
elems ++ hiddenByElem(elem)
50+
51+
recur(refs)
52+
end hidden
53+
54+
/** The captures of an argument or prefix widened to the formal parameter, if
55+
* the latter contains a cap.
56+
*/
57+
private def formalCaptures(arg: Tree)(using Context): Refs =
58+
val argType = arg.formalType.orElse(arg.nuType)
59+
(if arg.nuType.hasUseAnnot then argType.deepCaptureSet else argType.captureSet)
60+
.elems
61+
62+
/** The captures of an argument of prefix. No widening takes place */
63+
private def actualCaptures(arg: Tree)(using Context): Refs =
64+
val argType = arg.nuType
65+
(if argType.hasUseAnnot then argType.deepCaptureSet else argType.captureSet)
66+
.elems
67+
68+
private def sepError(fn: Tree, args: List[Tree], argIdx: Int,
69+
overlap: Refs, hiddenInArg: Refs, footprints: List[(Refs, Int)],
70+
deps: collection.Map[Tree, List[Tree]])(using Context): Unit =
71+
val arg = args(argIdx)
72+
def paramName(mt: Type, idx: Int): Option[Name] = mt match
73+
case mt @ MethodType(pnames) =>
74+
if idx < pnames.length then Some(pnames(idx)) else paramName(mt.resType, idx - pnames.length)
75+
case mt: PolyType => paramName(mt.resType, idx)
76+
case _ => None
77+
def formalName = paramName(fn.nuType.widen, argIdx) match
78+
case Some(pname) => i"$pname "
79+
case _ => ""
80+
def whatStr = if overlap.size == 1 then "this capability is" else "these capabilities are"
81+
def funStr =
82+
if fn.symbol.exists then i"${fn.symbol}: ${fn.symbol.info}"
83+
else i"a function of type ${fn.nuType.widen}"
84+
val clashIdx = footprints
85+
.collect:
86+
case (fp, idx) if !hiddenInArg.overlapWith(fp).isEmpty => idx
87+
.head
88+
def whereStr = clashIdx match
89+
case 0 => "function prefix"
90+
case 1 => "first argument "
91+
case 2 => "second argument"
92+
case 3 => "third argument "
93+
case n => s"${n}th argument "
94+
def clashTree =
95+
if clashIdx == 0 then methPart(fn).asInstanceOf[Select].qualifier
96+
else args(clashIdx - 1)
97+
def clashType = clashTree.nuType
98+
def clashCaptures = actualCaptures(clashTree)
99+
def hiddenCaptures = hidden(formalCaptures(arg))
100+
def clashFootprint = clashCaptures.footprint
101+
def hiddenFootprint = hiddenCaptures.footprint
102+
def declaredFootprint = deps(arg).map(actualCaptures(_)).foldLeft(emptySet)(_ ++ _).footprint
103+
def footprintOverlap = hiddenFootprint.overlapWith(clashFootprint) -- declaredFootprint
104+
report.error(
105+
em"""Separation failure: argument of type ${arg.nuType}
106+
|to $funStr
107+
|corresponds to capture-polymorphic formal parameter ${formalName}of type ${arg.formalType}
108+
|and captures ${CaptureSet(overlap)}, but $whatStr also passed separately
109+
|in the ${whereStr.trim} with type $clashType.
110+
|
111+
| Capture set of $whereStr : ${CaptureSet(clashCaptures)}
112+
| Hidden set of current argument : ${CaptureSet(hiddenCaptures)}
113+
| Footprint of $whereStr : ${CaptureSet(clashFootprint)}
114+
| Hidden footprint of current argument : ${CaptureSet(hiddenFootprint)}
115+
| Declared footprint of current argument: ${CaptureSet(declaredFootprint)}
116+
| Undeclared overlap of footprints : ${CaptureSet(footprintOverlap)}""",
117+
arg.srcPos)
118+
end sepError
119+
120+
private def checkApply(fn: Tree, args: List[Tree], deps: collection.Map[Tree, List[Tree]])(using Context): Unit =
121+
val fnCaptures = methPart(fn) match
122+
case Select(qual, _) => qual.nuType.captureSet
123+
case _ => CaptureSet.empty
124+
capt.println(i"check separate $fn($args), fnCaptures = $fnCaptures, argCaptures = ${args.map(arg => CaptureSet(formalCaptures(arg)))}, deps = ${deps.toList}")
125+
var footprint = fnCaptures.elems.footprint
126+
val footprints = mutable.ListBuffer[(Refs, Int)]((footprint, 0))
127+
val indexedArgs = args.zipWithIndex
128+
129+
def subtractDeps(elems: Refs, arg: Tree): Refs =
130+
deps(arg).foldLeft(elems): (elems, dep) =>
131+
elems -- actualCaptures(dep).footprint
132+
133+
for (arg, idx) <- indexedArgs do
134+
if !arg.needsSepCheck then
135+
footprint = footprint ++ subtractDeps(actualCaptures(arg).footprint, arg)
136+
footprints += ((footprint, idx + 1))
137+
for (arg, idx) <- indexedArgs do
138+
if arg.needsSepCheck then
139+
val ac = formalCaptures(arg)
140+
val hiddenInArg = hidden(ac).footprint
141+
//println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
142+
val overlap = subtractDeps(hiddenInArg.overlapWith(footprint), arg)
143+
if !overlap.isEmpty then
144+
sepError(fn, args, idx, overlap, hiddenInArg, footprints.toList, deps)
145+
footprint ++= actualCaptures(arg).footprint
146+
footprints += ((footprint, idx + 1))
147+
end checkApply
148+
149+
private def collectMethodTypes(tp: Type): List[TermLambda] = tp match
150+
case tp: MethodType => tp :: collectMethodTypes(tp.resType)
151+
case tp: PolyType => collectMethodTypes(tp.resType)
152+
case _ => Nil
153+
154+
private def dependencies(fn: Tree, argss: List[List[Tree]])(using Context): collection.Map[Tree, List[Tree]] =
155+
val mtpe =
156+
if fn.symbol.exists then fn.symbol.info
157+
else fn.tpe.widen // happens for PolyFunction applies
158+
val mtps = collectMethodTypes(mtpe)
159+
assert(mtps.hasSameLengthAs(argss), i"diff for $fn: ${fn.symbol} /// $mtps /// $argss")
160+
val mtpsWithArgs = mtps.zip(argss)
161+
val argMap = mtpsWithArgs.toMap
162+
val deps = mutable.HashMap[Tree, List[Tree]]().withDefaultValue(Nil)
163+
for
164+
(mt, args) <- mtpsWithArgs
165+
(formal, arg) <- mt.paramInfos.zip(args)
166+
dep <- formal.captureSet.elems.toList
167+
do
168+
val referred = dep match
169+
case dep: TermParamRef =>
170+
argMap(dep.binder)(dep.paramNum) :: Nil
171+
case dep: ThisType if dep.cls == fn.symbol.owner =>
172+
val Select(qual, _) = fn: @unchecked
173+
qual :: Nil
174+
case _ =>
175+
Nil
176+
deps(arg) ++= referred
177+
deps
178+
179+
private def traverseApply(tree: Tree, argss: List[List[Tree]])(using Context): Unit = tree match
180+
case Apply(fn, args) => traverseApply(fn, args :: argss)
181+
case TypeApply(fn, args) => traverseApply(fn, argss) // skip type arguments
182+
case _ =>
183+
if argss.nestedExists(_.needsSepCheck) then
184+
checkApply(tree, argss.flatten, dependencies(tree, argss))
185+
186+
def traverse(tree: Tree)(using Context): Unit =
187+
tree match
188+
case tree: GenericApply =>
189+
if tree.symbol != defn.Caps_unsafeAssumeSeparate then
190+
tree.tpe match
191+
case _: MethodOrPoly =>
192+
case _ => traverseApply(tree, Nil)
193+
traverseChildren(tree)
194+
case _ =>
195+
traverseChildren(tree)
196+
end SepChecker
197+
198+
199+
200+
201+
202+

compiler/src/dotty/tools/dotc/cc/Synthetics.scala

+6-2
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,9 @@ object Synthetics:
132132
val (pt: PolyType) = info: @unchecked
133133
val (mt: MethodType) = pt.resType: @unchecked
134134
val (enclThis: ThisType) = owner.thisType: @unchecked
135+
val paramCaptures = CaptureSet(enclThis, defn.captureRoot.termRef)
135136
pt.derivedLambdaType(resType = MethodType(mt.paramNames)(
136-
mt1 => mt.paramInfos.map(_.capturing(CaptureSet.universal)),
137+
mt1 => mt.paramInfos.map(_.capturing(paramCaptures)),
137138
mt1 => CapturingType(mt.resType, CaptureSet(enclThis, mt1.paramRefs.head))))
138139

139140
def transformCurriedTupledCaptures(info: Type, owner: Symbol) =
@@ -148,7 +149,10 @@ object Synthetics:
148149
ExprType(mapFinalResult(et.resType, CapturingType(_, CaptureSet(enclThis))))
149150

150151
def transformCompareCaptures =
151-
MethodType(defn.ObjectType.capturing(CaptureSet.universal) :: Nil, defn.BooleanType)
152+
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
153+
MethodType(
154+
defn.ObjectType.capturing(CaptureSet(defn.captureRoot.termRef, enclThis)) :: Nil,
155+
defn.BooleanType)
152156

153157
symd.copySymDenotation(info = symd.name match
154158
case DefaultGetterName(nme.copy, n) =>

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,7 @@ class Definitions {
10061006
@tu lazy val Caps_Exists: ClassSymbol = requiredClass("scala.caps.Exists")
10071007
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
10081008
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
1009+
@tu lazy val Caps_unsafeAssumeSeparate: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumeSeparate")
10091010
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
10101011
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")
10111012
@tu lazy val Caps_Mutable: ClassSymbol = requiredClass("scala.caps.Mutable")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package scala.annotation
2+
package internal
3+
4+
/** An annotation used internally for fresh capability wrappers of `cap`
5+
*/
6+
class freshCapability extends StaticAnnotation
7+

library/src/scala/caps.scala

+5
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,9 @@ import annotation.{experimental, compileTimeOnly, retainsCap}
7979
*/
8080
def unsafeAssumePure: T = x
8181

82+
/** A wrapper around code for which separation checks are suppressed.
83+
*/
84+
def unsafeAssumeSeparate[T](op: T): T = op
85+
8286
end unsafe
87+
end caps

scala2-library-cc/src/scala/collection/IterableOnce.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
805805
case _ => Some(reduceLeft(op))
806806
}
807807
private final def reduceLeftOptionIterator[B >: A](op: (B, A) => B): Option[B] = reduceOptionIterator[A, B](iterator)(op)
808-
private final def reduceOptionIterator[X >: A, B >: X](it: Iterator[X]^)(op: (B, X) => B): Option[B] = {
808+
private final def reduceOptionIterator[X >: A, B >: X](it: Iterator[X]^{this, caps.cap})(op: (B, X) => B): Option[B] = {
809809
if (it.hasNext) {
810810
var acc: B = it.next()
811811
while (it.hasNext)

scala2-library-cc/src/scala/collection/immutable/LazyListIterable.scala

+9-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.runtime.Statics
2525
import language.experimental.captureChecking
2626
import annotation.unchecked.uncheckedCaptures
2727
import caps.untrackedCaptures
28+
import caps.unsafe.unsafeAssumeSeparate
2829

2930
/** This class implements an immutable linked list. We call it "lazy"
3031
* because it computes its elements only when they are needed.
@@ -879,6 +880,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
879880
if (!cursor.stateDefined) b.append(sep).append("<not computed>")
880881
} else {
881882
@inline def same(a: LazyListIterable[A]^, b: LazyListIterable[A]^): Boolean = (a eq b) || (a.state eq b.state)
883+
// !!!CC with qualifiers, same should have cap.rd parameters
882884
// Cycle.
883885
// If we have a prefix of length P followed by a cycle of length C,
884886
// the scout will be at position (P%C) in the cycle when the cursor
@@ -890,7 +892,7 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
890892
// the start of the loop.
891893
var runner = this
892894
var k = 0
893-
while (!same(runner, scout)) {
895+
while (!unsafeAssumeSeparate(same(runner, scout))) {
894896
runner = runner.tail
895897
scout = scout.tail
896898
k += 1
@@ -900,11 +902,11 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
900902
// everything once. If cursor is already at beginning, we'd better
901903
// advance one first unless runner didn't go anywhere (in which case
902904
// we've already looped once).
903-
if (same(cursor, scout) && (k > 0)) {
905+
if (unsafeAssumeSeparate(same(cursor, scout)) && (k > 0)) {
904906
appendCursorElement()
905907
cursor = cursor.tail
906908
}
907-
while (!same(cursor, scout)) {
909+
while (!unsafeAssumeSeparate(same(cursor, scout))) {
908910
appendCursorElement()
909911
cursor = cursor.tail
910912
}
@@ -1052,7 +1054,9 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
10521054
val head = it.next()
10531055
rest = rest.tail
10541056
restRef = rest // restRef.elem = rest
1055-
sCons(head, newLL(stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state)))
1057+
sCons(head, newLL(
1058+
unsafeAssumeSeparate(
1059+
stateFromIteratorConcatSuffix(it)(flatMapImpl(rest, f).state))))
10561060
} else State.Empty
10571061
}
10581062
}
@@ -1181,7 +1185,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
11811185
def iterate[A](start: => A)(f: A => A): LazyListIterable[A]^{start, f} =
11821186
newLL {
11831187
val head = start
1184-
sCons(head, iterate(f(head))(f))
1188+
sCons(head, unsafeAssumeSeparate(iterate(f(head))(f)))
11851189
}
11861190

11871191
/**

0 commit comments

Comments
 (0)