Skip to content

Synthesize anonymous mirrors with path dependent prefixes #13502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -842,11 +842,99 @@ object TypeOps:
def nestedPairs(ts: List[Type])(using Context): Type =
ts.foldRight(defn.EmptyTupleModule.termRef: Type)(defn.PairClass.typeRef.appliedTo(_, _))


class StripTypeVarsMap(using Context) extends TypeMap:
def apply(tp: Type) = mapOver(tp).stripTypeVar

/** Apply [[Type.stripTypeVar]] recursively. */
def stripTypeVars(tp: Type)(using Context): Type =
new StripTypeVarsMap().apply(tp)

/** Converts the type into a form reachable from with the given `prefix` */
def healPrefix(tpe: Type, prefix: Type)(using Context): Either[String, Type] =

final class HealPrefixUnwind(val error: String) extends scala.util.control.ControlThrowable

final class HealPrefixMap(splice: PrefixSplice)(using Context) extends TypeMap {
def apply(tpe: Type): Type = healPrefixImpl(tpe, splice, this)(using mapCtx)
}

case class PrefixSplice(val prefix: Type):
def spliceThisType(cls: Symbol): Option[Type] = PrefixSplice.loopThisType(cls, prefix)
def isEmpty: Boolean = prefix eq NoPrefix

object PrefixSplice:

private def loopThisType(cls: Symbol, pre: Type): Option[Type] = pre match
case outer: ThisType =>
if outer.cls.isSubClass(cls) then Some(pre)
else loopThisType(cls, outer.tref.prefix)

case term: TermRef =>
if term.widen.classSymbol eq cls then Some(pre)
else loopThisType(cls, term.prefix)

case supTpe: SuperType =>
if supTpe.thistpe.classSymbol.isSubClass(cls) then Some(pre)
else None

case _ => None

end PrefixSplice

def healPrefixImpl(tpe: Type, splice: PrefixSplice, theMap: HealPrefixMap | Null)(using Context): Type = tpe match {
case tpe: ThisType =>
val cls = tpe.cls
splice.spliceThisType(cls) match
case Some(spliced) =>
spliced
case _ if cls.is(Module) =>
val pre = healPrefixImpl(tpe.tref.prefix, splice, theMap)
TermRef(pre, cls.sourceModule)
case _ if cls.is(Package) || splice.isEmpty =>
tpe
case _ =>
throw HealPrefixUnwind(
i"prefix $tpe references a non-enclosing class, can not be healed to ${splice.prefix}")
case tpe =>
(if (theMap != null) theMap else new HealPrefixMap(splice))
.mapOver(tpe)
}

try
Right(healPrefixImpl(tpe, PrefixSplice(prefix), null))
catch
case unwind: HealPrefixUnwind => Left(unwind.error)

end healPrefix

/** lift the prefix of a type */
def extractPrefix(tpe: Type)(using Context): Type = tpe match
case tpe: TypeRef =>
val tryDealias = tpe.dealias
if tryDealias ne tpe then extractPrefix(tryDealias)
else tpe.prefix
case tpe: TermRef => tpe.prefix
case tpe: ThisType => extractPrefix(tpe.tref)
case tpe: SingletonType => NoPrefix
case tpe: TypeProxy => extractPrefix(tpe.underlying)
case tpe: AndOrType =>
val p1 = extractPrefix(tpe.tp1)
if p1.exists then
val p2 = extractPrefix(tpe.tp2)
if p2.exists then
if p1.frozen_=:=(p2) then p1
else
def history(pre: Type, acc: List[Type]): List[Type] = pre match
case tpe: NamedType => history(tpe.prefix, tpe :: acc)
case tpe: ThisType => history(tpe.tref, tpe :: acc)
case tpe => tpe :: acc
val h1 = history(p1, Nil)
val h2 = history(p2, Nil)
val (common, _) = h1.lazyZip(h2).takeWhile((p1, p2) => p1.frozen_=:=(p2)).last
common
else NoType
else NoType
case _ => NoType

