@@ -21,7 +21,7 @@ import CheckRealizable._
21
21
import Variances .{Variance , setStructuralVariances , Invariant }
22
22
import typer .Nullables
23
23
import util .Stats ._
24
- import util .SimpleIdentitySet
24
+ import util .{ SimpleIdentityMap , SimpleIdentitySet }
25
25
import ast .tpd ._
26
26
import ast .TreeTypeMap
27
27
import printing .Texts ._
@@ -1751,7 +1751,7 @@ object Types {
1751
1751
t
1752
1752
case t if defn.isErasedFunctionType(t) =>
1753
1753
t
1754
- case t @ SAMType (_) =>
1754
+ case t @ SAMType (_, _ ) =>
1755
1755
t
1756
1756
case _ =>
1757
1757
NoType
@@ -5520,105 +5520,119 @@ object Types {
5520
5520
* A type is a SAM type if it is a reference to a class or trait, which
5521
5521
*
5522
5522
* - has a single abstract method with a method type (ExprType
5523
- * and PolyType not allowed!) whose result type is not an implicit function type
5524
- * and which is not marked inline.
5523
+ * and PolyType not allowed!) according to `possibleSamMethods`.
5525
5524
* - can be instantiated without arguments or with just () as argument.
5526
5525
*
5527
- * The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5528
- * type of the single abstract method.
5526
+ * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5527
+ * type of the single abstract method and `samParent` is a subtype of the matched
5528
+ * SAM type which has been stripped of wildcards to turn it into a valid parent
5529
+ * type.
5529
5530
*/
5530
5531
object SAMType {
5531
- def zeroParamClass (tp : Type )(using Context ): Type = tp match {
5532
+ /** If possible, return a type which is both a subtype of `origTp` and a type
5533
+ * application of `samClass` where none of the type arguments are
5534
+ * wildcards (thus making it a valid parent type), otherwise return
5535
+ * NoType.
5536
+ *
5537
+ * A wildcard in the original type will be replaced by its upper or lower bound in a way
5538
+ * that maximizes the number of possible implementations of `samMeth`. For example,
5539
+ * java.util.function defines an interface equivalent to:
5540
+ *
5541
+ * trait Function[T, R]:
5542
+ * def apply(t: T): R
5543
+ *
5544
+ * and it usually appears with wildcards to compensate for the lack of
5545
+ * definition-site variance in Java:
5546
+ *
5547
+ * (x => x.toInt): Function[? >: String, ? <: Int]
5548
+ *
5549
+ * When typechecking this lambda, we need to approximate the wildcards to find
5550
+ * a valid parent type for our lambda to extend. We can see that in `apply`,
5551
+ * `T` only appears contravariantly and `R` only appears covariantly, so by
5552
+ * minimizing the first parameter and maximizing the second, we maximize the
5553
+ * number of valid implementations of `apply` which lets us implement the lambda
5554
+ * with a closure equivalent to:
5555
+ *
5556
+ * new Function[String, Int] { def apply(x: String): Int = x.toInt }
5557
+ *
5558
+ * If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5559
+ * we arbitrarily pick the upper-bound.
5560
+ */
5561
+ def samParent (origTp : Type , samClass : Symbol , samMeth : Symbol )(using Context ): Type =
5562
+ val tp = origTp.baseType(samClass)
5563
+ if ! (tp <:< origTp) then NoType
5564
+ else tp match
5565
+ case tp @ AppliedType (tycon, args) if tp.hasWildcardArg =>
5566
+ val accu = new TypeAccumulator [VarianceMap [Symbol ]]:
5567
+ def apply (vmap : VarianceMap [Symbol ], t : Type ): VarianceMap [Symbol ] = t match
5568
+ case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) =>
5569
+ vmap.recordLocalVariance(tp.symbol, variance)
5570
+ case _ =>
5571
+ foldOver(vmap, t)
5572
+ val vmap = accu(VarianceMap .empty, samMeth.info)
5573
+ val tparams = tycon.typeParamSymbols
5574
+ val args1 = args.zipWithConserve(tparams):
5575
+ case (arg @ TypeBounds (lo, hi), tparam) =>
5576
+ val v = vmap.computedVariance(tparam)
5577
+ if v.uncheckedNN < 0 then lo
5578
+ else hi
5579
+ case (arg, _) => arg
5580
+ tp.derivedAppliedType(tycon, args1)
5581
+ case _ =>
5582
+ tp
5583
+
5584
+ def samClass (tp : Type )(using Context ): Symbol = tp match
5532
5585
case tp : ClassInfo =>
5533
- def zeroParams (tp : Type ): Boolean = tp.stripPoly match {
5586
+ def zeroParams (tp : Type ): Boolean = tp.stripPoly match
5534
5587
case mt : MethodType => mt.paramInfos.isEmpty && ! mt.resultType.isInstanceOf [MethodType ]
5535
5588
case et : ExprType => true
5536
5589
case _ => false
5537
- }
5538
- // `ContextFunctionN` does not have constructors
5539
- val ctor = tp.cls.primaryConstructor
5540
- if (! ctor.exists || zeroParams(ctor.info)) tp
5541
- else NoType
5590
+ val cls = tp.cls
5591
+ val validCtor =
5592
+ val ctor = cls.primaryConstructor
5593
+ // `ContextFunctionN` does not have constructors
5594
+ ! ctor.exists || zeroParams(ctor.info)
5595
+ val isInstantiable = ! cls.isOneOf(FinalOrSealed ) && (tp.appliedRef <:< tp.selfType)
5596
+ if validCtor && isInstantiable then tp.cls
5597
+ else NoSymbol
5542
5598
case tp : AppliedType =>
5543
- zeroParamClass (tp.superType)
5599
+ samClass (tp.superType)
5544
5600
case tp : TypeRef =>
5545
- zeroParamClass (tp.underlying)
5601
+ samClass (tp.underlying)
5546
5602
case tp : RefinedType =>
5547
- zeroParamClass (tp.underlying)
5603
+ samClass (tp.underlying)
5548
5604
case tp : TypeBounds =>
5549
- zeroParamClass (tp.underlying)
5605
+ samClass (tp.underlying)
5550
5606
case tp : TypeVar =>
5551
- zeroParamClass (tp.underlying)
5607
+ samClass (tp.underlying)
5552
5608
case tp : AnnotatedType =>
5553
- zeroParamClass(tp.underlying)
5554
- case _ =>
5555
- NoType
5556
- }
5557
- def isInstantiatable (tp : Type )(using Context ): Boolean = zeroParamClass(tp) match {
5558
- case cinfo : ClassInfo if ! cinfo.cls.isOneOf(FinalOrSealed ) =>
5559
- val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5560
- tp <:< selfType
5609
+ samClass(tp.underlying)
5561
5610
case _ =>
5562
- false
5563
- }
5564
- def unapply (tp : Type )(using Context ): Option [MethodType ] =
5565
- if (isInstantiatable(tp)) {
5566
- val absMems = tp.possibleSamMethods
5567
- if (absMems.size == 1 )
5568
- absMems.head.info match {
5569
- case mt : MethodType if ! mt.isParamDependent &&
5570
- mt.resultType.isValueTypeOrWildcard &&
5571
- ! defn.isContextFunctionType(mt.resultType) =>
5572
- val cls = tp.classSymbol
5573
-
5574
- // Given a SAM type such as:
5575
- //
5576
- // import java.util.function.Function
5577
- // Function[? >: String, ? <: Int]
5578
- //
5579
- // the single abstract method will have type:
5580
- //
5581
- // (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5582
- //
5583
- // which is not implementable outside of the scope of Function.
5584
- //
5585
- // To avoid this kind of issue, we approximate references to
5586
- // parameters of the SAM type by their bounds, this way in the
5587
- // above example we get:
5588
- //
5589
- // (x: String): Int
5590
- val approxParams = new ApproximatingTypeMap {
5591
- def apply (tp : Type ): Type = tp match {
5592
- case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) && tp.symbol.owner == cls =>
5593
- tp.info match {
5594
- case info : AliasingBounds =>
5595
- mapOver(info.alias)
5596
- case TypeBounds (lo, hi) =>
5597
- range(atVariance(- variance)(apply(lo)), apply(hi))
5598
- case _ =>
5599
- range(defn.NothingType , defn.AnyType ) // should happen only in error cases
5600
- }
5601
- case _ =>
5602
- mapOver(tp)
5603
- }
5604
- }
5605
- val approx =
5606
- if ctx.owner.isContainedIn(cls) then mt
5607
- else approxParams(mt).asInstanceOf [MethodType ]
5608
- Some (approx)
5611
+ NoSymbol
5612
+
5613
+ def unapply (tp : Type )(using Context ): Option [(MethodType , Type )] =
5614
+ val cls = samClass(tp)
5615
+ if cls.exists then
5616
+ val absMems =
5617
+ if tp.isRef(defn.PartialFunctionClass ) then
5618
+ // To maintain compatibility with 2.x, we treat PartialFunction specially,
5619
+ // pretending it is a SAM type. In the future it would be better to merge
5620
+ // Function and PartialFunction, have Function1 contain a isDefinedAt method
5621
+ // def isDefinedAt(x: T) = true
5622
+ // and overwrite that method whenever the function body is a sequence of
5623
+ // case clauses.
5624
+ List (defn.PartialFunction_apply )
5625
+ else
5626
+ tp.possibleSamMethods.map(_.symbol)
5627
+ if absMems.lengthCompare(1 ) == 0 then
5628
+ val samMethSym = absMems.head
5629
+ val parent = samParent(tp, cls, samMethSym)
5630
+ samMethSym.asSeenFrom(parent).info match
5631
+ case mt : MethodType if ! mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5632
+ Some (mt, parent)
5609
5633
case _ =>
5610
5634
None
5611
- }
5612
- else if (tp isRef defn.PartialFunctionClass )
5613
- // To maintain compatibility with 2.x, we treat PartialFunction specially,
5614
- // pretending it is a SAM type. In the future it would be better to merge
5615
- // Function and PartialFunction, have Function1 contain a isDefinedAt method
5616
- // def isDefinedAt(x: T) = true
5617
- // and overwrite that method whenever the function body is a sequence of
5618
- // case clauses.
5619
- absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf [MethodType ])
5620
5635
else None
5621
- }
5622
5636
else None
5623
5637
}
5624
5638
@@ -6451,6 +6465,37 @@ object Types {
6451
6465
}
6452
6466
}
6453
6467
6468
+ object VarianceMap :
6469
+ /** An immutable map representing the variance of keys of type `K` */
6470
+ opaque type VarianceMap [K <: AnyRef ] <: AnyRef = SimpleIdentityMap [K , Integer ]
6471
+ def empty [K <: AnyRef ]: VarianceMap [K ] = SimpleIdentityMap .empty[K ]
6472
+ extension [K <: AnyRef ](vmap : VarianceMap [K ])
6473
+ /** The backing map used to implement this VarianceMap. */
6474
+ inline def underlying : SimpleIdentityMap [K , Integer ] = vmap
6475
+
6476
+ /** Return a new map taking into account that K appears in a
6477
+ * {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6478
+ */
6479
+ def recordLocalVariance (k : K , localVariance : Int ): VarianceMap [K ] =
6480
+ val previousVariance = vmap(k)
6481
+ if previousVariance == null then
6482
+ vmap.updated(k, localVariance)
6483
+ else if previousVariance == localVariance || previousVariance == 0 then
6484
+ vmap
6485
+ else
6486
+ vmap.updated(k, 0 )
6487
+
6488
+ /** Return the variance of `k`:
6489
+ * - A positive value means that `k` appears only covariantly.
6490
+ * - A negative value means that `k` appears only contravariantly.
6491
+ * - A zero value means that `k` appears both covariantly and
6492
+ * contravariantly, or appears invariantly.
6493
+ * - A null value means that `k` does not appear at all.
6494
+ */
6495
+ def computedVariance (k : K ): Integer | Null =
6496
+ vmap(k)
6497
+ export VarianceMap .VarianceMap
6498
+
6454
6499
// ----- Name Filters --------------------------------------------------
6455
6500
6456
6501
/** A name filter selects or discards a member name of a type `pre`.
0 commit comments