Skip to content

Commit f940d92

Browse files
committed
Fixup List optimisation improvement
1 parent 3ed582c commit f940d92

File tree

4 files changed

+88
-29
lines changed

4 files changed

+88
-29
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ class Definitions {
517517
def ListType: TypeRef = ListClass.typeRef
518518
@tu lazy val ListModule: Symbol = requiredModule("scala.collection.immutable.List")
519519
@tu lazy val ListModule_apply: Symbol = ListModule.requiredMethod(nme.apply)
520+
def ListModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.List)
520521
@tu lazy val NilModule: Symbol = requiredModule("scala.collection.immutable.Nil")
521522
def NilType: TermRef = NilModule.termRef
522523
@tu lazy val ConsClass: Symbol = requiredClass("scala.collection.immutable.::")
@@ -531,17 +532,18 @@ class Definitions {
531532
List(AnyType), EmptyScope)
532533
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef
533534

534-
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
535-
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
535+
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
536+
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
537+
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
538+
@tu lazy val SeqModule_apply: Symbol = SeqModule.requiredMethod(nme.apply)
539+
def SeqModuleAlias: Symbol = ScalaPackageClass.requiredMethod(nme.Seq)
536540
def SeqClass(using Context): ClassSymbol = SeqType.symbol.asClass
537541
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
538542
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
539543
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
540544
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
541545
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
542546
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
543-
@tu lazy val SeqModule : Symbol = requiredModule("scala.collection.immutable.Seq")
544-
@tu lazy val SeqModule_apply : Symbol = SeqModule.requiredMethod(nme.apply)
545547

546548

547549
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")

compiler/src/dotty/tools/dotc/transform/ArrayApply.scala

+39-21
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
package dotty.tools.dotc
1+
package dotty.tools
2+
package dotc
23
package transform
34

45
import core._
56
import MegaPhase._
67
import Contexts._
8+
import Decorators.*
79
import Symbols._
810
import Flags._
911
import StdNames._
10-
import dotty.tools.dotc.ast.tpd
11-
12-
12+
import ast.tpd
13+
import reporting.trace
1314

1415
/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
1516
*
@@ -22,49 +23,66 @@ class ArrayApply extends MiniPhase {
2223

2324
override def description: String = ArrayApply.description
2425

25-
private var transformListApplyLimit = 8
26+
private val transformListApplyLimit = 8
2627

27-
private def reducingTransformListApply[A](depth: Int)(body: => A): A = {
28-
val saved = transformListApplyLimit
29-
transformListApplyLimit -= depth
30-
try body
31-
finally transformListApplyLimit = saved
32-
}
28+
override def transformTypeApply(tree: TypeApply)(using Context): Tree =
29+
stripCast(tree) match
30+
case app: Apply if isConsChain(app) => app
31+
case _ => tree
3332

34-
override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
33+
override def transformApply(tree: Apply)(using Context): Tree =
3534
if isArrayModuleApply(tree.symbol) then
3635
tree.args match
37-
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
36+
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: ct :: Nil
3837
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
3938
seqLit
4039

41-
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
40+
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: JavaSeqLiteral) :: Nil)) :: Nil
4241
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
43-
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
42+
JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
4443

4544
case _ =>
4645
tree
4746

48-
else if isListOrSeqModuleApply(tree.symbol) then
47+
else if isSeqApply(tree) then
4948
tree.args match
5049
// <List or Seq>(a, b, c) ~> new ::(a, new ::(b, new ::(c, Nil))) but only for reference types
51-
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: tpd.JavaSeqLiteral)))) :: Nil
50+
case StripAscription(Apply(wrapArrayMeth, List(StripAscription(rest: JavaSeqLiteral)))) :: Nil
5251
if defn.WrapArrayMethods().contains(wrapArrayMeth.symbol) &&
5352
rest.elems.lengthIs < transformListApplyLimit =>
54-
rest.elems.foldRight(tpd.ref(defn.NilModule)): (elem, acc) =>
55-
tpd.New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
53+
rest.elems.foldRight(ref(defn.NilModule)): (elem, acc) =>
54+
New(defn.ConsType, List(elem.ensureConforms(defn.ObjectType), acc))
5655

5756
case _ =>
5857
tree
5958

6059
else tree
6160

