diff --git a/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala b/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala index d02fa330cdf8..aec44d5987bf 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeExtractors.scala @@ -19,11 +19,15 @@ object TreeExtractors { } } - /** Match new C(args) and extract (C, args) */ + /** Match new C(args) and extract (C, args). + * Also admit new C(args): T and {new C(args)}. + */ object NewWithArgs { def unapply(t: Tree)(using Context): Option[(Type, List[Tree])] = t match { case Apply(Select(New(_), nme.CONSTRUCTOR), args) => Some((t.tpe, args)) + case Typed(expr, _) => unapply(expr) + case Block(Nil, expr) => unapply(expr) case _ => None } diff --git a/compiler/src/dotty/tools/dotc/transform/VCElideAllocations.scala b/compiler/src/dotty/tools/dotc/transform/VCElideAllocations.scala index caa387b9e937..d9fb4567e169 100644 --- a/compiler/src/dotty/tools/dotc/transform/VCElideAllocations.scala +++ b/compiler/src/dotty/tools/dotc/transform/VCElideAllocations.scala @@ -3,7 +3,7 @@ package transform import ast.tpd import core._ -import Contexts._, Symbols._ +import Contexts._, Symbols._, Types._, Flags._, Phases._ import DenotTransformers._, MegaPhase._ import TreeExtractors._, ValueClasses._ @@ -23,13 +23,19 @@ class VCElideAllocations extends MiniPhase with IdentityDenotTransformer { override def runsAfter: Set[String] = Set(ElimErasedValueType.name) override def transformApply(tree: Apply)(using Context): Tree = + def hasUserDefinedEquals(tp: Type): Boolean = + val eql = atPhase(erasurePhase) { + defn.Any_equals.matchingMember(tp.typeSymbol.thisType) + } + eql.owner != defn.AnyClass && !eql.is(Synthetic) + tree match { - // new V(u1) == new V(u2) => u1 == u2 + // new V(u1) == new V(u2) => u1 == u2, unless V defines its own equals. // (We don't handle != because it has been eliminated by InterceptedMethods) case BinaryOp(NewWithArgs(tp1, List(u1)), op, NewWithArgs(tp2, List(u2))) - if (tp1 eq tp2) && (op eq defn.Any_==) && - isDerivedValueClass(tp1.typeSymbol) && - !defn.Any_equals.overridingSymbol(tp1.typeSymbol.asClass).exists => + if (tp1 eq tp2) && (op eq defn.Any_==) + && isDerivedValueClass(tp1.typeSymbol) + && !hasUserDefinedEquals(tp1) => // == is overloaded in primitive classes u1.equal(u2) diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index ebf6d9f09d5d..197a9f7a8ef8 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -976,6 +976,47 @@ class TestBCode extends DottyBytecodeTest { assert((getMethod(c, "x_$eq").access & Opcodes.ACC_DEPRECATED) != 0) } } + + @Test def vcElideAllocations = { + val source = + s"""class ApproxState(val bits: Int) extends AnyVal + |class Foo { + | val FreshApprox: ApproxState = new ApproxState(4) + | var approx: ApproxState = FreshApprox + | def meth1: Boolean = approx == FreshApprox + | def meth2: Boolean = (new ApproxState(4): ApproxState) == FreshApprox + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + val isFrameLine = (x: Instruction) => x.isInstanceOf[FrameEntry] || x.isInstanceOf[LineNumber] + + // No allocations of ApproxState + + assertSameCode(instructions1.filterNot(isFrameLine), List( + VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "Foo", "approx", "()I", false), + VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "Foo", "FreshApprox", "()I", false), + Jump(IF_ICMPNE, Label(7)), Op(ICONST_1), + Jump(GOTO, Label(10)), + Label(7), Op(ICONST_0), + Label(10), Op(IRETURN))) + + assertSameCode(instructions2.filterNot(isFrameLine), List( + Op(ICONST_4), + VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "Foo", "FreshApprox", "()I", false), + Jump(IF_ICMPNE, Label(6)), Op(ICONST_1), + Jump(GOTO, Label(9)), + Label(6), Op(ICONST_0), + Label(9), Op(IRETURN))) + } + } } object invocationReceiversTestCode {