Skip to content

Commit 60b0486

Browse files
committed
First implementation of capture polymorphism
1 parent ea182f2 commit 60b0486

File tree

14 files changed

+139
-27
lines changed

14 files changed

+139
-27
lines changed

compiler/src/dotty/tools/dotc/ast/untpd.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -521,12 +521,17 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
521521
def captureRoot(using Context): Select =
522522
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)
523523

524-
def captureRootIn(using Context): Select =
525-
Select(scalaDot(nme.caps), nme.capIn)
526-
527524
def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
528525
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))
529526

527+
def makeCapsOf(id: Ident)(using Context): Tree =
528+
TypeApply(Select(scalaDot(nme.caps), nme.capsOf), id :: Nil)
529+
530+
def makeCapsBound()(using Context): Tree =
531+
makeRetaining(
532+
Select(scalaDot(nme.caps), tpnme.CapSet),
533+
Nil, tpnme.retainsCap)
534+
530535
def makeConstructor(tparams: List[TypeDef], vparamss: List[List[ValDef]], rhs: Tree = EmptyTree)(using Context): DefDef =
531536
DefDef(nme.CONSTRUCTOR, joinParams(tparams, vparamss), TypeTree(), rhs)
532537

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+11-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ extension (tree: Tree)
131131
def toCaptureRef(using Context): CaptureRef = tree match
132132
case ReachCapabilityApply(arg) =>
133133
arg.toCaptureRef.reach
134+
case CapsOfApply(arg) =>
135+
arg.toCaptureRef
134136
case _ => tree.tpe match
135137
case ref: CaptureRef if ref.isTrackableRef =>
136138
ref
@@ -145,7 +147,7 @@ extension (tree: Tree)
145147
case Some(refs) => refs
146148
case None =>
147149
val refs = CaptureSet(tree.retainedElems.map(_.toCaptureRef)*)
148-
.showing(i"toCaptureSet $tree --> $result", capt)
150+
//.showing(i"toCaptureSet $tree --> $result", capt)
149151
tree.putAttachment(Captures, refs)
150152
refs
151153

@@ -526,6 +528,14 @@ object ReachCapabilityApply:
526528
case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
527529
case _ => None
528530

531+
/** An extractor for `caps.capsOf[X]`, which is used to express a generic capture set
532+
* as a tree in a @retains annotation.
533+
*/
534+
object CapsOfApply:
535+
def unapply(tree: TypeApply)(using Context): Option[Tree] = tree match
536+
case TypeApply(capsOf, arg :: Nil) if capsOf.symbol == defn.Caps_capsOf => Some(arg)
537+
case _ => None
538+
529539
class AnnotatedCapability(annot: Context ?=> ClassSymbol):
530540
def apply(tp: Type)(using Context) =
531541
AnnotatedType(tp, Annotation(annot, util.Spans.NoSpan))

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,9 @@ object CaptureSet:
879879
val r1 = tm(r)
880880
val upper = r1.captureSet
881881
def isExact =
882-
upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
882+
upper.isAlwaysEmpty
883+
|| upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
884+
|| r.derivesFrom(defn.Caps_CapSet)
883885
if variance > 0 || isExact then upper
884886
else if variance < 0 then CaptureSet.empty
885887
else upper.maybe

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+17-9
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,24 @@ object CheckCaptures:
123123
case _: SingletonType =>
124124
report.error(em"Singleton type $parent cannot have capture set", parent.srcPos)
125125
case _ =>
126+
def check(elem: Tree, pos: SrcPos): Unit = elem.tpe match
127+
case ref: CaptureRef =>
128+
if !ref.isTrackableRef then
129+
report.error(em"$elem cannot be tracked since it is not a parameter or local value", pos)
130+
case tpe =>
131+
report.error(em"$elem: $tpe is not a legal element of a capture set", pos)
126132
for elem <- ann.retainedElems do
127-
val elem1 = elem match
128-
case ReachCapabilityApply(arg) => arg
129-
case _ => elem
130-
elem1.tpe match
131-
case ref: CaptureRef =>
132-
if !ref.isTrackableRef then
133-
report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos)
134-
case tpe =>
135-
report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos)
133+
elem match
134+
case CapsOfApply(arg) =>
135+
def isLegalCapsOfArg =
136+
arg.symbol.isAbstractOrParamType && arg.symbol.info.derivesFrom(defn.Caps_CapSet)
137+
if !isLegalCapsOfArg then
138+
report.error(
139+
em"""$arg is not a legal prefix for `^` here,
140+
|is must be a type parameter or abstract type with a caps.CapSet upper bound.""",
141+
elem.srcPos)
142+
case ReachCapabilityApply(arg) => check(arg, elem.srcPos)
143+
case _ => check(elem, elem.srcPos)
136144

