Skip to content

Commit b1da91a

Browse files
committed
Avoid forming intersections of capture sets on refined type lookup
1 parent 67d42bd commit b1da91a

File tree

10 files changed

+63
-26
lines changed

10 files changed

+63
-26
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,7 @@ object CaptureSet:
10731073
* return whether this was allowed. By default, recording is allowed
10741074
* but the special state VarState.Separate overrides this.
10751075
*/
1076-
def addHidden(hidden: HiddenSet, elem: CaptureRef): Boolean =
1076+
def addHidden(hidden: HiddenSet, elem: CaptureRef)(using Context): Boolean =
10771077
elemsMap.get(hidden) match
10781078
case None => elemsMap(hidden) = hidden.elems
10791079
case _ =>
@@ -1112,7 +1112,7 @@ object CaptureSet:
11121112
*/
11131113
@sharable
11141114
object Separate extends Closed:
1115-
override def addHidden(hidden: HiddenSet, elem: CaptureRef): Boolean = false
1115+
override def addHidden(hidden: HiddenSet, elem: CaptureRef)(using Context): Boolean = false
11161116

11171117
/** A special state that turns off recording of elements. Used only
11181118
* in `addSub` to prevent cycles in recordings.
@@ -1122,14 +1122,14 @@ object CaptureSet:
11221122
override def putElems(v: Var, refs: Refs) = true
11231123
override def putDeps(v: Var, deps: Deps) = true
11241124
override def rollBack(): Unit = ()
1125-
override def addHidden(hidden: HiddenSet, elem: CaptureRef): Boolean = true
1125+
override def addHidden(hidden: HiddenSet, elem: CaptureRef)(using Context): Boolean = true
11261126

11271127
/** A closed state that turns off recording of hidden elements (but allows
11281128
* adding them). Used in `mightAccountFor`.
11291129
*/
11301130
@sharable
11311131
private[CaptureSet] object ClosedUnrecorded extends Closed:
1132-
override def addHidden(hidden: HiddenSet, elem: CaptureRef): Boolean = true
1132+
override def addHidden(hidden: HiddenSet, elem: CaptureRef)(using Context): Boolean = true
11331133

11341134
end VarState
11351135

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import StdNames.nme
2424
import NameKinds.{DefaultGetterName, WildcardParamName, UniqueNameKind}
2525
import reporting.{trace, Message, OverrideError}
2626
import Existential.derivedExistentialType
27+
import Annotations.Annotation
2728

2829
/** The capture checker */
2930
object CheckCaptures:
@@ -785,7 +786,8 @@ class CheckCaptures extends Recheck, SymTransformer:
785786
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
786787
val getter = cls.info.member(getterName).suchThat(_.isRefiningParamAccessor).symbol
787788
if !getter.is(Private) && getter.hasTrackedParts then
788-
refined = RefinedType(refined, getterName, argType.unboxed) // Yichen you might want to check this
789+
refined = RefinedType(refined, getterName,
790+
AnnotatedType(argType.unboxed, Annotation(defn.RefineOverrideAnnot, util.Spans.NoSpan))) // Yichen you might want to check this
789791
allCaptures ++= argType.captureSet
790792
(refined, allCaptures)
791793

compiler/src/dotty/tools/dotc/core/Definitions.scala

+5
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ class Definitions {
10711071
@tu lazy val UncheckedCapturesAnnot: ClassSymbol = requiredClass("scala.annotation.unchecked.uncheckedCaptures")
10721072
@tu lazy val UntrackedCapturesAnnot: ClassSymbol = requiredClass("scala.caps.untrackedCaptures")
10731073
@tu lazy val UseAnnot: ClassSymbol = requiredClass("scala.caps.use")
1074+
@tu lazy val RefineOverrideAnnot: ClassSymbol = requiredClass("scala.caps.refineOverride")
10741075
@tu lazy val VolatileAnnot: ClassSymbol = requiredClass("scala.volatile")
10751076
@tu lazy val LanguageFeatureMetaAnnot: ClassSymbol = requiredClass("scala.annotation.meta.languageFeature")
10761077
@tu lazy val BeanGetterMetaAnnot: ClassSymbol = requiredClass("scala.annotation.meta.beanGetter")
@@ -1111,6 +1112,10 @@ class Definitions {
11111112
@tu lazy val MetaAnnots: Set[Symbol] =
11121113
NonBeanMetaAnnots + BeanGetterMetaAnnot + BeanSetterMetaAnnot
11131114

1115+
// Set of annotations that are not printed in types except under -Yprint-debug
1116+
@tu lazy val SilentAnnots: Set[Symbol] =
1117+
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot)
1118+
11141119
// A list of annotations that are commonly used to indicate that a field/method argument or return
11151120
// type is not null. These annotations are used by the nullification logic in JavaNullInterop to
11161121
// improve the precision of type nullification.

compiler/src/dotty/tools/dotc/core/Types.scala

+18-15
Original file line numberDiff line numberDiff line change
@@ -860,21 +860,24 @@ object Types extends TypeUtils {
860860
pinfo recoverable_& rinfo
861861
pdenot.asSingleDenotation.derivedSingleDenotation(pdenot.symbol, jointInfo)
862862
}
863-
else
864-
val isRefinedMethod = rinfo.isInstanceOf[MethodOrPoly]
865-
val joint = pdenot.meet(
866-
new JointRefDenotation(NoSymbol, rinfo, Period.allInRun(ctx.runId), pre, isRefinedMethod),
867-
pre,
868-
safeIntersection = ctx.base.pendingMemberSearches.contains(name))
869-
joint match
870-
case joint: SingleDenotation
871-
if isRefinedMethod
872-
&& (rinfo <:< joint.info
873-
|| name == nme.apply && defn.isFunctionType(tp.parent)) =>
874-
// use `rinfo` to keep the right parameter names for named args. See i8516.scala.
875-
joint.derivedSingleDenotation(joint.symbol, rinfo, pre, isRefinedMethod)
876-
case _ =>
877-
joint
863+
else rinfo match
864+
case AnnotatedType(rinfo1, ann) if ann.symbol == defn.RefineOverrideAnnot =>
865+
pdenot.asSingleDenotation.derivedSingleDenotation(pdenot.symbol, rinfo1)
866+
case _ =>
867+
val isRefinedMethod = rinfo.isInstanceOf[MethodOrPoly]
868+
val joint = pdenot.meet(
869+
new JointRefDenotation(NoSymbol, rinfo, Period.allInRun(ctx.runId), pre, isRefinedMethod),
870+
pre,
871+
safeIntersection = ctx.base.pendingMemberSearches.contains(name))
872+
joint match
873+
case joint: SingleDenotation
874+
if isRefinedMethod
875+
&& (rinfo <:< joint.info
876+
|| name == nme.apply && defn.isFunctionType(tp.parent)) =>
877+
// use `rinfo` to keep the right parameter names for named args. See i8516.scala.
878+
joint.derivedSingleDenotation(joint.symbol, rinfo, pre, isRefinedMethod)
879+
case _ =>
880+
joint
878881
}
879882

