@@ -21,7 +21,7 @@ import CheckRealizable._
2121import Variances .{Variance , setStructuralVariances , Invariant }
2222import typer .Nullables
2323import util .Stats ._
24- import util .SimpleIdentitySet
24+ import util .{ SimpleIdentityMap , SimpleIdentitySet }
2525import ast .tpd ._
2626import ast .TreeTypeMap
2727import printing .Texts ._
@@ -1751,7 +1751,7 @@ object Types {
17511751 t
17521752 case t if defn.isErasedFunctionType(t) =>
17531753 t
1754- case t @ SAMType (_) =>
1754+ case t @ SAMType (_, _ ) =>
17551755 t
17561756 case _ =>
17571757 NoType
@@ -5520,104 +5520,119 @@ object Types {
55205520 * A type is a SAM type if it is a reference to a class or trait, which
55215521 *
55225522 * - has a single abstract method with a method type (ExprType
5523- * and PolyType not allowed!) whose result type is not an implicit function type
5524- * and which is not marked inline.
5523+ * and PolyType not allowed!) according to `possibleSamMethods`.
55255524 * - can be instantiated without arguments or with just () as argument.
55265525 *
5527- * The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5528- * type of the single abstract method.
5526+ * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5527+ * type of the single abstract method and `samParent` is a subtype of the matched
5528+ * SAM type which has been stripped of wildcards to turn it into a valid parent
5529+ * type.
55295530 */
55305531 object SAMType {
5531- def zeroParamClass (tp : Type )(using Context ): Type = tp match {
5532+ /** If possible, return a type which is both a subtype of `origTp` and a type
5533+ * application of `samClass` where none of the type arguments are
5534+ * wildcards (thus making it a valid parent type), otherwise return
5535+ * NoType.
5536+ *
5537+ * A wildcard in the original type will be replaced by its upper or lower bound in a way
5538+ * that maximizes the number of possible implementations of `samMeth`. For example,
5539+ * java.util.function defines an interface equivalent to:
5540+ *
5541+ * trait Function[T, R]:
5542+ * def apply(t: T): R
5543+ *
5544+ * and it usually appears with wildcards to compensate for the lack of
5545+ * definition-site variance in Java:
5546+ *
5547+ * (x => x.toInt): Function[? >: String, ? <: Int]
5548+ *
5549+ * When typechecking this lambda, we need to approximate the wildcards to find
5550+ * a valid parent type for our lambda to extend. We can see that in `apply`,
5551+ * `T` only appears contravariantly and `R` only appears covariantly, so by
5552+ * minimizing the first parameter and maximizing the second, we maximize the
5553+ * number of valid implementations of `apply` which lets us implement the lambda
5554+ * with a closure equivalent to:
5555+ *
5556+ * new Function[String, Int] { def apply(x: String): Int = x.toInt }
5557+ *
5558+ * If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5559+ * we arbitrarily pick the upper-bound.
5560+ */
5561+ def samParent (origTp : Type , samClass : Symbol , samMeth : Symbol )(using Context ): Type =
5562+ val tp = origTp.baseType(samClass)
5563+ if ! (tp <:< origTp) then NoType
5564+ else tp match
5565+ case tp @ AppliedType (tycon, args) if tp.hasWildcardArg =>
5566+ val accu = new TypeAccumulator [VarianceMap [Symbol ]]:
5567+ def apply (vmap : VarianceMap [Symbol ], t : Type ): VarianceMap [Symbol ] = t match
5568+ case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) =>
5569+ vmap.recordLocalVariance(tp.symbol, variance)
5570+ case _ =>
5571+ foldOver(vmap, t)
5572+ val vmap = accu(VarianceMap .empty, samMeth.info)
5573+ val tparams = tycon.typeParamSymbols
5574+ val args1 = args.zipWithConserve(tparams):
5575+ case (arg @ TypeBounds (lo, hi), tparam) =>
5576+ val v = vmap.computedVariance(tparam)
5577+ if v.uncheckedNN < 0 then lo
5578+ else hi
5579+ case (arg, _) => arg
5580+ tp.derivedAppliedType(tycon, args1)
5581+ case _ =>
5582+ tp
5583+
5584+ def samClass (tp : Type )(using Context ): Symbol = tp match
55325585 case tp : ClassInfo =>
5533- def zeroParams (tp : Type ): Boolean = tp.stripPoly match {
5586+ def zeroParams (tp : Type ): Boolean = tp.stripPoly match
55345587 case mt : MethodType => mt.paramInfos.isEmpty && ! mt.resultType.isInstanceOf [MethodType ]
55355588 case et : ExprType => true
55365589 case _ => false
5537- }
5538- // `ContextFunctionN` does not have constructors
5539- val ctor = tp.cls.primaryConstructor
5540- if (! ctor.exists || zeroParams(ctor.info)) tp
5541- else NoType
5590+ val cls = tp.cls
5591+ val validCtor =
5592+ val ctor = cls.primaryConstructor
5593+ // `ContextFunctionN` does not have constructors
5594+ ! ctor.exists || zeroParams(ctor.info)
5595+ val isInstantiable = ! cls.isOneOf(FinalOrSealed ) && (tp.appliedRef <:< tp.selfType)
5596+ if validCtor && isInstantiable then tp.cls
5597+ else NoSymbol
55425598 case tp : AppliedType =>
5543- zeroParamClass (tp.superType)
5599+ samClass (tp.superType)
55445600 case tp : TypeRef =>
5545- zeroParamClass (tp.underlying)
5601+ samClass (tp.underlying)
55465602 case tp : RefinedType =>
5547- zeroParamClass (tp.underlying)
5603+ samClass (tp.underlying)
55485604 case tp : TypeBounds =>
5549- zeroParamClass (tp.underlying)
5605+ samClass (tp.underlying)
55505606 case tp : TypeVar =>
5551- zeroParamClass (tp.underlying)
5607+ samClass (tp.underlying)
55525608 case tp : AnnotatedType =>
5553- zeroParamClass(tp.underlying)
5554- case _ =>
5555- NoType
5556- }
5557- def isInstantiatable (tp : Type )(using Context ): Boolean = zeroParamClass(tp) match {
5558- case cinfo : ClassInfo if ! cinfo.cls.isOneOf(FinalOrSealed ) =>
5559- val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5560- tp <:< selfType
5609+ samClass(tp.underlying)
55615610 case _ =>
5562- false
5563- }
5564- def unapply (tp : Type )(using Context ): Option [MethodType ] =
5565- if (isInstantiatable(tp)) {
5566- val absMems = tp.possibleSamMethods
5567- if (absMems.size == 1 )
5568- absMems.head.info match {
5569- case mt : MethodType if ! mt.isParamDependent &&
5570- mt.resultType.isValueTypeOrWildcard =>
5571- val cls = tp.classSymbol
5572-
5573- // Given a SAM type such as:
5574- //
5575- // import java.util.function.Function
5576- // Function[? >: String, ? <: Int]
5577- //
5578- // the single abstract method will have type:
5579- //
5580- // (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5581- //
5582- // which is not implementable outside of the scope of Function.
5583- //
5584- // To avoid this kind of issue, we approximate references to
5585- // parameters of the SAM type by their bounds, this way in the
5586- // above example we get:
5587- //
5588- // (x: String): Int
5589- val approxParams = new ApproximatingTypeMap {
5590- def apply (tp : Type ): Type = tp match {
5591- case tp : TypeRef if tp.symbol.isAllOf(ClassTypeParam ) && tp.symbol.owner == cls =>
5592- tp.info match {
5593- case info : AliasingBounds =>
5594- mapOver(info.alias)
5595- case TypeBounds (lo, hi) =>
5596- range(atVariance(- variance)(apply(lo)), apply(hi))
5597- case _ =>
5598- range(defn.NothingType , defn.AnyType ) // should happen only in error cases
5599- }
5600- case _ =>
5601- mapOver(tp)
5602- }
5603- }
5604- val approx =
5605- if ctx.owner.isContainedIn(cls) then mt
5606- else approxParams(mt).asInstanceOf [MethodType ]
5607- Some (approx)
5611+ NoSymbol
5612+
5613+ def unapply (tp : Type )(using Context ): Option [(MethodType , Type )] =
5614+ val cls = samClass(tp)
5615+ if cls.exists then
5616+ val absMems =
5617+ if tp.isRef(defn.PartialFunctionClass ) then
5618+ // To maintain compatibility with 2.x, we treat PartialFunction specially,
5619+ // pretending it is a SAM type. In the future it would be better to merge
5620+ // Function and PartialFunction, have Function1 contain a isDefinedAt method
5621+ // def isDefinedAt(x: T) = true
5622+ // and overwrite that method whenever the function body is a sequence of
5623+ // case clauses.
5624+ List (defn.PartialFunction_apply )
5625+ else
5626+ tp.possibleSamMethods.map(_.symbol)
5627+ if absMems.lengthCompare(1 ) == 0 then
5628+ val samMethSym = absMems.head
5629+ val parent = samParent(tp, cls, samMethSym)
5630+ samMethSym.asSeenFrom(parent).info match
5631+ case mt : MethodType if ! mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5632+ Some (mt, parent)
56085633 case _ =>
56095634 None
5610- }
5611- else if (tp isRef defn.PartialFunctionClass )
5612- // To maintain compatibility with 2.x, we treat PartialFunction specially,
5613- // pretending it is a SAM type. In the future it would be better to merge
5614- // Function and PartialFunction, have Function1 contain a isDefinedAt method
5615- // def isDefinedAt(x: T) = true
5616- // and overwrite that method whenever the function body is a sequence of
5617- // case clauses.
5618- absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf [MethodType ])
56195635 else None
5620- }
56215636 else None
56225637 }
56235638
@@ -6450,6 +6465,37 @@ object Types {
64506465 }
64516466 }
64526467
6468+ object VarianceMap :
6469+ /** An immutable map representing the variance of keys of type `K` */
6470+ opaque type VarianceMap [K <: AnyRef ] <: AnyRef = SimpleIdentityMap [K , Integer ]
6471+ def empty [K <: AnyRef ]: VarianceMap [K ] = SimpleIdentityMap .empty[K ]
6472+ extension [K <: AnyRef ](vmap : VarianceMap [K ])
6473+ /** The backing map used to implement this VarianceMap. */
6474+ inline def underlying : SimpleIdentityMap [K , Integer ] = vmap
6475+
6476+ /** Return a new map taking into account that K appears in a
6477+ * {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6478+ */
6479+ def recordLocalVariance (k : K , localVariance : Int ): VarianceMap [K ] =
6480+ val previousVariance = vmap(k)
6481+ if previousVariance == null then
6482+ vmap.updated(k, localVariance)
6483+ else if previousVariance == localVariance || previousVariance == 0 then
6484+ vmap
6485+ else
6486+ vmap.updated(k, 0 )
6487+
6488+ /** Return the variance of `k`:
6489+ * - A positive value means that `k` appears only covariantly.
6490+ * - A negative value means that `k` appears only contravariantly.
6491+ * - A zero value means that `k` appears both covariantly and
6492+ * contravariantly, or appears invariantly.
6493+ * - A null value means that `k` does not appear at all.
6494+ */
6495+ def computedVariance (k : K ): Integer | Null =
6496+ vmap(k)
6497+ export VarianceMap .VarianceMap
6498+
64536499 // ----- Name Filters --------------------------------------------------
64546500
64556501 /** A name filter selects or discards a member name of a type `pre`.
0 commit comments