Skip to content

Fix #4446: Inline implementation of PF methods into its anonymous class #4604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ class Compiler {
protected def transformPhases: List[List[Phase]] =
List(new FirstTransform, // Some transformations to put trees into a canonical form
new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars
new ProtectedAccessors, // Add accessors for protected members
new ElimPackagePrefixes) :: // Eliminate references to package prefixes in Select nodes
List(new CheckStatic, // Check restrictions that apply to @static members
new ElimRepeated, // Rewrite vararg parameters and arguments
new NormalizeFlags, // Rewrite some definition flags
new ExtensionMethods, // Expand methods of value classes with extension methods
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
new ProtectedAccessors, // Add accessors for protected members
new ExtensionMethods, // Expand methods of value classes with extension methods
new ShortcutImplicits, // Allow implicit functions without creating closures
new TailRec, // Rewrite tail recursion to loops
new ByNameClosures, // Expand arguments to by-name parameters to closures
Expand Down
67 changes: 54 additions & 13 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,47 @@ class ExpandSAMs extends MiniPhase {
tree
}

/** A partial function literal:
*
* ```
* val x: PartialFunction[A, B] = { case C1 => E1; ...; case Cn => En }
* ```
*
* which desugars to:
*
* ```
* val x: PartialFunction[A, B] = {
* def $anonfun(x: A): B = x match { case C1 => E1; ...; case Cn => En }
* closure($anonfun: PartialFunction[A, B])
* }
* ```
*
* is expanded to an anomymous class:
*
* ```
* val x: PartialFunction[A, B] = {
* class $anon extends AbstractPartialFunction[A, B] {
* final def isDefinedAt(x: A): Boolean = x match {
* case C1 => true
* ...
* case Cn => true
* case _ => false
* }
*
* final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = x match {
* case C1 => E1
* ...
* case Cn => En
* case _ => default(x)
* }
* }
*
* new $anon
* }
* ```
*/
private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = {
// /** An extractor for match, either contained in a block or standalone. */
/** An extractor for match, either contained in a block or standalone. */
object PartialFunctionRHS {
def unapply(tree: Tree): Option[Match] = tree match {
case Block(Nil, expr) => unapply(expr)
Expand All @@ -71,15 +110,18 @@ class ExpandSAMs extends MiniPhase {
case PartialFunctionRHS(pf) =>
val anonSym = anon.symbol

val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
val pfSym = ctx.newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.pos)

def overrideSym(sym: Symbol) = sym.copy(
owner = anonSym.owner,
flags = Synthetic | Method | Final,
owner = pfSym,
flags = Synthetic | Method | Final | Override,
info = tpe.memberInfo(sym),
coord = tree.pos).asTerm
coord = tree.pos).asTerm.entered
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)

def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = {
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(implicit ctx: Context) = {
val selector = tree.selector
val selectorTpe = selector.tpe.widen
val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe)
Expand All @@ -96,7 +138,7 @@ class ExpandSAMs extends MiniPhase {
// And we need to update all references to 'param'
}

def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
def isDefinedAtRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
val tru = Literal(Constant(true))
def translateCase(cdef: CaseDef) =
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
Expand All @@ -105,20 +147,19 @@ class ExpandSAMs extends MiniPhase {
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
}

def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
def applyOrElseRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
val List(paramRef, defaultRef) = paramRefss.head
def translateCase(cdef: CaseDef) =
cdef.changeOwner(anonSym, applyOrElseFn)
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
}

val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))

val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
val anonCls = AnonClass(parents, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)
val constr = ctx.newConstructor(pfSym, Synthetic, Nil, Nil).entered
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(ctx.withOwner(isDefinedAtFn))))
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(ctx.withOwner(applyOrElseFn))))
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))

case _ =>
val found = tpe.baseType(defn.FunctionClass(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,8 @@ class ProtectedAccessors extends MiniPhase {

override def transformAssign(tree: Assign)(implicit ctx: Context): Tree =
tree.lhs match {
case lhs: RefTree =>
lhs.name match {
case ProtectedAccessorName(name) =>
cpy.Apply(tree)(Accessors.insert.useSetter(lhs), tree.rhs :: Nil)
case _ =>
tree
}
case lhs: RefTree if lhs.name.is(ProtectedAccessorName) =>
cpy.Apply(tree)(Accessors.insert.useSetter(lhs), tree.rhs :: Nil)
case _ =>
tree
}
Expand Down
23 changes: 23 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,27 @@ class TestBCode extends DottyBytecodeTest {
assertTrue(containsExpectedCall)
}
}

@Test def partialFunctions = {
val source =
"""object Foo {
| def magic(x: Int) = x
| val foo: PartialFunction[Int, Int] = { case x => magic(x) }
|}
""".stripMargin

checkBCode(source) { dir =>
// We test that the anonymous class generated for the partial function
// holds the method implementations and does not use forwarders
val clsIn = dir.lookupName("Foo$$anon$1.class", directory = false).input
val clsNode = loadClassNode(clsIn)
val applyOrElse = getMethod(clsNode, "applyOrElse")
val instructions = instructionsFromMethod(applyOrElse)
val callMagic = instructions.exists {
case Invoke(_, _, "magic", _, _) => true
case _ => false
}
assertTrue(callMagic)
}
}
}
19 changes: 19 additions & 0 deletions tests/run/i4446.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class Foo {
def foo: PartialFunction[Int, Int] = { case x => x + 1 }
}

object Test {
def serializeDeserialize[T <: AnyRef](obj: T): T = {
import java.io._
val buffer = new ByteArrayOutputStream
val out = new ObjectOutputStream(buffer)
out.writeObject(obj)
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
in.readObject.asInstanceOf[T]
}

def main(args: Array[String]): Unit = {
val adder = serializeDeserialize((new Foo).foo)
assert(adder(1) == 2)
}
}