Skip to content

Commit e7763eb

Browse files
authored
Treat more closure parameter types as inferred (#21583)
This is necessary for types that contain possibly illegal @retains annotations since those annotations are only removed before pickling for InferredTypes. Fixes #21437
2 parents 61adf81 + 45728af commit e7763eb

File tree

4 files changed

+37
-17
lines changed

4 files changed

+37
-17
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,13 @@ sealed abstract class CaptureSet extends Showable:
158158
* as frozen.
159159
*/
160160
def accountsFor(x: CaptureRef)(using Context): Boolean =
161-
reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true):
161+
def debugInfo(using Context) = i"$this accountsFor $x, which has capture set ${x.captureSetOfInfo}"
162+
def test(using Context) = reporting.trace(debugInfo):
162163
elems.exists(_.subsumes(x))
163164
|| !x.isMaxCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK
165+
comparer match
166+
case comparer: ExplainingTypeComparer => comparer.traceIndented(debugInfo)(test)
167+
case _ => test
164168

165169
/** A more optimistic version of accountsFor, which does not take variable supersets
166170
* of the `x` reference into account. A set might account for `x` if it accounts

compiler/src/dotty/tools/dotc/core/TypeOps.scala

+12-14
Original file line numberDiff line numberDiff line change
@@ -691,20 +691,18 @@ object TypeOps:
691691
val hiBound = instantiate(bounds.hi, skolemizedArgTypes)
692692
val loBound = instantiate(bounds.lo, skolemizedArgTypes)
693693

694-
def check(tp1: Type, tp2: Type, which: String, bound: Type)(using Context) = {
695-
val isSub = TypeComparer.explaining { cmp =>
696-
val isSub = cmp.isSubType(tp1, tp2)
697-
if !isSub then
698-
if !ctx.typerState.constraint.domainLambdas.isEmpty then
699-
typr.println(i"${ctx.typerState.constraint}")
700-
if !ctx.gadt.symbols.isEmpty then
701-
typr.println(i"${ctx.gadt}")
702-
typr.println(cmp.lastTrace(i"checkOverlapsBounds($lo, $hi, $arg, $bounds)($which)"))
703-
//trace.dumpStack()
704-
isSub
705-
}//(using ctx.fresh.setSetting(ctx.settings.verbose, true)) // uncomment to enable moreInfo in ExplainingTypeComparer recur
706-
if !isSub then violations += ((arg, which, bound))
707-
}
694+
def check(tp1: Type, tp2: Type, which: String, bound: Type)(using Context) =
695+
val isSub = TypeComparer.isSubType(tp1, tp2)
696+
if !isSub then
697+
// inContext(ctx.fresh.setSetting(ctx.settings.verbose, true)): // uncomment to enable moreInfo in ExplainingTypeComparer
698+
TypeComparer.explaining: cmp =>
699+
if !ctx.typerState.constraint.domainLambdas.isEmpty then
700+
typr.println(i"${ctx.typerState.constraint}")
701+
if !ctx.gadt.symbols.isEmpty then
702+
typr.println(i"${ctx.gadt}")
703+
typr.println(cmp.lastTrace(i"checkOverlapsBounds($lo, $hi, $arg, $bounds)($which)"))
704+
violations += ((arg, which, bound))
705+
708706
check(lo, hiBound, "upper", hiBound)(using checkCtx)
709707
check(loBound, hi, "lower", loBound)(using checkCtx)
710708
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

+9-2
Original file line numberDiff line numberDiff line change
@@ -1903,9 +1903,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19031903
if knownFormal then formal0
19041904
else errorType(AnonymousFunctionMissingParamType(param, tree, inferredType = formal, expectedType = pt), param.srcPos)
19051905
)
1906+
val untpdTpt = formal match
1907+
case _: WildcardType =>
1908+
// In this case we have a situation like f(_), where we expand in the end to
1909+
// (x: T) => f(x) and `T` is taken from `f`'s declared parameters. In this case
1910+
// we treat the type as declared instead of inferred. InferredType is used for
1911+
// types that are inferred from the context.
1912+
untpd.TypeTree()
1913+
case _ => InferredTypeTree()
19061914
val paramTpt = untpd.TypedSplice(
1907-
(if knownFormal then InferredTypeTree() else untpd.TypeTree())
1908-
.withType(paramType.translateFromRepeated(toArray = false))
1915+
untpdTpt.withType(paramType.translateFromRepeated(toArray = false))
19091916
.withSpan(param.span.endPos)
19101917
)
19111918
val param0 = cpy.ValDef(param)(tpt = paramTpt)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//> using scala 3.6.0-RC1-bin-SNAPSHOT
2+
3+
import language.experimental.captureChecking
4+
5+
class Box[Cap^] {}
6+
7+
def run[Cap^](f: Box[Cap]^{Cap^} => Unit): Box[Cap]^{Cap^} = ???
8+
9+
def main() =
10+
val b = run(_ => ())
11+
// val b = run[caps.CapSet](_ => ()) // this compiles

0 commit comments

Comments
 (0)