Skip to content

Commit 196ef31

Browse files
committed
Fix #4446: Inline implementation of PF methods into its anonymous class
1 parent b4db175 commit 196ef31

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,18 @@ class ExpandSAMs extends MiniPhase {
6666
case PartialFunctionRHS(pf) =>
6767
val anonSym = anon.symbol
6868

69+
val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
70+
val pfSym = ctx.newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.pos)
71+
6972
def overrideSym(sym: Symbol) = sym.copy(
70-
owner = anonSym.owner,
71-
flags = Synthetic | Method | Final,
73+
owner = pfSym,
74+
flags = Synthetic | Method | Final | Override,
7275
info = tpe.memberInfo(sym),
73-
coord = tree.pos).asTerm
76+
coord = tree.pos).asTerm.entered
7477
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
7578
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
7679

77-
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = {
80+
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(implicit ctx: Context) = {
7881
val selector = tree.selector
7982
val selectorTpe = selector.tpe.widen
8083
val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe)
@@ -91,7 +94,7 @@ class ExpandSAMs extends MiniPhase {
9194
// And we need to update all references to 'param'
9295
}
9396

94-
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
97+
def isDefinedAtRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
9598
val tru = Literal(Constant(true))
9699
def translateCase(cdef: CaseDef) =
97100
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
@@ -100,20 +103,19 @@ class ExpandSAMs extends MiniPhase {
100103
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
101104
}
102105

103-
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
106+
def applyOrElseRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
104107
val List(paramRef, defaultRef) = paramRefss.head
105108
def translateCase(cdef: CaseDef) =
106109
cdef.changeOwner(anonSym, applyOrElseFn)
107110
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
108111
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
109112
}
110113

111-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
112-
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
113-
114-
val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
115-
val anonCls = AnonClass(parents, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
116-
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)
114+
val constr = ctx.newConstructor(pfSym, Synthetic, Nil, Nil).entered
115+
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(ctx.withOwner(isDefinedAtFn))))
116+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(ctx.withOwner(applyOrElseFn))))
117+
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
118+
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
117119

118120
case _ =>
119121
val found = tpe.baseType(defn.FunctionClass(1))

tests/run/i4446.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Foo {
2+
def foo: PartialFunction[Int, Int] = { case x => x + 1 }
3+
}
4+
5+
object Test {
6+
def serializeDeserialize[T <: AnyRef](obj: T): T = {
7+
import java.io._
8+
val buffer = new ByteArrayOutputStream
9+
val out = new ObjectOutputStream(buffer)
10+
out.writeObject(obj)
11+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
12+
in.readObject.asInstanceOf[T]
13+
}
14+
15+
def main(args: Array[String]): Unit = {
16+
val adder = serializeDeserialize((new Foo).foo)
17+
assert(adder(1) == 2)
18+
}
19+
}

0 commit comments

Comments
 (0)