Skip to content

Commit 02489af

Browse files
committed
Shorthand for curried functions
Propagate capture sets to the right in curried functions. Example: {x} A -> B -> C is a shorthand for {x} A -> {x} B -> C or: (x: {*} A) -> B -> C is a shorthand for (x: {*} A) -> {x} B -> C or: ({*} A) -> B -> C is a shorthand for (x$0: {*} A) -> {x$0} B -> C Also: allow empty capture sets in types This gives a more convenient override to disable capture set propagation in curried types than wrapping in a type alias. E.g. compare {x} A -> {} B -> C with {x} A -> Protect[B -> C] where type Protect[X] = X Also: refactoring to move setup code from Rechecker and CheckCaptures into a joint class cc.Setup.
1 parent d603c91 commit 02489af

File tree

9 files changed

+468
-323
lines changed

9 files changed

+468
-323
lines changed
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
package dotty.tools
2+
package dotc
3+
package cc
4+
5+
import core._
6+
import Phases.*, DenotTransformers.*, SymDenotations.*
7+
import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
8+
import Types.*, StdNames.*
9+
import config.Printers.capt
10+
import ast.tpd
11+
import transform.Recheck.*
12+
13+
class Setup(
14+
preRecheckPhase: DenotTransformer,
15+
thisPhase: DenotTransformer,
16+
recheckDef: (tpd.ValOrDefDef, Symbol) => Context ?=> Unit)
17+
extends tpd.TreeTraverser:
18+
import tpd.*
19+
20+
private def depFun(tycon: Type, argTypes: List[Type], resType: Type)(using Context): Type =
21+
MethodType.companion(
22+
isContextual = defn.isContextFunctionClass(tycon.classSymbol),
23+
isErased = defn.isErasedFunctionClass(tycon.classSymbol)
24+
)(argTypes, resType)
25+
.toFunctionType(isJava = false, alwaysDependent = true)
26+
27+
private def box(tp: Type)(using Context): Type = tp match
28+
case CapturingType(parent, refs, false) => CapturingType(parent, refs, true)
29+
case _ => tp
30+
31+
private def setBoxed(tp: Type)(using Context) = tp match
32+
case AnnotatedType(_, annot) if annot.symbol == defn.RetainsAnnot =>
33+
annot.tree.setBoxedCapturing()
34+
case _ =>
35+
36+
private def addBoxes(using Context) = new TypeTraverser:
37+
def traverse(t: Type) =
38+
t match
39+
case AppliedType(tycon, args) if !defn.isNonRefinedFunction(t) =>
40+
args.foreach(setBoxed)
41+
case TypeBounds(lo, hi) =>
42+
setBoxed(lo); setBoxed(hi)
43+
case _ =>
44+
traverseChildren(t)
45+
46+
/** Perform the following transformation steps everywhere in a type:
47+
* 1. Drop retains annotations
48+
* 2. Turn plain function types into dependent function types, so that
49+
* we can refer to their parameter in capture sets. Currently this is
50+
* only done at the toplevel, i.e. for function types that are not
51+
* themselves argument types of other function types. Without this restriction
52+
* boxmap-paper.scala fails. Need to figure out why.
53+
* 3. Refine other class types C by adding capture set variables to their parameter getters
54+
* (see addCaptureRefinements)
55+
* 4. Add capture set variables to all types that can be tracked
56+
*
57+
* Polytype bounds are only cleaned using step 1, but not otherwise transformed.
58+
*/
59+
private def mapInferred(using Context) = new TypeMap:
60+
61+
/** Drop @retains annotations everywhere */
62+
object cleanup extends TypeMap:
63+
def apply(t: Type) = t match
64+
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
65+
apply(parent)
66+
case _ =>
67+
mapOver(t)
68+
69+
/** Refine a possibly applied class type C where the class has tracked parameters
70+
* x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n }
71+
* where CV_1, ..., CV_n are fresh capture sets.
72+
*/
73+
def addCaptureRefinements(tp: Type): Type = tp match
74+
case _: TypeRef | _: AppliedType if tp.typeParams.isEmpty =>
75+
tp.typeSymbol match
76+
case cls: ClassSymbol if !defn.isFunctionClass(cls) =>
77+
cls.paramGetters.foldLeft(tp) { (core, getter) =>
78+
if getter.termRef.isTracked then
79+
val getterType = tp.memberInfo(getter).strippedDealias
80+
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var(), boxed = false))
81+
.showing(i"add capture refinement $tp --> $result", capt)
82+
else
83+
core
84+
}
85+
case _ => tp
86+
case _ => tp
87+
88+
/** Should a capture set variable be added on type `tp`? */
89+
def canHaveInferredCapture(tp: Type): Boolean =
90+
tp.typeParams.isEmpty && tp.match
91+
case tp: (TypeRef | AppliedType) =>
92+
val sym = tp.typeSymbol
93+
if sym.isClass then !sym.isValueClass && sym != defn.AnyClass
94+
else canHaveInferredCapture(tp.superType.dealias)
95+
case tp: (RefinedOrRecType | MatchType) =>
96+
canHaveInferredCapture(tp.underlying)
97+
case tp: AndType =>
98+
canHaveInferredCapture(tp.tp1) && canHaveInferredCapture(tp.tp2)
99+
case tp: OrType =>
100+
canHaveInferredCapture(tp.tp1) || canHaveInferredCapture(tp.tp2)
101+
case _ =>
102+
false
103+
104+
/** Add a capture set variable to `tp` if necessary, or maybe pull out
105+
* an embedded capture set variables from a part of `tp`.
106+
*/
107+
def addVar(tp: Type) = tp match
108+
case tp @ RefinedType(parent @ CapturingType(parent1, refs, boxed), rname, rinfo) =>
109+
CapturingType(tp.derivedRefinedType(parent1, rname, rinfo), refs, boxed)
110+
case tp: RecType =>
111+
tp.parent match
112+
case CapturingType(parent1, refs, boxed) =>
113+
CapturingType(tp.derivedRecType(parent1), refs, boxed)
114+
case _ =>
115+
tp // can return `tp` here since unlike RefinedTypes, RecTypes are never created
116+
// by `mapInferred`. Hence if the underlying type admits capture variables
117+
// a variable was already added, and the first case above would apply.
118+
case AndType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) =>
119+
assert(refs1.asVar.elems.isEmpty)
120+
assert(refs2.asVar.elems.isEmpty)
121+
assert(boxed1 == boxed2)
122+
CapturingType(AndType(parent1, parent2), refs1, boxed1)
123+
case tp @ OrType(CapturingType(parent1, refs1, boxed1), CapturingType(parent2, refs2, boxed2)) =>
124+
assert(refs1.asVar.elems.isEmpty)
125+
assert(refs2.asVar.elems.isEmpty)
126+
assert(boxed1 == boxed2)
127+
CapturingType(OrType(parent1, parent2, tp.isSoft), refs1, boxed1)
128+
case tp @ OrType(CapturingType(parent1, refs1, boxed1), tp2) =>
129+
CapturingType(OrType(parent1, tp2, tp.isSoft), refs1, boxed1)
130+
case tp @ OrType(tp1, CapturingType(parent2, refs2, boxed2)) =>
131+
CapturingType(OrType(tp1, parent2, tp.isSoft), refs2, boxed2)
132+
case _ if canHaveInferredCapture(tp) =>
133+
CapturingType(tp, CaptureSet.Var(), boxed = false)
134+
case _ =>
135+
tp
136+
137+
var isTopLevel = true
138+
139+
def mapNested(ts: List[Type]): List[Type] =
140+
val saved = isTopLevel
141+
isTopLevel = false
142+
try ts.mapConserve(this) finally isTopLevel = saved
143+
144+
def apply(t: Type) =
145+
val t1 = t match
146+
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
147+
apply(parent)
148+
case tp @ AppliedType(tycon, args) =>
149+
val tycon1 = this(tycon)
150+
if defn.isNonRefinedFunction(tp) then
151+
val args1 = mapNested(args.init)
152+
val res1 = this(args.last)
153+
if isTopLevel then
154+
depFun(tycon1, args1, res1)
155+
.showing(i"add function refinement $tp --> $result", capt)
156+
else
157+
tp.derivedAppliedType(tycon1, args1 :+ res1)
158+
else
159+
tp.derivedAppliedType(tycon1, args.mapConserve(arg => box(this(arg))))
160+
case tp @ RefinedType(core, rname, rinfo) if defn.isFunctionType(tp) =>
161+
val rinfo1 = apply(rinfo)
162+
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
163+
else tp
164+
case tp: MethodType =>
165+
tp.derivedLambdaType(
166+
paramInfos = mapNested(tp.paramInfos),
167+
resType = this(tp.resType))
168+
case tp: TypeLambda =>
169+
// Don't recurse into parameter bounds, just cleanup any stray retains annotations
170+
tp.derivedLambdaType(
171+
paramInfos = tp.paramInfos.mapConserve(cleanup(_).bounds),
172+
resType = this(tp.resType))
173+
case _ =>
174+
mapOver(t)
175+
addVar(addCaptureRefinements(t1))
176+
end mapInferred
177+
178+
private def expandAbbreviations(using Context) = new TypeMap:
179+
180+
def propagateMethodResult(tp: Type, outerCs: CaptureSet, deep: Boolean): Type = tp match
181+
case tp: MethodType =>
182+
if deep then
183+
val tp1 = tp.derivedLambdaType(paramInfos = tp.paramInfos.mapConserve(this))
184+
propagateMethodResult(tp1, outerCs, deep = false)
185+
else
186+
val localCs = CaptureSet(tp.paramRefs.filter(_.isTracked)*)
187+
tp.derivedLambdaType(
188+
resType = propagateEnclosing(tp.resType, CaptureSet.empty, outerCs ++ localCs))
189+
190+
def propagateDepFunctionResult(tp: Type, outerCs: CaptureSet, deep: Boolean): Type = tp match
191+
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) =>
192+
val rinfo1 = propagateMethodResult(rinfo, outerCs, deep)
193+
if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true)
194+
else tp
195+
196+
def propagateEnclosing(tp: Type, currentCs: CaptureSet, outerCs: CaptureSet): Type = tp match
197+
case tp @ AppliedType(tycon, args) if defn.isFunctionClass(tycon.typeSymbol) =>
198+
val tycon1 = this(tycon)
199+
val args1 = args.init.mapConserve(this)
200+
val tp1 =
201+
if args1.exists(!_.captureSet.isAlwaysEmpty) then
202+
val propagated = propagateDepFunctionResult(
203+
depFun(tycon, args1, args.last), currentCs ++ outerCs, deep = false)
204+
propagated match
205+
case RefinedType(_, _, mt: MethodType) =>
206+
val following = mt.resType.captureSet.elems
207+
if mt.paramRefs.exists(following.contains(_)) then propagated
208+
else tp.derivedAppliedType(tycon1, args1 :+ mt.resType)
209+
else
210+
val resType1 = propagateEnclosing(
211+
args.last, CaptureSet.empty, currentCs ++ outerCs)
212+
tp.derivedAppliedType(tycon1, args1 :+ resType1)
213+
tp1.capturing(outerCs)
214+
case tp @ RefinedType(parent, nme.apply, rinfo: MethodType) if defn.isFunctionType(tp) =>
215+
propagateDepFunctionResult(tp, currentCs ++ outerCs, deep = true)
216+
.capturing(outerCs)
217+
case _ =>
218+
mapOver(tp)
219+
220+
def apply(tp: Type): Type = tp match
221+
case CapturingType(parent, cs, boxed) =>
222+
tp.derivedCapturingType(propagateEnclosing(parent, cs, CaptureSet.empty), cs)
223+
case _ =>
224+
propagateEnclosing(tp, CaptureSet.empty, CaptureSet.empty)
225+
end expandAbbreviations
226+
227+
private def transformInferredType(tp: Type, boxed: Boolean)(using Context): Type =
228+
val tp1 = mapInferred(tp)
229+
if boxed then box(tp1) else tp1
230+
231+
private def transformExplicitType(tp: Type, boxed: Boolean)(using Context): Type =
232+
addBoxes.traverse(tp)
233+
if boxed then setBoxed(tp)
234+
if ctx.settings.YccNoAbbrev.value then tp
235+
else expandAbbreviations(tp)
236+
237+
// Substitute parameter symbols in `from` to paramRefs in corresponding
238+
// method or poly types `to`. We use a single BiTypeMap to do everything.
239+
private class SubstParams(from: List[List[Symbol]], to: List[LambdaType])(using Context)
240+
extends DeepTypeMap, BiTypeMap:
241+
242+
def apply(t: Type): Type = t match
243+
case t: NamedType =>
244+
val sym = t.symbol
245+
def outer(froms: List[List[Symbol]], tos: List[LambdaType]): Type =
246+
def inner(from: List[Symbol], to: List[ParamRef]): Type =
247+
if from.isEmpty then outer(froms.tail, tos.tail)
248+
else if sym eq from.head then to.head
249+
else inner(from.tail, to.tail)
250+
if tos.isEmpty then t
251+
else inner(froms.head, tos.head.paramRefs)
252+
outer(from, to)
253+
case _ =>
254+
mapOver(t)
255+
256+
def inverse(t: Type): Type = t match
257+
case t: ParamRef =>
258+
def recur(from: List[LambdaType], to: List[List[Symbol]]): Type =
259+
if from.isEmpty then t
260+
else if t.binder eq from.head then to.head(t.paramNum).namedType
261+
else recur(from.tail, to.tail)
262+
recur(to, from)
263+
case _ =>
264+
mapOver(t)
265+
end SubstParams
266+
267+
private def transformTT(tree: TypeTree, boxed: Boolean)(using Context) =
268+
tree.rememberType(
269+
if tree.isInstanceOf[InferredTypeTree]
270+
then transformInferredType(tree.tpe, boxed)
271+
else transformExplicitType(tree.tpe, boxed))
272+
273+
def traverse(tree: Tree)(using Context) =
274+
tree match
275+
case tree @ ValDef(_, tpt: TypeTree, _) if tree.symbol.is(Mutable) =>
276+
transformTT(tpt, boxed = true)
277+
traverse(tree.rhs)
278+
case _ =>
279+
traverseChildren(tree)
280+
tree match
281+
case tree: TypeTree =>
282+
transformTT(tree, boxed = false)
283+
case tree: ValOrDefDef =>
284+
val sym = tree.symbol
285+
286+
// replace an existing symbol info with inferred types
287+
def integrateRT(
288+
info: Type, // symbol info to replace
289+
psymss: List[List[Symbol]], // the local (type and trem) parameter symbols corresponding to `info`
290+
prevPsymss: List[List[Symbol]], // the local parameter symbols seen previously in reverse order
291+
prevLambdas: List[LambdaType] // the outer method and polytypes generated previously in reverse order
292+
): Type =
293+
info match
294+
case mt: MethodOrPoly =>
295+
val psyms = psymss.head
296+
mt.companion(mt.paramNames)(
297+
mt1 =>
298+
if !psyms.exists(_.isUpdatedAfter(preRecheckPhase)) && !mt.isParamDependent && prevLambdas.isEmpty then
299+
mt.paramInfos
300+
else
301+
val subst = SubstParams(psyms :: prevPsymss, mt1 :: prevLambdas)
302+
psyms.map(psym => subst(psym.info).asInstanceOf[mt.PInfo]),
303+
mt1 =>
304+
integrateRT(mt.resType, psymss.tail, psyms :: prevPsymss, mt1 :: prevLambdas)
305+
)
306+
case info: ExprType =>
307+
info.derivedExprType(resType =
308+
integrateRT(info.resType, psymss, prevPsymss, prevLambdas))
309+
case _ =>
310+
val restp = tree.tpt.knownType
311+
if prevLambdas.isEmpty then restp
312+
else SubstParams(prevPsymss, prevLambdas)(restp)
313+
314+
if tree.tpt.hasRememberedType && !sym.isConstructor then
315+
val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil)
316+
.showing(i"update info $sym: ${sym.info} --> $result", capt)
317+
if newInfo ne sym.info then
318+
val completer = new LazyType:
319+
def complete(denot: SymDenotation)(using Context) =
320+
denot.info = newInfo
321+
recheckDef(tree, sym)
322+
sym.updateInfoBetween(preRecheckPhase, thisPhase, completer)
323+
case tree: Bind =>
324+
val sym = tree.symbol
325+
sym.updateInfoBetween(preRecheckPhase, thisPhase,
326+
transformInferredType(sym.info, boxed = false))
327+
case _ =>
328+
end Setup

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ private sealed trait YSettings:
328328
val Yrecheck: Setting[Boolean] = BooleanSetting("-Yrecheck", "Run type rechecks (test only)")
329329
val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references")
330330
val YccDebug: Setting[Boolean] = BooleanSetting("-Ycc-debug", "Debug info for captured references")
331+
val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with -Ycc, suppress type abbreviations")
331332