137145
/** Report an error if some part of `tp` contains the root capability in its capture set
138146
* or if it refers to an unsealed type parameter that could possibly be instantiated with

compiler/src/dotty/tools/dotc/cc/Setup.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
384384
sym.updateInfo(thisPhase, info, newFlagsFor(sym))
385385
toBeUpdated -= sym
386386
sym.namedType match
387-
case ref: CaptureRef => ref.invalidateCaches() // TODO: needed?
387+
case ref: CaptureRef if ref.isTrackableRef => ref.invalidateCaches() // TODO: needed?
388388
case _ =>
389389

390390
extension (sym: Symbol) def nextInfo(using Context): Type =

compiler/src/dotty/tools/dotc/core/Definitions.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -991,8 +991,10 @@ class Definitions {
991991

992992
@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
993993
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap")
994-
@tu lazy val Caps_Capability: ClassSymbol = requiredClass("scala.caps.Capability")
994+
@tu lazy val Caps_Capability: TypeSymbol = CapsModule.requiredType("Capability")
995+
@tu lazy val Caps_CapSet = requiredClass("scala.caps.CapSet")
995996
@tu lazy val Caps_reachCapability: TermSymbol = CapsModule.requiredMethod("reachCapability")
997+
@tu lazy val Caps_capsOf: TermSymbol = CapsModule.requiredMethod("capsOf")
996998
@tu lazy val Caps_Exists = requiredClass("scala.caps.Exists")
997999
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
9981000
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")

compiler/src/dotty/tools/dotc/core/StdNames.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ object StdNames {
358358
val AppliedTypeTree: N = "AppliedTypeTree"
359359
val ArrayAnnotArg: N = "ArrayAnnotArg"
360360
val CAP: N = "CAP"
361+
val CapSet: N = "CapSet"
361362
val Constant: N = "Constant"
362363
val ConstantType: N = "ConstantType"
363364
val Eql: N = "Eql"
@@ -441,8 +442,8 @@ object StdNames {
441442
val bytes: N = "bytes"
442443
val canEqual_ : N = "canEqual"
443444
val canEqualAny : N = "canEqualAny"
444-
val capIn: N = "capIn"
445445
val caps: N = "caps"
446+
val capsOf: N = "capsOf"
446447
val captureChecking: N = "captureChecking"
447448
val checkInitialized: N = "checkInitialized"
448449
val classOf: N = "classOf"

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -2839,7 +2839,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
28392839
private def existentialVarsConform(tp1: Type, tp2: Type) =
28402840
tp2 match
28412841
case tp2: TermParamRef => tp1 match
2842-
case tp1: CaptureRef => subsumesExistentially(tp2, tp1)
2842+
case tp1: CaptureRef if tp1.isTrackableRef => subsumesExistentially(tp2, tp1)
28432843
case _ => false
28442844
case _ => false
28452845

compiler/src/dotty/tools/dotc/core/Types.scala

+13-3
Original file line numberDiff line numberDiff line change
@@ -2313,7 +2313,11 @@ object Types extends TypeUtils {
23132313

23142314
override def captureSet(using Context): CaptureSet =
23152315
val cs = captureSetOfInfo
2316-
if isTrackableRef && !cs.isAlwaysEmpty then singletonCaptureSet else cs
2316+
if isTrackableRef then
2317+
if cs.isAlwaysEmpty then cs else singletonCaptureSet
2318+
else dealias match
2319+
case _: (TypeRef | TypeParamRef) => CaptureSet.empty
2320+
case _ => cs
23172321

23182322
end CaptureRef
23192323

@@ -3032,7 +3036,7 @@ object Types extends TypeUtils {
30323036

30333037
abstract case class TypeRef(override val prefix: Type,
30343038
private var myDesignator: Designator)
3035-
extends NamedType {
3039+
extends NamedType, CaptureRef {
30363040

30373041
type ThisType = TypeRef
30383042
type ThisName = TypeName
@@ -3081,6 +3085,9 @@ object Types extends TypeUtils {
30813085
/** Hook that can be called from creation methods in TermRef and TypeRef */
30823086
def validated(using Context): this.type =
30833087
this
3088+
3089+
override def isTrackableRef(using Context) =
3090+
symbol.isAbstractOrParamType && derivesFrom(defn.Caps_CapSet)
30843091
}
30853092

