Skip to content

Commit 4851278

Browse files
authored
Properly handle SAM types with wildcards (#18201)
When typing a closure with an expected type containing a wildcard, the closure type itself should not contain wildcards, because it might be expanded to an anonymous class extending the closure type (this happens on non-JVM backends as well as on the JVM itself in situations where a SAM trait does not compile down to a SAM interface). We were already approximating wildcards in the method type returned by the SAMType extractor, but to fix this issue we had to change the extractor to perform the approximation on the expected type itself to generate a valid parent type. The SAMType extractor now returns both the approximated parent type and the type of the method itself. The wildcard approximation analysis relies on a new `VarianceMap` opaque type extracted from Inferencing#variances. Fixes #16065. Fixes #18096.
2 parents ca19b46 + 89735d0 commit 4851278

11 files changed

+216
-142
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ class Definitions {
744744
@tu lazy val StringContextModule_processEscapes: Symbol = StringContextModule.requiredMethod(nme.processEscapes)
745745

746746
@tu lazy val PartialFunctionClass: ClassSymbol = requiredClass("scala.PartialFunction")
747+
@tu lazy val PartialFunction_apply: Symbol = PartialFunctionClass.requiredMethod(nme.apply)
747748
@tu lazy val PartialFunction_isDefinedAt: Symbol = PartialFunctionClass.requiredMethod(nme.isDefinedAt)
748749
@tu lazy val PartialFunction_applyOrElse: Symbol = PartialFunctionClass.requiredMethod(nme.applyOrElse)
749750

compiler/src/dotty/tools/dotc/core/Types.scala

+128-83
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CheckRealizable._
2121
import Variances.{Variance, setStructuralVariances, Invariant}
2222
import typer.Nullables
2323
import util.Stats._
24-
import util.SimpleIdentitySet
24+
import util.{SimpleIdentityMap, SimpleIdentitySet}
2525
import ast.tpd._
2626
import ast.TreeTypeMap
2727
import printing.Texts._
@@ -1751,7 +1751,7 @@ object Types {
17511751
t
17521752
case t if defn.isErasedFunctionType(t) =>
17531753
t
1754-
case t @ SAMType(_) =>
1754+
case t @ SAMType(_, _) =>
17551755
t
17561756
case _ =>
17571757
NoType
@@ -5520,105 +5520,119 @@ object Types {
55205520
* A type is a SAM type if it is a reference to a class or trait, which
55215521
*
55225522
* - has a single abstract method with a method type (ExprType
5523-
* and PolyType not allowed!) whose result type is not an implicit function type
5524-
* and which is not marked inline.
5523+
* and PolyType not allowed!) according to `possibleSamMethods`.
55255524
* - can be instantiated without arguments or with just () as argument.
55265525
*
5527-
* The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5528-
* type of the single abstract method.
5526+
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5527+
* type of the single abstract method and `samParent` is a subtype of the matched
5528+
* SAM type which has been stripped of wildcards to turn it into a valid parent
5529+
* type.
55295530
*/
55305531
object SAMType {
5531-
def zeroParamClass(tp: Type)(using Context): Type = tp match {
5532+
/** If possible, return a type which is both a subtype of `origTp` and a type
5533+
* application of `samClass` where none of the type arguments are
5534+
* wildcards (thus making it a valid parent type), otherwise return
5535+
* NoType.
5536+
*
5537+
* A wildcard in the original type will be replaced by its upper or lower bound in a way
5538+
* that maximizes the number of possible implementations of `samMeth`. For example,
5539+
* java.util.function defines an interface equivalent to:
5540+
*
5541+
* trait Function[T, R]:
5542+
* def apply(t: T): R
5543+
*
5544+
* and it usually appears with wildcards to compensate for the lack of
5545+
* definition-site variance in Java:
5546+
*
5547+
* (x => x.toInt): Function[? >: String, ? <: Int]
5548+
*
5549+
* When typechecking this lambda, we need to approximate the wildcards to find
5550+
* a valid parent type for our lambda to extend. We can see that in `apply`,
5551+
* `T` only appears contravariantly and `R` only appears covariantly, so by
5552+
* minimizing the first parameter and maximizing the second, we maximize the
5553+
* number of valid implementations of `apply` which lets us implement the lambda
5554+
* with a closure equivalent to:
5555+
*
5556+
* new Function[String, Int] { def apply(x: String): Int = x.toInt }
5557+
*
5558+
* If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5559+
* we arbitrarily pick the upper-bound.
5560+
*/
5561+
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
5562+
val tp = origTp.baseType(samClass)
5563+
if !(tp <:< origTp) then NoType
5564+
else tp match
5565+
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5566+
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5567+
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5568+
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5569+
vmap.recordLocalVariance(tp.symbol, variance)
5570+
case _ =>
5571+
foldOver(vmap, t)
5572+
val vmap = accu(VarianceMap.empty, samMeth.info)
5573+
val tparams = tycon.typeParamSymbols
5574+
val args1 = args.zipWithConserve(tparams):
5575+
case (arg @ TypeBounds(lo, hi), tparam) =>
5576+
val v = vmap.computedVariance(tparam)
5577+
if v.uncheckedNN < 0 then lo
5578+
else hi
5579+
case (arg, _) => arg
5580+
tp.derivedAppliedType(tycon, args1)
5581+
case _ =>
5582+
tp
5583+
5584+
def samClass(tp: Type)(using Context): Symbol = tp match
55325585
case tp: ClassInfo =>
5533-
def zeroParams(tp: Type): Boolean = tp.stripPoly match {
5586+
def zeroParams(tp: Type): Boolean = tp.stripPoly match
55345587
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
55355588
case et: ExprType => true
55365589
case _ => false
5537-
}
5538-
// `ContextFunctionN` does not have constructors
5539-
val ctor = tp.cls.primaryConstructor
5540-
if (!ctor.exists || zeroParams(ctor.info)) tp
5541-
else NoType
5590+
val cls = tp.cls
5591+
val validCtor =
5592+
val ctor = cls.primaryConstructor
5593+
// `ContextFunctionN` does not have constructors
5594+
!ctor.exists || zeroParams(ctor.info)
5595+
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
5596+
if validCtor && isInstantiable then tp.cls
5597+
else NoSymbol
55425598
case tp: AppliedType =>
5543-
zeroParamClass(tp.superType)
5599+
samClass(tp.superType)
55445600
case tp: TypeRef =>
5545-
zeroParamClass(tp.underlying)
5601+
samClass(tp.underlying)
55465602
case tp: RefinedType =>
5547-
zeroParamClass(tp.underlying)
5603+
samClass(tp.underlying)
55485604
case tp: TypeBounds =>
5549-
zeroParamClass(tp.underlying)
5605+
samClass(tp.underlying)
55505606
case tp: TypeVar =>
5551-
zeroParamClass(tp.underlying)
5607+
samClass(tp.underlying)
55525608
case tp: AnnotatedType =>
5553-
zeroParamClass(tp.underlying)
5554-
case _ =>
5555-
NoType
5556-
}
5557-
def isInstantiatable(tp: Type)(using Context): Boolean = zeroParamClass(tp) match {
5558-
case cinfo: ClassInfo if !cinfo.cls.isOneOf(FinalOrSealed) =>
5559-
val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5560-
tp <:< selfType
5609+
samClass(tp.underlying)
55615610
case _ =>
5562-
false
5563-
}
5564-
def unapply(tp: Type)(using Context): Option[MethodType] =
5565-
if (isInstantiatable(tp)) {
5566-
val absMems = tp.possibleSamMethods
5567-
if (absMems.size == 1)
5568-
absMems.head.info match {
5569-
case mt: MethodType if !mt.isParamDependent &&
5570-
mt.resultType.isValueTypeOrWildcard &&
5571-
!defn.isContextFunctionType(mt.resultType) =>
5572-
val cls = tp.classSymbol
5573-
5574-
// Given a SAM type such as:
5575-
//
5576-
// import java.util.function.Function
5577-
// Function[? >: String, ? <: Int]
5578-
//
5579-
// the single abstract method will have type:
5580-
//
5581-
// (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5582-
//
5583-
// which is not implementable outside of the scope of Function.
5584-
//
5585-
// To avoid this kind of issue, we approximate references to
5586-
// parameters of the SAM type by their bounds, this way in the
5587-
// above example we get:
5588-
//
5589-
// (x: String): Int
5590-
val approxParams = new ApproximatingTypeMap {
5591-
def apply(tp: Type): Type = tp match {
5592-
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) && tp.symbol.owner == cls =>
5593-
tp.info match {
5594-
case info: AliasingBounds =>
5595-
mapOver(info.alias)
5596-
case TypeBounds(lo, hi) =>
5597-
range(atVariance(-variance)(apply(lo)), apply(hi))
5598-
case _ =>
5599-
range(defn.NothingType, defn.AnyType) // should happen only in error cases
5600-
}
5601-
case _ =>
5602-
mapOver(tp)
5603-
}
5604-
}
5605-
val approx =
5606-
if ctx.owner.isContainedIn(cls) then mt
5607-
else approxParams(mt).asInstanceOf[MethodType]
5608-
Some(approx)
5611+
NoSymbol
5612+
5613+
def unapply(tp: Type)(using Context): Option[(MethodType, Type)] =
5614+
val cls = samClass(tp)
5615+
if cls.exists then
5616+
val absMems =
5617+
if tp.isRef(defn.PartialFunctionClass) then
5618+
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5619+
// pretending it is a SAM type. In the future it would be better to merge
5620+
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5621+
// def isDefinedAt(x: T) = true
5622+
// and overwrite that method whenever the function body is a sequence of
5623+
// case clauses.
5624+
List(defn.PartialFunction_apply)
5625+
else
5626+
tp.possibleSamMethods.map(_.symbol)
5627+
if absMems.lengthCompare(1) == 0 then
5628+
val samMethSym = absMems.head
5629+
val parent = samParent(tp, cls, samMethSym)
5630+
samMethSym.asSeenFrom(parent).info match
5631+
case mt: MethodType if !mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5632+
Some(mt, parent)
56095633
case _ =>
56105634
None
5611-
}
5612-
else if (tp isRef defn.PartialFunctionClass)
5613-
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5614-
// pretending it is a SAM type. In the future it would be better to merge
5615-
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5616-
// def isDefinedAt(x: T) = true
5617-
// and overwrite that method whenever the function body is a sequence of
5618-
// case clauses.
5619-
absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf[MethodType])
56205635
else None
5621-
}
56225636
else None
56235637
}
56245638

@@ -6451,6 +6465,37 @@ object Types {
64516465
}
64526466
}
64536467