880883
def goApplied(tp: AppliedType, tycon: HKTypeLambda) =

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
310310
toTextGlobal(tp.resultType)
311311
}
312312
case AnnotatedType(tpe, annot) =>
313-
if annot.symbol == defn.InlineParamAnnot || annot.symbol == defn.ErasedParamAnnot
314-
then toText(tpe)
313+
if defn.SilentAnnots.contains(annot.symbol) && !printDebug then
314+
toText(tpe)
315315
else if (annot.symbol == defn.IntoAnnot || annot.symbol == defn.IntoParamAnnot)
316316
&& !printDebug
317317
then atPrec(GlobalPrec)( Str("into ") ~ toText(tpe) )

library/src/scala/caps.scala

+7
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ import annotation.{experimental, compileTimeOnly, retainsCap}
7070
*/
7171
final class use extends annotation.StaticAnnotation
7272

73+
/** An annotation placed on a refinement created by capture checking.
74+
* Refinements with this annotation unconditionally override any
75+
* info vfrom the parent type, so no intersection needs to be formed.
76+
* This could be useful for tracked parameters as well.
77+
*/
78+
final class refineOverride extends annotation.StaticAnnotation
79+
7380
object unsafe:
7481

7582
extension [T](x: T)

scala2-library-cc/src/scala/collection/View.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,8 @@ object View extends IterableFactory[View] {
150150
object Filter {
151151
def apply[A](underlying: Iterable[A]^, p: A => Boolean, isFlipped: Boolean): Filter[A]^{underlying, p} =
152152
underlying match {
153-
case filter: Filter[A] if filter.isFlipped == isFlipped =>
153+
case filter: Filter[A]^{underlying} if filter.isFlipped == isFlipped =>
154154
new Filter(filter.underlying, a => filter.p(a) && p(a), isFlipped)
155-
.asInstanceOf[Filter[A]^{underlying, p}]
156-
// !!! asInstanceOf needed once paths were added, see path-patmat-should-be-pos.scala for minimization
157155
case _ => new Filter(underlying, p, isFlipped)
158156
}
159157
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt-depfun.scala:10:43 ----------------------------------
2+
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
3+
| ^^^^^^^
4+
| Found: Str^{} ->{ac, y, z} Str^{y, z}
5+
| Required: Str^{y, z} ->{fresh} Str^{y, z}
6+
|
7+
| longer explanation available when compiling with `-explain`
8+
-- Error: tests/neg-custom-args/captures/capt-depfun.scala:10:24 -------------------------------------------------------
9+
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
10+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
11+
| Separation failure: Str^{y, z} => Str^{y, z} captures a root element hiding {ac, y, z}
12+
| and also refers to {y, z}.
13+
| The two sets overlap at {y, z}
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import annotation.retains
2+
import language.future // sepchecks on
23
class C
34
type Cap = C @retains(caps.cap)
45
class Str
56

67
def f(y: Cap, z: Cap) =
78
def g(): C @retains(y, z) = ???
89
val ac: ((x: Cap) => Str @retains(x) => Str @retains(x)) = ???
9-
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error
10+
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck

tests/run-custom-args/captures/colltest5/CollectionStrawManCC5_1.scala

+8
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,14 @@ object CollectionStrawMan5 {
452452
this: Filter[A]^{underlying, p} =>
453453
def iterator: Iterator[A]^{this} = underlying.iterator.filter(p)
454454
}
455+
456+
object Filter:
457+
def apply[A](underlying: Iterable[A]^, pp: A => Boolean, isFlipped: Boolean): Filter[A]^{underlying, pp} =
458+
underlying match
459+
case filter: Filter[A]^{underlying} =>
460+
new Filter(filter.underlying, a => filter.p(a) && pp(a))
461+
case _ => new Filter(underlying, pp)
462+
455463
case class Partition[A](val underlying: Iterable[A]^, p: A => Boolean) {
456464
self: Partition[A]^{underlying, p} =>
457465

0 commit comments

Comments
 (0)