332333
/** Area-specific debug output */
333334
val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.")

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -957,21 +957,21 @@ object Parsers {
957957

958958
def followingIsCaptureSet(): Boolean =
959959
val lookahead = in.LookaheadScanner()
960+
def followingIsTypeStart() =
961+
lookahead.nextToken()
962+
canStartInfixTypeTokens.contains(lookahead.token)
963+
|| lookahead.token == LBRACKET
960964
def recur(): Boolean =
961965
(lookahead.isIdent || lookahead.token == THIS) && {
962966
lookahead.nextToken()
963967
if lookahead.token == COMMA then
964968
lookahead.nextToken()
965969
recur()
966970
else
967-
lookahead.token == RBRACE && {
968-
lookahead.nextToken()
969-
canStartInfixTypeTokens.contains(lookahead.token)
970-
|| lookahead.token == LBRACKET
971-
}
971+
lookahead.token == RBRACE && followingIsTypeStart()
972972
}
973973
lookahead.nextToken()
974-
recur()
974+
if lookahead.token == RBRACE then followingIsTypeStart() else recur()
975975

976976
/* --------- OPERAND/OPERATOR STACK --------------------------------------- */
977977

@@ -1551,7 +1551,9 @@ object Parsers {
15511551
else { accept(TLARROW); typ() }
15521552
}
15531553
else if in.token == LBRACE && followingIsCaptureSet() then
1554-
val refs = inBraces { commaSeparated(captureRef) }
1554+
val refs = inBraces {
1555+
if in.token == RBRACE then Nil else commaSeparated(captureRef)
1556+
}
15551557
val t = typ()
15561558
CapturingTypeTree(refs, t)
15571559
else if (in.token == INDENT) enclosed(INDENT, typ())

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import config.Config
2828

2929
import dotty.tools.dotc.util.SourcePosition
3030
import dotty.tools.dotc.ast.untpd.{MemberDef, Modifiers, PackageDef, RefTree, Template, TypeDef, ValOrDefDef}
31-
import cc.{EventuallyCapturingType, CaptureSet, toCaptureSet, IllegalCaptureRef}
31+
import cc.{CaptureSet, toCaptureSet, IllegalCaptureRef}
3232

3333
class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
3434

0 commit comments

Comments
 (0)