6468+
object VarianceMap:
6469+
/** An immutable map representing the variance of keys of type `K` */
6470+
opaque type VarianceMap[K <: AnyRef] <: AnyRef = SimpleIdentityMap[K, Integer]
6471+
def empty[K <: AnyRef]: VarianceMap[K] = SimpleIdentityMap.empty[K]
6472+
extension [K <: AnyRef](vmap: VarianceMap[K])
6473+
/** The backing map used to implement this VarianceMap. */
6474+
inline def underlying: SimpleIdentityMap[K, Integer] = vmap
6475+
6476+
/** Return a new map taking into account that K appears in a
6477+
* {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6478+
*/
6479+
def recordLocalVariance(k: K, localVariance: Int): VarianceMap[K] =
6480+
val previousVariance = vmap(k)
6481+
if previousVariance == null then
6482+
vmap.updated(k, localVariance)
6483+
else if previousVariance == localVariance || previousVariance == 0 then
6484+
vmap
6485+
else
6486+
vmap.updated(k, 0)
6487+
6488+
/** Return the variance of `k`:
6489+
* - A positive value means that `k` appears only covariantly.
6490+
* - A negative value means that `k` appears only contravariantly.
6491+
* - A zero value means that `k` appears both covariantly and
6492+
* contravariantly, or appears invariantly.
6493+
* - A null value means that `k` does not appear at all.
6494+
*/
6495+
def computedVariance(k: K): Integer | Null =
6496+
vmap(k)
6497+
export VarianceMap.VarianceMap
6498+
64546499
// ----- Name Filters --------------------------------------------------
64556500