30863093
final class CachedTermRef(prefix: Type, designator: Designator, hc: Int) extends TermRef(prefix, designator) {
@@ -4841,7 +4848,8 @@ object Types extends TypeUtils {
48414848
/** Only created in `binder.paramRefs`. Use `binder.paramRefs(paramNum)` to
48424849
* refer to `TypeParamRef(binder, paramNum)`.
48434850
*/
4844-
abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int) extends ParamRef {
4851+
abstract case class TypeParamRef(binder: TypeLambda, paramNum: Int)
4852+
extends ParamRef, CaptureRef {
48454853
type BT = TypeLambda
48464854
def kindString: String = "Type"
48474855
def copyBoundType(bt: BT): Type = bt.paramRefs(paramNum)
@@ -4861,6 +4869,8 @@ object Types extends TypeUtils {
48614869
case bound: OrType => occursIn(bound.tp1, fromBelow) || occursIn(bound.tp2, fromBelow)
48624870
case _ => false
48634871
}
4872+
4873+
override def isTrackableRef(using Context) = derivesFrom(defn.Caps_CapSet)
48644874
}
48654875

48664876
private final class TypeParamRefImpl(binder: TypeLambda, paramNum: Int) extends TypeParamRef(binder, paramNum)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+13-4
Original file line numberDiff line numberDiff line change
@@ -1541,7 +1541,7 @@ object Parsers {
15411541
case _ => None
15421542
}
15431543

1544-
/** CaptureRef ::= ident | `this` | `cap` [`[` ident `]`]
1544+
/** CaptureRef ::= ident [`*` | `^`] | `this`
15451545
*/
15461546
def captureRef(): Tree =
15471547
if in.token == THIS then simpleRef()
@@ -1551,6 +1551,10 @@ object Parsers {
15511551
in.nextToken()
15521552
atSpan(startOffset(id)):
15531553
PostfixOp(id, Ident(nme.CC_REACH))
1554+
else if isIdent(nme.UPARROW) then
1555+
in.nextToken()
1556+
atSpan(startOffset(id)):
1557+
makeCapsOf(cpy.Ident(id)(id.name.toTypeName))
15541558
else id
15551559

15561560
/** CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` -- under captureChecking
@@ -1968,7 +1972,7 @@ object Parsers {
19681972
}
19691973

19701974
/** SimpleType ::= SimpleLiteral
1971-
* | ‘?’ SubtypeBounds
1975+
* | ‘?’ TypeBounds
19721976
* | SimpleType1
19731977
* | SimpleType ‘(’ Singletons ‘)’ -- under language.experimental.dependent, checked in Typer
19741978
* Singletons ::= Singleton {‘,’ Singleton}
@@ -2188,9 +2192,15 @@ object Parsers {
21882192
inBraces(refineStatSeq())
21892193

21902194
/** TypeBounds ::= [`>:' Type] [`<:' Type]
2195+
* | `^` -- under captureChecking
21912196
*/
21922197
def typeBounds(): TypeBoundsTree =
2193-
atSpan(in.offset) { TypeBoundsTree(bound(SUPERTYPE), bound(SUBTYPE)) }
2198+
atSpan(in.offset):
2199+
if in.isIdent(nme.UPARROW) && Feature.ccEnabled then
2200+
in.nextToken()
2201+
TypeBoundsTree(EmptyTree, makeCapsBound())
2202+
else
2203+
TypeBoundsTree(bound(SUPERTYPE), bound(SUBTYPE))
21942204

21952205
private def bound(tok: Int): Tree =
21962206
if (in.token == tok) { in.nextToken(); toplevelTyp() }
@@ -3384,7 +3394,6 @@ object Parsers {
33843394
*
33853395
* DefTypeParamClause::= ‘[’ DefTypeParam {‘,’ DefTypeParam} ‘]’
33863396
* DefTypeParam ::= {Annotation}
3387-
* [`sealed`] -- under captureChecking
33883397
* id [HkTypeParamClause] TypeParamBounds
33893398
*
33903399
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -2328,7 +2328,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
23282328
val res = Throw(expr1).withSpan(tree.span)
23292329
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
23302330
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
2331-
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotatoon.
2331+
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
23322332
Typed(res,
23332333
TypeTree(
23342334
AnnotatedType(res.tpe,

library/src/scala/caps.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package scala
22

3-
import annotation.experimental
3+
import annotation.{experimental, compileTimeOnly}
44

55
@experimental object caps:
66

@@ -16,6 +16,12 @@ import annotation.experimental
1616
@deprecated("Use `Capability` instead")
1717
type Cap = Capability
1818

19+
/** Carrier trait for capture set type parameters */
20+
trait CapSet extends Any
21+
22+
@compileTimeOnly("Should be be used only internally by the Scala compiler")
23+
def capsOf[CS]: Any = ???
24+
1925
/** Reach capabilities x* which appear as terms in @retains annotations are encoded
2026
* as `caps.reachCapability(x)`. When converted to CaptureRef types in capture sets
2127
* they are represented as `x.type @annotation.internal.reachCapability`.

tests/pos/cc-poly-1.scala

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import language.experimental.captureChecking
2+
import annotation.experimental
3+
import caps.{CapSet, Capability}
4+
5+
@experimental object Test:
6+
7+
class C extends Capability
8+
class D
9+
10+
def f[X^](x: D^{X^}): D^{X^} = x
11+
def g[X^](x: D^{X^}, y: D^{X^}): D^{X^} = x
12+
13+
def test(c1: C, c2: C) =
14+
val d: D^{c1, c2} = D()
15+
val x = f[CapSet^{c1, c2}](d)
16+
val _: D^{c1, c2} = x
17+
val d1: D^{c1} = D()
18+
val d2: D^{c2} = D()
19+
val y = g(d1, d2)
20+
val _: D^{d1, d2} = y
21+
val _: D^{c1, c2} = y
22+
23+

tests/pos/cc-poly-source.scala

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import language.experimental.captureChecking
2+
import annotation.experimental
3+
import caps.{CapSet, Capability}
4+
5+
@experimental object Test:
6+
7+
class Label //extends Capability
8+
9+
class Listener
10+
11+
class Source[X^]:
12+
private var listeners: Set[Listener^{X^}] = Set.empty
13+
def register(x: Listener^{X^}): Unit =
14+
listeners += x
15+
16+
def allListeners: Set[Listener^{X^}] = listeners
17+
18+
def test1(lbl1: Label^, lbl2: Label^) =
19+
val src = Source[CapSet^{lbl1, lbl2}]
20+
def l1: Listener^{lbl1} = ???
21+
val l2: Listener^{lbl2} = ???
22+
src.register{l1}
23+
src.register{l2}
24+
val ls = src.allListeners
25+
val _: Set[Listener^{lbl1, lbl2}] = ls
26+
27+
def test2(lbls: List[Label^]) =
28+
def makeListener(lbl: Label^): Listener^{lbl} = ???
29+
val listeners = lbls.map(makeListener)
30+
val src = Source[CapSet^{lbls*}]
31+
for l <- listeners do
32+
src.register(l)
33+
val ls = src.allListeners
34+
val _: Set[Listener^{lbls*}] = ls
35+
36+

0 commit comments

Comments
 (0)