Skip to content

Commit b35309c

Browse files
authored
Merge pull request #4604 from dotty-staging/fix-4446
Fix #4446: Inline implementation of PF methods into its anonymous class
2 parents fa15cf6 + fdc8b26 commit b35309c

File tree

5 files changed

+100
-22
lines changed

5 files changed

+100
-22
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ class Compiler {
6060
protected def transformPhases: List[List[Phase]] =
6161
List(new FirstTransform, // Some transformations to put trees into a canonical form
6262
new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars
63-
new ProtectedAccessors, // Add accessors for protected members
6463
new ElimPackagePrefixes) :: // Eliminate references to package prefixes in Select nodes
6564
List(new CheckStatic, // Check restrictions that apply to @static members
6665
new ElimRepeated, // Rewrite vararg parameters and arguments
6766
new NormalizeFlags, // Rewrite some definition flags
68-
new ExtensionMethods, // Expand methods of value classes with extension methods
6967
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
68+
new ProtectedAccessors, // Add accessors for protected members
69+
new ExtensionMethods, // Expand methods of value classes with extension methods
7070
new ShortcutImplicits, // Allow implicit functions without creating closures
7171
new TailRec, // Rewrite tail recursion to loops
7272
new ByNameClosures, // Expand arguments to by-name parameters to closures

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,47 @@ class ExpandSAMs extends MiniPhase {
5656
tree
5757
}
5858

59+
/** A partial function literal:
60+
*
61+
* ```
62+
* val x: PartialFunction[A, B] = { case C1 => E1; ...; case Cn => En }
63+
* ```
64+
*
65+
* which desugars to:
66+
*
67+
* ```
68+
* val x: PartialFunction[A, B] = {
69+
* def $anonfun(x: A): B = x match { case C1 => E1; ...; case Cn => En }
70+
* closure($anonfun: PartialFunction[A, B])
71+
* }
72+
* ```
73+
*
74+
* is expanded to an anomymous class:
75+
*
76+
* ```
77+
* val x: PartialFunction[A, B] = {
78+
* class $anon extends AbstractPartialFunction[A, B] {
79+
* final def isDefinedAt(x: A): Boolean = x match {
80+
* case C1 => true
81+
* ...
82+
* case Cn => true
83+
* case _ => false
84+
* }
85+
*
86+
* final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = x match {
87+
* case C1 => E1
88+
* ...
89+
* case Cn => En
90+
* case _ => default(x)
91+
* }
92+
* }
93+
*
94+
* new $anon
95+
* }
96+
* ```
97+
*/
5998
private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = {
60-
// /** An extractor for match, either contained in a block or standalone. */
99+
/** An extractor for match, either contained in a block or standalone. */
61100
object PartialFunctionRHS {
62101
def unapply(tree: Tree): Option[Match] = tree match {
63102
case Block(Nil, expr) => unapply(expr)
@@ -71,15 +110,18 @@ class ExpandSAMs extends MiniPhase {
71110
case PartialFunctionRHS(pf) =>
72111
val anonSym = anon.symbol
73112

113+
val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
114+
val pfSym = ctx.newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.pos)
115+
74116
def overrideSym(sym: Symbol) = sym.copy(
75-
owner = anonSym.owner,
76-
flags = Synthetic | Method | Final,
117+
owner = pfSym,
118+
flags = Synthetic | Method | Final | Override,
77119
info = tpe.memberInfo(sym),
78-
coord = tree.pos).asTerm
120+
coord = tree.pos).asTerm.entered
79121
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
80122
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
81123

82-
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = {
124+
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(implicit ctx: Context) = {
83125
val selector = tree.selector
84126
val selectorTpe = selector.tpe.widen
85127
val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe)
@@ -96,7 +138,7 @@ class ExpandSAMs extends MiniPhase {
96138
// And we need to update all references to 'param'
97139
}
98140

99-
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
141+
def isDefinedAtRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
100142
val tru = Literal(Constant(true))
101143
def translateCase(cdef: CaseDef) =
102144
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
@@ -105,20 +147,19 @@ class ExpandSAMs extends MiniPhase {
105147
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
106148
}
107149

108-
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
150+
def applyOrElseRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
109151
val List(paramRef, defaultRef) = paramRefss.head
110152
def translateCase(cdef: CaseDef) =
111153
cdef.changeOwner(anonSym, applyOrElseFn)
112154
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
113155
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
114156
}
115157

116-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
117-
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
118-
119-
val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
120-
val anonCls = AnonClass(parents, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
121-
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)
158+
val constr = ctx.newConstructor(pfSym, Synthetic, Nil, Nil).entered
159+
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(ctx.withOwner(isDefinedAtFn))))
160+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(ctx.withOwner(applyOrElseFn))))
161+
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
162+
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
122163

123164
case _ =>
124165
val found = tpe.baseType(defn.FunctionClass(1))

compiler/src/dotty/tools/dotc/transform/ProtectedAccessors.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,8 @@ class ProtectedAccessors extends MiniPhase {
6767

6868
override def transformAssign(tree: Assign)(implicit ctx: Context): Tree =
6969
tree.lhs match {
70-
case lhs: RefTree =>
71-
lhs.name match {
72-
case ProtectedAccessorName(name) =>
73-
cpy.Apply(tree)(Accessors.insert.useSetter(lhs), tree.rhs :: Nil)
74-
case _ =>
75-
tree
76-
}
70+
case lhs: RefTree if lhs.name.is(ProtectedAccessorName) =>
71+
cpy.Apply(tree)(Accessors.insert.useSetter(lhs), tree.rhs :: Nil)
7772
case _ =>
7873
tree
7974
}

compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,4 +287,27 @@ class TestBCode extends DottyBytecodeTest {
287287
assertTrue(containsExpectedCall)
288288
}
289289
}
290+
291+
@Test def partialFunctions = {
292+
val source =
293+
"""object Foo {
294+
| def magic(x: Int) = x
295+
| val foo: PartialFunction[Int, Int] = { case x => magic(x) }
296+
|}
297+
""".stripMargin
298+
299+
checkBCode(source) { dir =>
300+
// We test that the anonymous class generated for the partial function
301+
// holds the method implementations and does not use forwarders
302+
val clsIn = dir.lookupName("Foo$$anon$1.class", directory = false).input
303+
val clsNode = loadClassNode(clsIn)
304+
val applyOrElse = getMethod(clsNode, "applyOrElse")
305+
val instructions = instructionsFromMethod(applyOrElse)
306+
val callMagic = instructions.exists {
307+
case Invoke(_, _, "magic", _, _) => true
308+
case _ => false
309+
}
310+
assertTrue(callMagic)
311+
}
312+
}
290313
}

tests/run/i4446.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Foo {
2+
def foo: PartialFunction[Int, Int] = { case x => x + 1 }
3+
}
4+
5+
object Test {
6+
def serializeDeserialize[T <: AnyRef](obj: T): T = {
7+
import java.io._
8+
val buffer = new ByteArrayOutputStream
9+
val out = new ObjectOutputStream(buffer)
10+
out.writeObject(obj)
11+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
12+
in.readObject.asInstanceOf[T]
13+
}
14+
15+
def main(args: Array[String]): Unit = {
16+
val adder = serializeDeserialize((new Foo).foo)
17+
assert(adder(1) == 2)
18+
}
19+
}

0 commit comments

Comments
 (0)