64566501
/** A name filter selects or discards a member name of a type `pre`.

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

+2-11
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class ExpandSAMs extends MiniPhase:
5050
tree // it's a plain function
5151
case tpe if defn.isContextFunctionType(tpe) =>
5252
tree
53-
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
53+
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
5454
val tpe1 = checkRefinements(tpe, fn)
5555
toPartialFunction(tree, tpe1)
56-
case tpe @ SAMType(_) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
56+
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
5757
checkRefinements(tpe, fn)
5858
tree
5959
case tpe =>
@@ -66,13 +66,6 @@ class ExpandSAMs extends MiniPhase:
6666
tree
6767
}
6868

69-
private def checkNoContextFunction(tpt: Tree)(using Context): Unit =
70-
if defn.isContextFunctionType(tpt.tpe) then
71-
report.error(
72-
em"""Implementation restriction: cannot convert this expression to
73-
|partial function with context function result type $tpt""",
74-
tpt.srcPos)
75-
7669
/** A partial function literal:
7770
*
7871
* ```
@@ -115,8 +108,6 @@ class ExpandSAMs extends MiniPhase:
115108
private def toPartialFunction(tree: Block, tpe: Type)(using Context): Tree = {
116109
val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree: @unchecked
117110

118-
checkNoContextFunction(anon.tpt)
119-
120111
// The right hand side from which to construct the partial function. This is always a Match.
121112
// If the original rhs is already a Match (possibly in braces), return that.
122113
// Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.

compiler/src/dotty/tools/dotc/typer/Applications.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ trait Applications extends Compatibility {
696696

697697
def SAMargOK =
698698
defn.isFunctionNType(argtpe1) && formal.match
699-
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
699+
case SAMType(samMeth, samParent) => argtpe <:< samMeth.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined))
700700
case _ => false
701701

702702
isCompatible(argtpe, formal)
@@ -2074,7 +2074,7 @@ trait Applications extends Compatibility {
20742074
* new java.io.ObjectOutputStream(f)
20752075
*/
20762076
pt match {
2077-
case SAMType(mtp) =>
2077+
case SAMType(mtp, _) =>
20782078
narrowByTypes(alts, mtp.paramInfos, mtp.resultType)
20792079
case _ =>
20802080
// pick any alternatives that are not methods since these might be convertible

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

+9-16
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ object Inferencing {
407407
val vs = variances(tp)
408408
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
409409
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
410-
vs foreachBinding { (tvar, v) =>
410+
vs.underlying foreachBinding { (tvar, v) =>
411411
if !tvar.isInstantiated then
412412
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
413413
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
@@ -440,8 +440,6 @@ object Inferencing {
440440
res
441441
}
442442

443-
type VarianceMap = SimpleIdentityMap[TypeVar, Integer]
444-
445443
/** All occurrences of type vars in `tp` that satisfy predicate
446444
* `include` mapped to their variances (-1/0/1) in both `tp` and
447445
* `pt.finalResultType`, where
@@ -465,23 +463,18 @@ object Inferencing {
465463
*
466464
* we want to instantiate U to x.type right away. No need to wait further.
467465
*/
468-
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap = {
466+
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
469467
Stats.record("variances")
470468
val constraint = ctx.typerState.constraint
471469

472-
object accu extends TypeAccumulator[VarianceMap] {
470+
object accu extends TypeAccumulator[VarianceMap[TypeVar]]:
473471
def setVariance(v: Int) = variance = v
474-
def apply(vmap: VarianceMap, t: Type): VarianceMap = t match {
472+
def apply(vmap: VarianceMap[TypeVar], t: Type): VarianceMap[TypeVar] = t match
475473
case t: TypeVar
476474
if !t.isInstantiated && accCtx.typerState.constraint.contains(t) =>
477-
val v = vmap(t)
478-
if (v == null) vmap.updated(t, variance)
479-
else if (v == variance || v == 0) vmap
480-
else vmap.updated(t, 0)
475+
vmap.recordLocalVariance(t, variance)
481476
case _ =>
482477
foldOver(vmap, t)
483-
}
484-
}
485478

486479
/** Include in `vmap` type variables occurring in the constraints of type variables
487480
* already in `vmap`. Specifically:
@@ -493,10 +486,10 @@ object Inferencing {
493486
* bounds as non-variant.
494487
* Do this in a fixpoint iteration until `vmap` stabilizes.
495488
*/
496-
def propagate(vmap: VarianceMap): VarianceMap = {
489+
def propagate(vmap: VarianceMap[TypeVar]): VarianceMap[TypeVar] = {
497490
var vmap1 = vmap
498491
def traverse(tp: Type) = { vmap1 = accu(vmap1, tp) }
499-
vmap.foreachBinding { (tvar, v) =>
492+
vmap.underlying.foreachBinding { (tvar, v) =>
500493
val param = tvar.origin
501494
constraint.entry(param) match
502495
case TypeBounds(lo, hi) =>
@@ -512,7 +505,7 @@ object Inferencing {
512505
if (vmap1 eq vmap) vmap else propagate(vmap1)
513506
}
514507

515-
propagate(accu(accu(SimpleIdentityMap.empty, tp), pt.finalResultType))
508+
propagate(accu(accu(VarianceMap.empty, tp), pt.finalResultType))
516509
}
517510

518511
/** Run the transformation after dealiasing but return the original type if it was a no-op. */
@@ -638,7 +631,7 @@ trait Inferencing { this: Typer =>
638631
if !tvar.isInstantiated then
639632
// isInstantiated needs to be checked again, since previous interpolations could already have
640633
// instantiated `tvar` through unification.
641-
val v = vs(tvar)
634+
val v = vs.computedVariance(tvar)
642635
if v == null then buf += ((tvar, 0))
643636
else if v.intValue != 0 then buf += ((tvar, v.intValue))
644637
else comparing(cmp =>

0 commit comments

Comments
 (0)