Skip to content

Commit 3d0f02c

Browse files
Backport "Reimplement support for type aliases in SAM types" to 3.3.4 (#21553)
Backports #18317 to 3.3.4-RC2 LTS. The PR fixes a regression introduced in 3.3.2
2 parents 1a5bff6 + db126c1 commit 3d0f02c

File tree

4 files changed

+85
-45
lines changed

4 files changed

+85
-45
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

+13-10
Original file line numberDiff line numberDiff line change
@@ -344,24 +344,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
344344

345345
/** An anonymous class
346346
*
347-
* new parents { forwarders }
347+
* new parents { termForwarders; typeAliases }
348348
*
349-
* where `forwarders` contains forwarders for all functions in `fns`.
350-
* @param parents a non-empty list of class types
351-
* @param fns a non-empty of functions for which forwarders should be defined in the class.
352-
* The class has the same owner as the first function in `fns`.
353-
* Its position is the union of all functions in `fns`.
349+
* @param parents a non-empty list of class types
350+
* @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to.
351+
* @param typeMembers a possibly-empty list of type members specified by their name and their right hand side.
352+
*
353+
* The class has the same owner as the first function in `termForwarders`.
354+
* Its position is the union of all symbols in `termForwarders`.
354355
*/
355-
def AnonClass(parents: List[Type], fns: List[TermSymbol], methNames: List[TermName])(using Context): Block = {
356-
AnonClass(fns.head.owner, parents, fns.map(_.span).reduceLeft(_ union _)) { cls =>
357-
def forwarder(fn: TermSymbol, name: TermName) = {
356+
def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)],
357+
typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = {
358+
AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls =>
359+
def forwarder(name: TermName, fn: TermSymbol) = {
358360
val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm
359361
for overridden <- fwdMeth.allOverriddenSymbols do
360362
if overridden.is(Extension) then fwdMeth.setFlag(Extension)
361363
if !overridden.is(Deferred) then fwdMeth.setFlag(Override)
362364
DefDef(fwdMeth, ref(fn).appliedToArgss(_))
363365
}
364-
fns.lazyZip(methNames).map(forwarder)
366+
termForwarders.map((name, sym) => forwarder(name, sym)) ++
367+
typeMembers.map((name, info) => TypeDef(newSymbol(cls, name, Synthetic, info).entered))
365368
}
366369
}
367370

compiler/src/dotty/tools/dotc/core/Types.scala

+38-21
Original file line numberDiff line numberDiff line change
@@ -5534,13 +5534,16 @@ object Types extends TypeUtils {
55345534
* and PolyType not allowed!) according to `possibleSamMethods`.
55355535
* - can be instantiated without arguments or with just () as argument.
55365536
*
5537+
* Additionally, a SAM type may contain type aliases refinements if they refine
5538+
* an existing type member.
5539+
*
55375540
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
55385541
* type of the single abstract method and `samParent` is a subtype of the matched
55395542
* SAM type which has been stripped of wildcards to turn it into a valid parent
55405543
* type.
55415544
*/
55425545
object SAMType {
5543-
/** If possible, return a type which is both a subtype of `origTp` and a type
5546+
/** If possible, return a type which is both a subtype of `origTp` and a (possibly refined) type
55445547
* application of `samClass` where none of the type arguments are
55455548
* wildcards (thus making it a valid parent type), otherwise return
55465549
* NoType.
@@ -5570,27 +5573,41 @@ object Types extends TypeUtils {
55705573
* we arbitrarily pick the upper-bound.
55715574
*/
55725575
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
5573-
val tp = origTp.baseType(samClass)
5576+
val tp0 = origTp.baseType(samClass)
5577+
5578+
/** Copy type aliases refinements to `toTp` from `fromTp` */
5579+
def withRefinements(toType: Type, fromTp: Type): Type = fromTp.dealias match
5580+
case RefinedType(fromParent, name, info: TypeAlias) if tp0.member(name).exists =>
5581+
val parent1 = withRefinements(toType, fromParent)
5582+
RefinedType(toType, name, info)
5583+
case _ => toType
5584+
val tp = withRefinements(tp0, origTp)
5585+
55745586
if !(tp <:< origTp) then NoType
5575-
else tp match
5576-
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5577-
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5578-
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5579-
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5580-
vmap.recordLocalVariance(tp.symbol, variance)
5581-
case _ =>
5582-
foldOver(vmap, t)
5583-
val vmap = accu(VarianceMap.empty, samMeth.info)
5584-
val tparams = tycon.typeParamSymbols
5585-
val args1 = args.zipWithConserve(tparams):
5586-
case (arg @ TypeBounds(lo, hi), tparam) =>
5587-
val v = vmap.computedVariance(tparam)
5588-
if v.uncheckedNN < 0 then lo
5589-
else hi
5590-
case (arg, _) => arg
5591-
tp.derivedAppliedType(tycon, args1)
5592-
case _ =>
5593-
tp
5587+
else
5588+
def approxWildcardArgs(tp: Type): Type = tp match
5589+
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5590+
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5591+
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5592+
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5593+
vmap.recordLocalVariance(tp.symbol, variance)
5594+
case _ =>
5595+
foldOver(vmap, t)
5596+
val vmap = accu(VarianceMap.empty, samMeth.info)
5597+
val tparams = tycon.typeParamSymbols
5598+
val args1 = args.zipWithConserve(tparams):
5599+
case (arg @ TypeBounds(lo, hi), tparam) =>
5600+
val v = vmap.computedVariance(tparam)
5601+
if v.uncheckedNN < 0 then lo
5602+
else hi
5603+
case (arg, _) => arg
5604+
tp.derivedAppliedType(tycon, args1)
5605+
case tp @ RefinedType(parent, name, info) =>
5606+
tp.derivedRefinedType(approxWildcardArgs(parent), name, info)
5607+
case _ =>
5608+
tp
5609+
approxWildcardArgs(tp)
5610+
end samParent
55945611

55955612
def samClass(tp: Type)(using Context): Symbol = tp match
55965613
case tp: ClassInfo =>

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

+19-14
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import core.*
66
import Scopes.newScope
77
import Contexts.*, Symbols.*, Types.*, Flags.*, Decorators.*, StdNames.*, Constants.*
88
import MegaPhase.*
9+
import Names.TypeName
10+
import Symbols.*
911
import NullOpsDecorator.*
1012
import ast.untpd
1113

@@ -50,16 +52,28 @@ class ExpandSAMs extends MiniPhase:
5052
case tpe if defn.isContextFunctionType(tpe) =>
5153
tree
5254
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
53-
val tpe1 = checkRefinements(tpe, fn)
54-
toPartialFunction(tree, tpe1)
55+
toPartialFunction(tree, tpe)
5556
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
56-
checkRefinements(tpe, fn)
5757
tree
5858
case tpe =>
59-
val tpe1 = checkRefinements(tpe.stripNull, fn)
59+
// A SAM type is allowed to have type aliases refinements (see
60+
// SAMType#samParent) which must be converted into type members if
61+
// the closure is desugared into a class.
62+
val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]()
63+
def collectAndStripRefinements(tp: Type): Type = tp match
64+
case RefinedType(parent, name, info: TypeAlias) =>
65+
val res = collectAndStripRefinements(parent)
66+
refinements += ((name.asTypeName, info))
67+
res
68+
case _ => tp
69+
val tpe1 = collectAndStripRefinements(tpe)
6070
val Seq(samDenot) = tpe1.possibleSamMethods
6171
cpy.Block(tree)(stats,
62-
AnonClass(tpe1 :: Nil, fn.symbol.asTerm :: Nil, samDenot.symbol.asTerm.name :: Nil))
72+
AnonClass(List(tpe1),
73+
List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm),
74+
refinements.toList
75+
)
76+
)
6377
}
6478
case _ =>
6579
tree
@@ -170,13 +184,4 @@ class ExpandSAMs extends MiniPhase:
170184
List(isDefinedAtDef, applyOrElseDef)
171185
}
172186
}
173-
174-
private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match {
175-
case RefinedType(parent, name, _) =>
176-
if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
177-
report.error(em"Lambda does not define $name", tree.srcPos)
178-
checkRefinements(parent, tree)
179-
case tpe =>
180-
tpe
181-
}
182187
end ExpandSAMs

tests/run/i18315.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
trait Sam1:
2+
type T
3+
def apply(x: T): T
4+
5+
trait Sam2:
6+
var x: Int = 1 // To force anonymous class generation
7+
type T
8+
def apply(x: T): T
9+
10+
object Test:
11+
def main(args: Array[String]): Unit =
12+
val s1: Sam1 { type T = String } = x => x.trim
13+
s1.apply("foo")
14+
val s2: Sam2 { type T = Int } = x => x + 1
15+
s2.apply(1)

0 commit comments

Comments
 (0)