end TypeOps
22 changes: 15 additions & 7 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,16 @@ object SymUtils:
* Excluded are value classes, abstract classes and case classes with more than one
* parameter section.
*/
def whyNotGenericProduct(using Context): String =
def whyNotGenericProduct(pre: Type)(using Context): String =
if (!self.is(CaseClass)) "it is not a case class"
else if (self.is(Abstract)) "it is an abstract class"
else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list"
else if (isDerivedValueClass(self)) "it is a value class"
else if (self.is(Scala2x) && !(pre == NoPrefix || pre.typeSymbol.isStaticOwner)) then
"it is not accessible in a static context"
else ""

def isGenericProduct(using Context): Boolean = whyNotGenericProduct.isEmpty
def isGenericProduct(pre: Type)(using Context): Boolean = whyNotGenericProduct(pre: Type).isEmpty

/** Is this an old style implicit conversion?
* @param directOnly only consider explicitly written methods
Expand Down Expand Up @@ -145,7 +147,7 @@ object SymUtils:
* and also the location of the generated mirror.
* - all of its children are generic products, singletons, or generic sums themselves.
*/
def whyNotGenericSum(declScope: Symbol)(using Context): String =
def whyNotGenericSum(pre: Type, declScope: Symbol)(using Context): String =
if (!self.is(Sealed))
s"it is not a sealed ${self.kindString}"
else if (!self.isOneOf(AbstractOrTrait))
Expand All @@ -157,17 +159,19 @@ object SymUtils:
def problem(child: Symbol) = {

def isAccessible(sym: Symbol): Boolean =
(self.isContainedIn(sym) && (companionMirror || declScope.isContainedIn(sym)))
def declRef = if declScope.isTerm then declScope.termRef else declScope.typeRef
self.isContainedIn(sym) &&
(companionMirror || sym.isAccessibleFrom(declRef))
|| sym.is(Module) && isAccessible(sym.owner)

if (child == self) "it has anonymous or inaccessible subclasses"
else if (!isAccessible(child.owner)) i"its child $child is not accessible"
else if (!child.isClass) ""
else {
val s = child.whyNotGenericProduct
val s = child.whyNotGenericProduct(pre)
if (s.isEmpty) s
else if (child.is(Sealed)) {
val s = child.whyNotGenericSum(if child.useCompanionAsSumMirror then child.linkedClass else ctx.owner)
val s = child.whyNotGenericSum(pre, if child.useCompanionAsSumMirror then child.linkedClass else ctx.owner)
if (s.isEmpty) s
else i"its child $child is not a generic sum because $s"
} else i"its child $child is not a generic product because $s"
Expand All @@ -177,7 +181,8 @@ object SymUtils:
else children.map(problem).find(!_.isEmpty).getOrElse("")
}

def isGenericSum(declScope: Symbol)(using Context): Boolean = whyNotGenericSum(declScope).isEmpty
def isGenericSum(pre: Type, declScope: Symbol)(using Context): Boolean =
whyNotGenericSum(pre, declScope).isEmpty

/** If this is a constructor, its owner: otherwise this. */
final def skipConstructor(using Context): Symbol =
Expand Down Expand Up @@ -286,6 +291,9 @@ object SymUtils:
def reachableRawTypeRef(using Context) =
self.reachableTypeRef.appliedTo(self.typeParams.map(_ => TypeBounds.emptyPolyKind))

def rawTypeRef(using Context) =
self.typeRef.appliedTo(self.typeParams.map(_ => TypeBounds.emptyPolyKind))

/** Is symbol a quote operation? */
def isQuote(using Context): Boolean =
self == defn.QuotedRuntime_exprQuote || self == defn.QuotedTypeModule_of
Expand Down
36 changes: 26 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import SymUtils._
import util.Property
import config.Printers.derive
import NullOpsDecorator._
import dotty.tools.dotc.util.SrcPos

object SyntheticMembers {

Expand All @@ -26,6 +27,8 @@ object SyntheticMembers {

/** Attachment recording that an anonymous class should extend Mirror.Sum */
val ExtendsSumMirror: Property.StickyKey[Unit] = new Property.StickyKey

val AnonymousMirrorPrefix: Property.StickyKey[Type] = new Property.StickyKey
}

/** Synthetic method implementations for case classes, case objects,
Expand Down Expand Up @@ -526,12 +529,16 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
* a wildcard for each type parameter. The normalized type of an object
* O is O.type.
*/
def ordinalBody(cls: Symbol, param: Tree)(using Context): Tree =
def ordinalBody(cls: Symbol, param: Tree, pre: Type, pos: SrcPos)(using Context): Tree =
if (cls.is(Enum)) param.select(nme.ordinal).ensureApplied
else {
val cases =
for ((child, idx) <- cls.children.zipWithIndex) yield {
val patType = if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
val patType0 = if child.isTerm then child.termRef else child.rawTypeRef
def toErr(err: String) =
report.error(err, pos)
ErrorType(err)
val patType: Type = TypeOps.healPrefix(patType0, pre).fold(toErr, identity)
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
}
Expand Down Expand Up @@ -568,12 +575,20 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
}
}
val linked = clazz.linkedClass

lazy val pre = impl.removeAttachment(AnonymousMirrorPrefix).getOrElse(NoPrefix)

lazy val monoType = {
val existing = clazz.info.member(tpnme.MirroredMonoType).symbol
if (existing.exists && !existing.is(Deferred)) existing
else {
val monoType =
newSymbol(clazz, tpnme.MirroredMonoType, Synthetic, TypeAlias(linked.reachableRawTypeRef), coord = clazz.coord)
val rawRef = linked.rawTypeRef
def toErr(err: String) =
report.error(err, impl.srcPos)
ErrorType(err)
val monoType = TypeOps.healPrefix(rawRef, pre).fold(toErr, identity)
newSymbol(clazz, tpnme.MirroredMonoType, Synthetic, TypeAlias(monoType), coord = clazz.coord)
newBody = newBody :+ TypeDef(monoType).withSpan(ctx.owner.span.focus)
monoType.enteredAfter(thisPhase)
}
Expand All @@ -585,25 +600,26 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls,
fromProductBody(_, _).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
}
def makeSumMirror(cls: Symbol) = {
def makeSumMirror(cls: Symbol, pre: Type, pos: SrcPos) = {
addParent(defn.Mirror_SumClass.typeRef)
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType), cls,
ordinalBody(_, _))
ordinalBody(_, _, pre, pos))
}

