@@ -249,8 +249,7 @@ sealed abstract class CaptureSet extends Showable:
249
249
if this .subCaptures(that, frozen = true ).isOK then that
250
250
else if that.subCaptures(this , frozen = true ).isOK then this
251
251
else if this .isConst && that.isConst then Const (this .elems ++ that.elems)
252
- else Var (initialElems = this .elems ++ that.elems)
253
- .addAsDependentTo(this ).addAsDependentTo(that)
252
+ else Union (this , that)
254
253
255
254
/** The smallest superset (via <:<) of this capture set that also contains `ref`.
256
255
*/
@@ -263,7 +262,7 @@ sealed abstract class CaptureSet extends Showable:
263
262
if this .subCaptures(that, frozen = true ).isOK then this
264
263
else if that.subCaptures(this , frozen = true ).isOK then that
265
264
else if this .isConst && that.isConst then Const (elemIntersection(this , that))
266
- else Intersected (this , that)
265
+ else Intersection (this , that)
267
266
268
267
/** The largest subset (via <:<) of this capture set that does not account for
269
268
* any of the elements in the constant capture set `that`
@@ -816,7 +815,29 @@ object CaptureSet:
816
815
class Diff (source : Var , other : Const )(using Context )
817
816
extends Filtered (source, ! other.accountsFor(_))
818
817
819
- class Intersected (cs1 : CaptureSet , cs2 : CaptureSet )(using Context )
818
+ class Union (cs1 : CaptureSet , cs2 : CaptureSet )(using Context )
819
+ extends Var (initialElems = cs1.elems ++ cs2.elems):
820
+ addAsDependentTo(cs1)
821
+ addAsDependentTo(cs2)
822
+
823
+ override def tryInclude (elem : CaptureRef , origin : CaptureSet )(using Context , VarState ): CompareResult =
824
+ if accountsFor(elem) then CompareResult .OK
825
+ else
826
+ val res = super .tryInclude(elem, origin)
827
+ // If this is the union of a constant and a variable,
828
+ // propagate `elem` to the variable part to avoid slack
829
+ // between the operands and the union.
830
+ if res.isOK && (origin ne cs1) && (origin ne cs2) then
831
+ if cs1.isConst then cs2.tryInclude(elem, origin)
832
+ else if cs2.isConst then cs1.tryInclude(elem, origin)
833
+ else res
834
+ else res
835
+
836
+ override def propagateSolved ()(using Context ) =
837
+ if cs1.isConst && cs2.isConst && ! isConst then markSolved()
838
+ end Union
839
+
840
+ class Intersection (cs1 : CaptureSet , cs2 : CaptureSet )(using Context )
820
841
extends Var (initialElems = elemIntersection(cs1, cs2)):
821
842
addAsDependentTo(cs1)
822
843
addAsDependentTo(cs2)
@@ -841,7 +862,7 @@ object CaptureSet:
841
862
842
863
override def propagateSolved ()(using Context ) =
843
864
if cs1.isConst && cs2.isConst && ! isConst then markSolved()
844
- end Intersected
865
+ end Intersection
845
866
846
867
def elemIntersection (cs1 : CaptureSet , cs2 : CaptureSet )(using Context ): Refs =
847
868
cs1.elems.filter(cs2.mightAccountFor) ++ cs2.elems.filter(cs1.mightAccountFor)
0 commit comments