diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index feaf5d134a09..0d22c55a2356 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -60,13 +60,13 @@ class Compiler { protected def transformPhases: List[List[Phase]] = List(new FirstTransform, // Some transformations to put trees into a canonical form new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars - new ProtectedAccessors, // Add accessors for protected members new ElimPackagePrefixes) :: // Eliminate references to package prefixes in Select nodes List(new CheckStatic, // Check restrictions that apply to @static members new ElimRepeated, // Rewrite vararg parameters and arguments new NormalizeFlags, // Rewrite some definition flags - new ExtensionMethods, // Expand methods of value classes with extension methods new ExpandSAMs, // Expand single abstract method closures to anonymous classes + new ProtectedAccessors, // Add accessors for protected members + new ExtensionMethods, // Expand methods of value classes with extension methods new ShortcutImplicits, // Allow implicit functions without creating closures new TailRec, // Rewrite tail recursion to loops new ByNameClosures, // Expand arguments to by-name parameters to closures diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index eb1e5edb087c..3cac673ae5b0 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -56,8 +56,47 @@ class ExpandSAMs extends MiniPhase { tree } + /** A partial function literal: + * + * ``` + * val x: PartialFunction[A, B] = { case C1 => E1; ...; case Cn => En } + * ``` + * + * which desugars to: + * + * ``` + * val x: PartialFunction[A, B] = { + * def $anonfun(x: A): B = x match { case C1 => E1; ...; case Cn => En } + * closure($anonfun: PartialFunction[A, B]) + * } + * ``` + * + * is expanded to an anomymous class: + * + * ``` + * val x: PartialFunction[A, B] = { + * class $anon extends AbstractPartialFunction[A, B] { + * final def isDefinedAt(x: A): Boolean = x match { + * case C1 => true + * ... + * case Cn => true + * case _ => false + * } + * + * final def applyOrElse[A1 <: A, B1 >: B](x: A1, default: A1 => B1): B1 = x match { + * case C1 => E1 + * ... + * case Cn => En + * case _ => default(x) + * } + * } + * + * new $anon + * } + * ``` + */ private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = { - // /** An extractor for match, either contained in a block or standalone. */ + /** An extractor for match, either contained in a block or standalone. */ object PartialFunctionRHS { def unapply(tree: Tree): Option[Match] = tree match { case Block(Nil, expr) => unapply(expr) @@ -71,15 +110,18 @@ class ExpandSAMs extends MiniPhase { case PartialFunctionRHS(pf) => val anonSym = anon.symbol + val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType) + val pfSym = ctx.newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.pos) + def overrideSym(sym: Symbol) = sym.copy( - owner = anonSym.owner, - flags = Synthetic | Method | Final, + owner = pfSym, + flags = Synthetic | Method | Final | Override, info = tpe.memberInfo(sym), - coord = tree.pos).asTerm + coord = tree.pos).asTerm.entered val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt) val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse) - def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = { + def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(implicit ctx: Context) = { val selector = tree.selector val selectorTpe = selector.tpe.widen val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe) @@ -96,7 +138,7 @@ class ExpandSAMs extends MiniPhase { // And we need to update all references to 'param' } - def isDefinedAtRhs(paramRefss: List[List[Tree]]) = { + def isDefinedAtRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = { val tru = Literal(Constant(true)) def translateCase(cdef: CaseDef) = cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn) @@ -105,7 +147,7 @@ class ExpandSAMs extends MiniPhase { translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) } - def applyOrElseRhs(paramRefss: List[List[Tree]]) = { + def applyOrElseRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = { val List(paramRef, defaultRef) = paramRefss.head def translateCase(cdef: CaseDef) = cdef.changeOwner(anonSym, applyOrElseFn) @@ -113,12 +155,11 @@ class ExpandSAMs extends MiniPhase { translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue) } - val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_))) - val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_))) - - val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType) - val anonCls = AnonClass(parents, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse)) - cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls) + val constr = ctx.newConstructor(pfSym, Synthetic, Nil, Nil).entered + val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(ctx.withOwner(isDefinedAtFn)))) + val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(ctx.withOwner(applyOrElseFn)))) + val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef)) + cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil)) case _ => val found = tpe.baseType(defn.FunctionClass(1)) diff --git a/compiler/src/dotty/tools/dotc/transform/ProtectedAccessors.scala b/compiler/src/dotty/tools/dotc/transform/ProtectedAccessors.scala index 401aa6f900a3..d6fb0e7e1787 100644 --- a/compiler/src/dotty/tools/dotc/transform/ProtectedAccessors.scala +++ b/compiler/src/dotty/tools/dotc/transform/ProtectedAccessors.scala @@ -67,13 +67,8 @@ class ProtectedAccessors extends MiniPhase { override def transformAssign(tree: Assign)(implicit ctx: Context): Tree = tree.lhs match { - case lhs: RefTree => - lhs.name match { - case ProtectedAccessorName(name) => - cpy.Apply(tree)(Accessors.insert.useSetter(lhs), tree.rhs :: Nil) - case _ => - tree - } + case lhs: RefTree if lhs.name.is(ProtectedAccessorName) => + cpy.Apply(tree)(Accessors.insert.useSetter(lhs), tree.rhs :: Nil) case _ => tree } diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index 007e9174be30..65f736e3e184 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -287,4 +287,27 @@ class TestBCode extends DottyBytecodeTest { assertTrue(containsExpectedCall) } } + + @Test def partialFunctions = { + val source = + """object Foo { + | def magic(x: Int) = x + | val foo: PartialFunction[Int, Int] = { case x => magic(x) } + |} + """.stripMargin + + checkBCode(source) { dir => + // We test that the anonymous class generated for the partial function + // holds the method implementations and does not use forwarders + val clsIn = dir.lookupName("Foo$$anon$1.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val applyOrElse = getMethod(clsNode, "applyOrElse") + val instructions = instructionsFromMethod(applyOrElse) + val callMagic = instructions.exists { + case Invoke(_, _, "magic", _, _) => true + case _ => false + } + assertTrue(callMagic) + } + } } diff --git a/tests/run/i4446.scala b/tests/run/i4446.scala new file mode 100644 index 000000000000..080153c3598a --- /dev/null +++ b/tests/run/i4446.scala @@ -0,0 +1,19 @@ +class Foo { + def foo: PartialFunction[Int, Int] = { case x => x + 1 } +} + +object Test { + def serializeDeserialize[T <: AnyRef](obj: T): T = { + import java.io._ + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray)) + in.readObject.asInstanceOf[T] + } + + def main(args: Array[String]): Unit = { + val adder = serializeDeserialize((new Foo).foo) + assert(adder(1) == 2) + } +}