61+
private def isConsChain(tree: Tree)(using Context): Boolean = tree match
62+
case Apply(Select(New(tt), nme.CONSTRUCTOR), List(_, arg)) =>
63+
tt.symbol == defn.ConsClass && (arg.symbol == defn.NilModule || isConsChain(arg))
64+
case _ => false
65+
6266
private def isArrayModuleApply(sym: Symbol)(using Context): Boolean =
6367
sym.name == nme.apply
6468
&& (sym.owner == defn.ArrayModuleClass || (sym.owner == defn.IArrayModuleClass && !sym.is(Extension)))
6569

66-
private def isListOrSeqModuleApply(sym: Symbol)(using Context): Boolean =
67-
sym == defn.ListModule_apply || sym == defn.SeqModule_apply
70+
private def isListApply(tree: Tree)(using Context): Boolean =
71+
(tree.symbol == defn.ListModule_apply || tree.symbol.name == nme.apply) && appliedCore(tree).match
72+
case Select(qual, _) =>
73+
val sym = qual.symbol
74+
sym == defn.ListModule
75+
|| sym == defn.ListModuleAlias
76+
case _ => false
77+
78+
private def isSeqApply(tree: Tree)(using Context): Boolean =
79+
isListApply(tree) || tree.symbol == defn.SeqModule_apply && appliedCore(tree).match
80+
case Select(qual, _) =>
81+
val sym = qual.symbol
82+
sym == defn.SeqModule
83+
|| sym == defn.SeqModuleAlias
84+
|| sym == defn.CollectionSeqType.symbol.companionModule
85+
case _ => false
6886

6987
/** Only optimize when classtag if it is one of
7088
* - `ClassTag.apply(classOf[XYZ])`

compiler/test/dotty/tools/backend/jvm/ArrayApplyOptTest.scala

+40-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
package dotty.tools.backend.jvm
1+
package dotty.tools
2+
package backend.jvm
23

34
import org.junit.Test
45
import org.junit.Assert._
@@ -161,15 +162,50 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
161162
}
162163

163164
@Test def testListApplyAvoidsIntermediateArray = {
164-
val source =
165+
checkApplyAvoidsIntermediateArray("List"):
165166
"""
166167
|class Foo {
167168
| def meth1: List[String] = List("1", "2", "3")
168169
| def meth2: List[String] =
169-
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
170+
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
171+
|}
172+
""".stripMargin
173+
}
174+
175+
@Test def testSeqApplyAvoidsIntermediateArray = {
176+
checkApplyAvoidsIntermediateArray("Seq"):
177+
"""
178+
|class Foo {
179+
| def meth1: Seq[String] = Seq("1", "2", "3")
180+
| def meth2: Seq[String] =
181+
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
170182
|}
171183
""".stripMargin
184+
}
185+
186+
@Test def testSeqApplyAvoidsIntermediateArray2 = {
187+
checkApplyAvoidsIntermediateArray("scala.collection.immutable.Seq"):
188+
"""import scala.collection.immutable.Seq
189+
|class Foo {
190+
| def meth1: Seq[String] = Seq("1", "2", "3")
191+
| def meth2: Seq[String] =
192+
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
193+
|}
194+
""".stripMargin
195+
}
196+
197+
@Test def testSeqApplyAvoidsIntermediateArray3 = {
198+
checkApplyAvoidsIntermediateArray("scala.collection.Seq"):
199+
"""import scala.collection.Seq
200+
|class Foo {
201+
| def meth1: Seq[String] = Seq("1", "2", "3")
202+
| def meth2: Seq[String] =
203+
| new scala.collection.immutable.::("1", new scala.collection.immutable.::("2", new scala.collection.immutable.::("3", scala.collection.immutable.Nil))).asInstanceOf[List[String]]
204+
|}
205+
""".stripMargin
206+
}
172207

208+
def checkApplyAvoidsIntermediateArray(name: String)(source: String) = {
173209
checkBCode(source) { dir =>
174210
val clsIn = dir.lookupName("Foo.class", directory = false).input
175211
val clsNode = loadClassNode(clsIn)
@@ -180,7 +216,7 @@ class ArrayApplyOptTest extends DottyBytecodeTest {
180216
val instructions2 = instructionsFromMethod(meth2)
181217

182218
assert(instructions1 == instructions2,
183-
"the List.apply method " +
219+
s"the $name.apply method\n" +
184220
diffInstructions(instructions1, instructions2))
185221
}
186222
}

tests/run/list-apply-eval.scala

+3
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ object Test:
1919

2020
val emptyList = List[Int]()
2121
assert(emptyList == Nil)
22+
23+
// just assert it doesn't throw CCE to List
24+
val queue = scala.collection.mutable.Queue[String]()

0 commit comments

Comments
 (0)