@@ -21,7 +21,7 @@ import CheckRealizable._
2121import Variances .{Variance , setStructuralVariances , Invariant }
2222import typer .Nullables
2323import util .Stats ._
24- import util .SimpleIdentitySet
24+ import util .{ SimpleIdentityMap , SimpleIdentitySet }
2525import ast .tpd ._
2626import ast .TreeTypeMap
2727import printing .Texts ._
@@ -1746,7 +1746,7 @@ object Types {
17461746 t
17471747 case t if defn.isErasedFunctionType(t) =>
17481748 t
1749- case t @ SAMType (_) =>
1749+ case t @ SAMType (_, _ ) =>
17501750 t
17511751 case _ =>
17521752 NoType
@@ -5505,104 +5505,119 @@ object Types {
55055505 * A type is a SAM type if it is a reference to a class or trait, which
55065506 *
55075507 * - has a single abstract method with a method type (ExprType
5508- * and PolyType not allowed!) whose result type is not an implicit function type
5509- * and which is not marked inline.
5508+ * and PolyType not allowed!) according to `possibleSamMethods`.
55105509 * - can be instantiated without arguments or with just () as argument.
55115510 *
5512- * The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5513- * type of the single abstract method.
5511+ * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5512+ * type of the single abstract method and `samParent` is a subtype of the matched
5513+ * SAM type which has been stripped of wildcards to turn it into a valid parent
5514+ * type.
55145515 */
55155516 object SAMType {
5516- def zeroParamClass (tp : Type )(using Context ): Type = tp match {
5517+ /** If possible, return a type which is both a subtype of `origTp` and a type
5518+ * application of `samClass` where none of the type arguments are
5519+ * wildcards (thus making it a valid parent type), otherwise return
5520+ * NoType.
5521+ *
5522+ * A wildcard in the original type will be replaced by its upper or lower bound in a way
5523+ * that maximizes the number of possible implementations of `samMeth`. For example,
5524+ * java.util.function defines an interface equivalent to:
5525+ *
5526+ * trait Function[T, R]:
5527+ * def apply(t: T): R
5528+ *
5529+ * and it usually appears with wildcards to compensate for the lack of
5530+ * definition-site variance in Java:
5531+ *
5532+ * (x => x.toInt): Function[? >: String, ? <: Int]
5533+ *
5534+ * When typechecking this lambda, we need to approximate the wildcards to find
5535+ * a valid parent type for our lambda to extend. We can see that in `apply`,
5536+ * `T` only appears contravariantly and `R` only appears covariantly, so by
5537+ * minimizing the first parameter and maximizing the second, we maximize the
5538+ * number of valid implementations of `apply` which lets us implement the lambda
5539+ * with a closure equivalent to:
5540+ *
5541+ * new Function[String, Int] { def apply(x: String): Int = x.toInt }
5542+ *
5543+ * If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5544+ * we arbitrarily pick the upper-bound.
5545+ */
5546+ def samParent (origTp : Type , samClass : Symbol , samMeth : Symbol )(using Context ): Type =
5547+ val tp = origTp.baseType(samClass)
5548+ if ! (tp <:< origTp) then NoType
5549+ else tp match
5550+ case tp @ AppliedType (tycon, args) if tp.hasWildcardArg =>
5551+ val accu = new TypeAccumulator [VarianceMap [Symbol ]]:
5552+ def apply (vmap : VarianceMap [Symbol ], t : Type ): VarianceMap [Symbol ] = t match
5553+ case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) =>
5554+ vmap.recordLocalVariance(tp.symbol, variance)
5555+ case _ =>
5556+ foldOver(vmap, t)
5557+ val vmap = accu(VarianceMap .empty, samMeth.info)
5558+ val tparams = tycon.typeParamSymbols
5559+ val args1 = args.zipWithConserve(tparams):
5560+ case (arg @ TypeBounds (lo, hi), tparam) =>
5561+ val v = vmap.computedVariance(tparam)
5562+ if v.uncheckedNN < 0 then lo
5563+ else hi
5564+ case (arg, _) => arg
5565+ tp.derivedAppliedType(tycon, args1)
5566+ case _ =>
5567+ tp
5568+
5569+ def samClass (tp : Type )(using Context ): Symbol = tp match
55175570 case tp : ClassInfo =>
5518- def zeroParams (tp : Type ): Boolean = tp.stripPoly match {
5571+ def zeroParams (tp : Type ): Boolean = tp.stripPoly match
55195572 case mt : MethodType => mt.paramInfos.isEmpty && ! mt.resultType.isInstanceOf [MethodType ]
55205573 case et : ExprType => true
55215574 case _ => false
5522- }
5523- // `ContextFunctionN` does not have constructors
5524- val ctor = tp.cls.primaryConstructor
5525- if (! ctor.exists || zeroParams(ctor.info)) tp
5526- else NoType
5575+ val cls = tp.cls
5576+ val validCtor =
5577+ val ctor = cls.primaryConstructor
5578+ // `ContextFunctionN` does not have constructors
5579+ ! ctor.exists || zeroParams(ctor.info)
5580+ val isInstantiable = ! cls.isOneOf(FinalOrSealed ) && (tp.appliedRef <:< tp.selfType)
5581+ if validCtor && isInstantiable then tp.cls
5582+ else NoSymbol
55275583 case tp : AppliedType =>
5528- zeroParamClass (tp.superType)
5584+ samClass (tp.superType)
55295585 case tp : TypeRef =>
5530- zeroParamClass (tp.underlying)
5586+ samClass (tp.underlying)
55315587 case tp : RefinedType =>
5532- zeroParamClass (tp.underlying)
5588+ samClass (tp.underlying)
55335589 case tp : TypeBounds =>
5534- zeroParamClass (tp.underlying)
5590+ samClass (tp.underlying)
55355591 case tp : TypeVar =>
5536- zeroParamClass (tp.underlying)
5592+ samClass (tp.underlying)
55375593 case tp : AnnotatedType =>
5538- zeroParamClass(tp.underlying)
5539- case _ =>
5540- NoType
5541- }
5542- def isInstantiatable (tp : Type )(using Context ): Boolean = zeroParamClass(tp) match {
5543- case cinfo : ClassInfo if ! cinfo.cls.isOneOf(FinalOrSealed ) =>
5544- val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5545- tp <:< selfType
5594+ samClass(tp.underlying)
55465595 case _ =>
5547- false
5548- }
5549- def unapply (tp : Type )(using Context ): Option [MethodType ] =
5550- if (isInstantiatable(tp)) {
5551- val absMems = tp.possibleSamMethods
5552- if (absMems.size == 1 )
5553- absMems.head.info match {
5554- case mt : MethodType if ! mt.isParamDependent &&
5555- mt.resultType.isValueTypeOrWildcard =>
5556- val cls = tp.classSymbol
5557-
5558- // Given a SAM type such as:
5559- //
5560- // import java.util.function.Function
5561- // Function[? >: String, ? <: Int]
5562- //
5563- // the single abstract method will have type:
5564- //
5565- // (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5566- //
5567- // which is not implementable outside of the scope of Function.
5568- //
5569- // To avoid this kind of issue, we approximate references to
5570- // parameters of the SAM type by their bounds, this way in the
5571- // above example we get:
5572- //
5573- // (x: String): Int
5574- val approxParams = new ApproximatingTypeMap {
5575- def apply (tp : Type ): Type = tp match {
5576- case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) && tp.symbol.owner == cls =>
5577- tp.info match {
5578- case info : AliasingBounds =>
5579- mapOver(info.alias)
5580- case TypeBounds (lo, hi) =>
5581- range(atVariance(- variance)(apply(lo)), apply(hi))
5582- case _ =>
5583- range(defn.NothingType , defn.AnyType ) // should happen only in error cases
5584- }
5585- case _ =>
5586- mapOver(tp)
5587- }
5588- }
5589- val approx =
5590- if ctx.owner.isContainedIn(cls) then mt
5591- else approxParams(mt).asInstanceOf [MethodType ]
5592- Some (approx)
5596+ NoSymbol
5597+
5598+ def unapply (tp : Type )(using Context ): Option [(MethodType , Type )] =
5599+ val cls = samClass(tp)
5600+ if cls.exists then
5601+ val absMems =
5602+ if tp.isRef(defn.PartialFunctionClass ) then
5603+ // To maintain compatibility with 2.x, we treat PartialFunction specially,
5604+ // pretending it is a SAM type. In the future it would be better to merge
5605+ // Function and PartialFunction, have Function1 contain a isDefinedAt method
5606+ // def isDefinedAt(x: T) = true
5607+ // and overwrite that method whenever the function body is a sequence of
5608+ // case clauses.
5609+ List (defn.PartialFunction_apply )
5610+ else
5611+ tp.possibleSamMethods.map(_.symbol)
5612+ if absMems.lengthCompare(1 ) == 0 then
5613+ val samMethSym = absMems.head
5614+ val parent = samParent(tp, cls, samMethSym)
5615+ samMethSym.asSeenFrom(parent).info match
5616+ case mt : MethodType if ! mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5617+ Some (mt, parent)
55935618 case _ =>
55945619 None
5595- }
5596- else if (tp isRef defn.PartialFunctionClass )
5597- // To maintain compatibility with 2.x, we treat PartialFunction specially,
5598- // pretending it is a SAM type. In the future it would be better to merge
5599- // Function and PartialFunction, have Function1 contain a isDefinedAt method
5600- // def isDefinedAt(x: T) = true
5601- // and overwrite that method whenever the function body is a sequence of
5602- // case clauses.
5603- absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf [MethodType ])
56045620 else None
5605- }
56065621 else None
56075622 }
56085623
@@ -6435,6 +6450,37 @@ object Types {
64356450 }
64366451 }
64376452
6453+ object VarianceMap :
6454+ /** An immutable map representing the variance of keys of type `K` */
6455+ opaque type VarianceMap [K <: AnyRef ] <: AnyRef = SimpleIdentityMap [K , Integer ]
6456+ def empty [K <: AnyRef ]: VarianceMap [K ] = SimpleIdentityMap .empty[K ]
6457+ extension [K <: AnyRef ](vmap : VarianceMap [K ])
6458+ /** The backing map used to implement this VarianceMap. */
6459+ inline def underlying : SimpleIdentityMap [K , Integer ] = vmap
6460+
6461+ /** Return a new map taking into account that K appears in a
6462+ * {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6463+ */
6464+ def recordLocalVariance (k : K , localVariance : Int ): VarianceMap [K ] =
6465+ val previousVariance = vmap(k)
6466+ if previousVariance == null then
6467+ vmap.updated(k, localVariance)
6468+ else if previousVariance == localVariance || previousVariance == 0 then
6469+ vmap
6470+ else
6471+ vmap.updated(k, 0 )
6472+
6473+ /** Return the variance of `k`:
6474+ * - A positive value means that `k` appears only covariantly.
6475+ * - A negative value means that `k` appears only contravariantly.
6476+ * - A zero value means that `k` appears both covariantly and
6477+ * contravariantly, or appears invariantly.
6478+ * - A null value means that `k` does not appear at all.
6479+ */
6480+ def computedVariance (k : K ): Integer | Null =
6481+ vmap(k)
6482+ export VarianceMap .VarianceMap
6483+
64386484 // ----- Name Filters --------------------------------------------------
64396485
64406486 /** A name filter selects or discards a member name of a type `pre`.
0 commit comments