diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index b087f313c33e..56e71c29110b 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -92,6 +92,7 @@ class Compiler { List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations + new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]` new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses new TailRec, // Rewrite tail recursion to loops new Mixin, // Expand trait fields and trait initializers diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 31804f36b9cb..c76c3cf75da6 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -755,6 +755,8 @@ class Definitions { @threadUnsafe lazy val ClassTagType: TypeRef = ctx.requiredClassRef("scala.reflect.ClassTag") def ClassTagClass(implicit ctx: Context): ClassSymbol = ClassTagType.symbol.asClass def ClassTagModule(implicit ctx: Context): Symbol = ClassTagClass.companionModule + @threadUnsafe lazy val ClassTagModule_applyR: TermRef = ClassTagModule.requiredMethodRef(nme.apply) + def ClassTagModule_apply(implicit ctx: Context): Symbol = ClassTagModule_applyR.symbol @threadUnsafe lazy val QuotedExprType: TypeRef = ctx.requiredClassRef("scala.quoted.Expr") def QuotedExprClass(implicit ctx: Context): ClassSymbol = QuotedExprType.symbol.asClass diff --git a/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala new file mode 100644 index 000000000000..39260a98850b --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/ArrayApply.scala @@ -0,0 +1,69 @@ +package dotty.tools.dotc +package transform + +import core._ +import MegaPhase._ +import Contexts.Context +import Symbols._ +import Types._ +import StdNames._ +import ast.Trees._ +import dotty.tools.dotc.ast.tpd + +import scala.reflect.ClassTag + + +/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode. + * + * Transforms `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]` + */ +class ArrayApply extends MiniPhase { + import tpd._ + + override def phaseName: String = "arrayApply" + + override def transformApply(tree: tpd.Apply)(implicit ctx: Context): tpd.Tree = { + if (tree.symbol.name == nme.apply && tree.symbol.owner == defn.ArrayModule) { // Is `Array.apply` + tree.args match { + case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil + if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) => + seqLit + + case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil + if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) => + tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt) + + case _ => + tree + } + + } else tree + } + + /** Only optimize when classtag if it is one of + * - `ClassTag.apply(classOf[XYZ])` + * - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ`` + * - `ClassTag.XYZ` for primitive types + */ + private def elideClassTag(ct: Tree)(implicit ctx: Context): Boolean = ct match { + case Apply(_, rc :: Nil) if ct.symbol == defn.ClassTagModule_apply => + rc match { + case _: Literal => true // ClassTag.apply(classOf[XYZ]) + case rc: RefTree if rc.name == nme.TYPE_ => + // ClassTag.apply(java.lang.XYZ.Type) + defn.ScalaBoxedClasses().contains(rc.symbol.maybeOwner.companionClass) + case _ => false + } + case Apply(ctm: RefTree, _) if ctm.symbol.maybeOwner.companionModule == defn.ClassTagModule => + // ClassTag.XYZ + nme.ScalaValueNames.contains(ctm.name) + case _ => false + } + + object StripAscription { + def unapply(tree: Tree)(implicit ctx: Context): Some[Tree] = tree match { + case Typed(expr, _) => unapply(expr) + case _ => Some(tree) + } + } +} diff --git a/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala new file mode 100644 index 000000000000..1089abf331a8 --- /dev/null +++ b/compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala @@ -0,0 +1,109 @@ +package dotty.tools.backend.jvm + +import org.junit.Test +import org.junit.Assert._ + +import scala.tools.asm.Opcodes._ + +class ArrayApplyOptTest extends DottyBytecodeTest { + import ASMConverters._ + + @Test def testArrayEmptyGenericApply= { + test("Array[String]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/String"), Op(POP), Op(RETURN))) + test("Array[Unit]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(POP), Op(RETURN))) + test("Array[Object]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/Object"), Op(POP), Op(RETURN))) + test("Array[Boolean]()", newArray0Opcodes(T_BOOLEAN)) + test("Array[Byte]()", newArray0Opcodes(T_BYTE)) + test("Array[Short]()", newArray0Opcodes(T_SHORT)) + test("Array[Int]()", newArray0Opcodes(T_INT)) + test("Array[Long]()", newArray0Opcodes(T_LONG)) + test("Array[Float]()", newArray0Opcodes(T_FLOAT)) + test("Array[Double]()", newArray0Opcodes(T_DOUBLE)) + test("Array[Char]()", newArray0Opcodes(T_CHAR)) + test("Array[T]()", newArray0Opcodes(T_INT)) + } + + @Test def testArrayGenericApply= { + def opCodes(tpe: String) = + List(Op(ICONST_2), TypeOp(ANEWARRAY, tpe), Op(DUP), Op(ICONST_0), Ldc(LDC, "a"), Op(AASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, "b"), Op(AASTORE), Op(POP), Op(RETURN)) + test("""Array("a", "b")""", opCodes("java/lang/String")) + test("""Array[Object]("a", "b")""", opCodes("java/lang/Object")) + } + + @Test def testArrayApplyBoolean = + test("Array(true, false)", newArray2Opcodes(T_BOOLEAN, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_0), Op(BASTORE)))) + + @Test def testArrayApplyByte = + test("Array[Byte](1, 2)", newArray2Opcodes(T_BYTE, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(BASTORE)))) + + @Test def testArrayApplyShort = + test("Array[Short](1, 2)", newArray2Opcodes(T_SHORT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(SASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(SASTORE)))) + + @Test def testArrayApplyInt = { + test("Array(1, 2)", newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE)))) + test("""Array[T](t, t)""", newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE), Op(DUP), Op(ICONST_1), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE)))) + } + + @Test def testArrayApplyLong = + test("Array(2L, 3L)", newArray2Opcodes(T_LONG, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2), Op(LASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3), Op(LASTORE)))) + + @Test def testArrayApplyFloat = + test("Array(2.1f, 3.1f)", newArray2Opcodes(T_FLOAT, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2.1f), Op(FASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.1f), Op(FASTORE)))) + + @Test def testArrayApplyDouble = + test("Array(2.2d, 3.2d)", newArray2Opcodes(T_DOUBLE, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2.2d), Op(DASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.2d), Op(DASTORE)))) + + @Test def testArrayApplyChar = + test("Array('x', 'y')", newArray2Opcodes(T_CHAR, List(Op(DUP), Op(ICONST_0), IntOp(BIPUSH, 120), Op(CASTORE), Op(DUP), Op(ICONST_1), IntOp(BIPUSH, 121), Op(CASTORE)))) + + @Test def testArrayApplyUnit = + test("Array[Unit]((), ())", List(Op(ICONST_2), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(DUP), + Op(ICONST_0), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(DUP), + Op(ICONST_1), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(POP), Op(RETURN))) + + @Test def testArrayInlined = test( + """{ + | inline def array(xs: =>Int*): Array[Int] = Array(xs: _*) + | array(1, 2) + |}""".stripMargin, + newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE), TypeOp(CHECKCAST, "[I"))) + ) + + @Test def testArrayInlined2 = test( + """{ + | inline def array(x: =>Int, xs: =>Int*): Array[Int] = Array(x, xs: _*) + | array(1, 2) + |}""".stripMargin, + newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE))) + ) + + private def newArray0Opcodes(tpe: Int, init: List[Any] = Nil): List[Any] = + Op(ICONST_0) :: IntOp(NEWARRAY, tpe) :: init ::: Op(POP) :: Op(RETURN) :: Nil + + private def newArray2Opcodes(tpe: Int, init: List[Any] = Nil): List[Any] = + Op(ICONST_2) :: IntOp(NEWARRAY, tpe) :: init ::: Op(POP) :: Op(RETURN) :: Nil + + private def test(code: String, expectedInstructions: List[Any])= { + val source = + s"""class Foo { + | import Foo._ + | def test: Unit = $code + |} + |object Foo { + | opaque type T = Int + | def t: T = 1 + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth = getMethod(clsNode, "test") + + val instructions = instructionsFromMethod(meth) + + assertEquals(expectedInstructions, instructions) + } + } + +} diff --git a/tests/run/i502.check b/tests/run/i502.check new file mode 100644 index 000000000000..1bd58020669f --- /dev/null +++ b/tests/run/i502.check @@ -0,0 +1,3 @@ +Ok +foo +bar diff --git a/tests/run/i502.scala b/tests/run/i502.scala new file mode 100644 index 000000000000..c56825c134b0 --- /dev/null +++ b/tests/run/i502.scala @@ -0,0 +1,16 @@ +import scala.reflect.ClassTag + +object Test extends App { + Array[Int](1, 2) + + try { + Array[Int](1, 2)(null) + ??? + } catch { + case _: NullPointerException => println("Ok") + } + + Array[Int](1, 2)({println("foo"); the[ClassTag[Int]]}) + + Array[Int](1, 2)(ClassTag.apply({ println("bar"); classOf[Int]})) +} diff --git a/tests/run/t6611b.scala b/tests/run/t6611b.scala new file mode 100644 index 000000000000..1c3cdf8d9742 --- /dev/null +++ b/tests/run/t6611b.scala @@ -0,0 +1,6 @@ +object Test extends App { + val a = Array("1") + val a2 = Array(a: _*) + a2(0) = "2" + assert(a(0) == "1") +}