if (clazz.is(Module)) {
if (clazz.is(Case)) makeSingletonMirror()
else if (linked.isGenericProduct) makeProductMirror(linked)
else if (linked.isGenericSum(clazz)) makeSumMirror(linked)
else if (linked.is(Sealed))
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum(clazz)}")
else if linked.exists then
if (linked.isGenericProduct(pre)) makeProductMirror(linked)
else if (linked.isGenericSum(pre, clazz)) makeSumMirror(linked, pre, impl.srcPos)
else if (linked.is(Sealed))
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum(pre, clazz)}")
}
else if (impl.removeAttachment(ExtendsSingletonMirror).isDefined)
makeSingletonMirror()
else if (impl.removeAttachment(ExtendsProductMirror).isDefined)
makeProductMirror(monoType.typeRef.dealias.classSymbol)
else if (impl.removeAttachment(ExtendsSumMirror).isDefined)
makeSumMirror(monoType.typeRef.dealias.classSymbol)
makeSumMirror(monoType.typeRef.dealias.classSymbol, pre, impl.srcPos)

cpy.Template(impl)(parents = newParents, body = newBody)
}
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ object TypeUtils {
val r2 = tp2.mirrorCompanionRef
assert(r1.symbol == r2.symbol, em"mirrorCompanionRef mismatch for $self: $r1, $r2 did not have the same symbol")
r1
case AndType(tp1, tp2) =>
val c1 = tp1.classSymbol
val c2 = tp2.classSymbol
if c1.isSubClass(c2) then tp1.mirrorCompanionRef
else tp2.mirrorCompanionRef // precondition: the parts of the AndType have already been checked to be non-overlapping
case self @ TypeRef(prefix, _) if self.symbol.isClass =>
prefix.select(self.symbol.companionModule).asInstanceOf[TermRef]
case self: TypeProxy =>
Expand Down
Loading