Skip to content

Commit 5df2904

Browse files
committed
Add MiniPhaseTransform to add specialized methods to FunctionN
1 parent 4a84b0c commit 5df2904

File tree

4 files changed

+273
-66
lines changed

4 files changed

+273
-66
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class Compiler {
5858
new ByNameClosures, // Expand arguments to by-name parameters to closures
5959
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
6060
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
61+
new SpecializedApplyMethods, // Adds specialized methods to FunctionN
6162
new ClassOf), // Expand `Predef.classOf` calls.
6263
List(new TryCatchPatterns, // Compile cases in try/catch
6364
new PatternMatcher, // Compile pattern matches
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import TreeTransforms.{ MiniPhaseTransform, TransformerInfo }
5+
import ast.Trees._, ast.tpd, core._
6+
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
7+
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
8+
9+
/** This phase synthesizes specialized methods for FunctionN, this is done
10+
* since there are no scala signatures in the bytecode for the specialized
11+
* methods.
12+
*
13+
* We know which specializations exist for the different arities, therefore we
14+
* can hardcode them. This should, however be removed once we're using a
15+
* different standard library.
16+
*/
17+
class SpecializedApplyMethods extends MiniPhaseTransform with InfoTransformer {
18+
import ast.tpd._
19+
20+
val phaseName = "specializedApplyMethods"
21+
22+
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
23+
case tp: ClassInfo if defn.isFunctionClass(sym) => {
24+
def specApply(ret: Type, args: List[Type])(implicit ctx: Context) = {
25+
val all = args :+ ret
26+
val name = nme.apply.specializedFor(all, all.map(_.typeSymbol.name), Nil, Nil)
27+
ctx.newSymbol(sym, name, Flags.Method, MethodType(args, ret))
28+
}
29+
30+
val newDecls = sym.name.functionArity match {
31+
case 0 =>
32+
List(
33+
specApply(defn.UnitType, Nil),
34+
specApply(defn.ByteType, Nil),
35+
specApply(defn.ShortType, Nil),
36+
specApply(defn.IntType, Nil),
37+
specApply(defn.LongType, Nil),
38+
specApply(defn.CharType, Nil),
39+
specApply(defn.FloatType, Nil),
40+
specApply(defn.DoubleType, Nil),
41+
specApply(defn.BooleanType, Nil)
42+
)
43+
.foldLeft(tp.decls.cloneScope){ (decls, sym) => decls.enter(sym); decls }
44+
45+
case 1 =>
46+
List(
47+
specApply(defn.UnitType, List(defn.IntType)),
48+
specApply(defn.IntType, List(defn.IntType)),
49+
specApply(defn.FloatType, List(defn.IntType)),
50+
specApply(defn.LongType, List(defn.IntType)),
51+
specApply(defn.DoubleType, List(defn.IntType)),
52+
specApply(defn.UnitType, List(defn.LongType)),
53+
specApply(defn.BooleanType, List(defn.LongType)),
54+
specApply(defn.IntType, List(defn.LongType)),
55+
specApply(defn.FloatType, List(defn.LongType)),
56+
specApply(defn.LongType, List(defn.LongType)),
57+
specApply(defn.DoubleType, List(defn.LongType)),
58+
specApply(defn.UnitType, List(defn.FloatType)),
59+
specApply(defn.BooleanType, List(defn.FloatType)),
60+
specApply(defn.IntType, List(defn.FloatType)),
61+
specApply(defn.FloatType, List(defn.FloatType)),
62+
specApply(defn.LongType, List(defn.FloatType)),
63+
specApply(defn.DoubleType, List(defn.FloatType)),
64+
specApply(defn.UnitType, List(defn.DoubleType)),
65+
specApply(defn.BooleanType, List(defn.DoubleType)),
66+
specApply(defn.IntType, List(defn.DoubleType)),
67+
specApply(defn.FloatType, List(defn.DoubleType)),
68+
specApply(defn.LongType, List(defn.DoubleType)),
69+
specApply(defn.DoubleType, List(defn.DoubleType))
70+
)
71+
.foldLeft(tp.decls.cloneScope){ (decls, sym) => decls.enter(sym); decls }
72+
73+
case 2 =>
74+
List(
75+
specApply(defn.UnitType, List(defn.IntType, defn.IntType)),
76+
specApply(defn.BooleanType, List(defn.IntType, defn.IntType)),
77+
specApply(defn.IntType, List(defn.IntType, defn.IntType)),
78+
specApply(defn.FloatType, List(defn.IntType, defn.IntType)),
79+
specApply(defn.LongType, List(defn.IntType, defn.IntType)),
80+
specApply(defn.DoubleType, List(defn.IntType, defn.IntType)),
81+
specApply(defn.UnitType, List(defn.IntType, defn.LongType)),
82+
specApply(defn.BooleanType, List(defn.IntType, defn.LongType)),
83+
specApply(defn.IntType, List(defn.IntType, defn.LongType)),
84+
specApply(defn.FloatType, List(defn.IntType, defn.LongType)),
85+
specApply(defn.LongType, List(defn.IntType, defn.LongType)),
86+
specApply(defn.DoubleType, List(defn.IntType, defn.LongType)),
87+
specApply(defn.UnitType, List(defn.IntType, defn.DoubleType)),
88+
specApply(defn.BooleanType, List(defn.IntType, defn.DoubleType)),
89+
specApply(defn.IntType, List(defn.IntType, defn.DoubleType)),
90+
specApply(defn.FloatType, List(defn.IntType, defn.DoubleType)),
91+
specApply(defn.LongType, List(defn.IntType, defn.DoubleType)),
92+
specApply(defn.DoubleType, List(defn.IntType, defn.DoubleType)),
93+
specApply(defn.UnitType, List(defn.LongType, defn.IntType)),
94+
specApply(defn.BooleanType, List(defn.LongType, defn.IntType)),
95+
specApply(defn.IntType, List(defn.LongType, defn.IntType)),
96+
specApply(defn.FloatType, List(defn.LongType, defn.IntType)),
97+
specApply(defn.LongType, List(defn.LongType, defn.IntType)),
98+
specApply(defn.DoubleType, List(defn.LongType, defn.IntType)),
99+
specApply(defn.UnitType, List(defn.LongType, defn.LongType)),
100+
specApply(defn.BooleanType, List(defn.LongType, defn.LongType)),
101+
specApply(defn.IntType, List(defn.LongType, defn.LongType)),
102+
specApply(defn.FloatType, List(defn.LongType, defn.LongType)),
103+
specApply(defn.LongType, List(defn.LongType, defn.LongType)),
104+
specApply(defn.DoubleType, List(defn.LongType, defn.LongType)),
105+
specApply(defn.UnitType, List(defn.LongType, defn.DoubleType)),
106+
specApply(defn.BooleanType, List(defn.LongType, defn.DoubleType)),
107+
specApply(defn.IntType, List(defn.LongType, defn.DoubleType)),
108+
specApply(defn.FloatType, List(defn.LongType, defn.DoubleType)),
109+
specApply(defn.LongType, List(defn.LongType, defn.DoubleType)),
110+
specApply(defn.DoubleType, List(defn.LongType, defn.DoubleType)),
111+
specApply(defn.UnitType, List(defn.DoubleType, defn.IntType)),
112+
specApply(defn.BooleanType, List(defn.DoubleType, defn.IntType)),
113+
specApply(defn.IntType, List(defn.DoubleType, defn.IntType)),
114+
specApply(defn.FloatType, List(defn.DoubleType, defn.IntType)),
115+
specApply(defn.LongType, List(defn.DoubleType, defn.IntType)),
116+
specApply(defn.DoubleType, List(defn.DoubleType, defn.IntType)),
117+
specApply(defn.UnitType, List(defn.DoubleType, defn.LongType)),
118+
specApply(defn.BooleanType, List(defn.DoubleType, defn.LongType)),
119+
specApply(defn.IntType, List(defn.DoubleType, defn.LongType)),
120+
specApply(defn.FloatType, List(defn.DoubleType, defn.LongType)),
121+
specApply(defn.LongType, List(defn.DoubleType, defn.LongType)),
122+
specApply(defn.DoubleType, List(defn.DoubleType, defn.LongType)),
123+
specApply(defn.UnitType, List(defn.DoubleType, defn.DoubleType)),
124+
specApply(defn.BooleanType, List(defn.DoubleType, defn.DoubleType)),
125+
specApply(defn.IntType, List(defn.DoubleType, defn.DoubleType)),
126+
specApply(defn.FloatType, List(defn.DoubleType, defn.DoubleType)),
127+
specApply(defn.LongType, List(defn.DoubleType, defn.DoubleType)),
128+
specApply(defn.DoubleType, List(defn.DoubleType, defn.DoubleType))
129+
)
130+
.foldLeft(tp.decls.cloneScope){ (decls, sym) => decls.enter(sym); decls }
131+
132+
case _ =>
133+
tp.decls
134+
}
135+
136+
tp.derivedClassInfo(decls = newDecls)
137+
}
138+
case _ => tp
139+
}
140+
}

compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import dotc.core.Contexts.{Context, ContextBase}
55
import dotc.core.Phases.Phase
66
import dotc.Compiler
77

8-
import scala.reflect.io.{VirtualDirectory => Directory}
8+
import dotty.tools.io.{VirtualDirectory => Directory}
99
import scala.tools.asm
1010
import asm._
1111
import asm.tree._
@@ -17,6 +17,8 @@ import scala.tools.asm.{ClassWriter, ClassReader}
1717
import scala.tools.asm.tree._
1818
import java.io.{File => JFile, InputStream}
1919

20+
import org.junit.Assert._
21+
2022
class TestGenBCode(val outDir: String) extends GenBCode {
2123
override def phaseName: String = "testGenBCode"
2224
val virtualDir = new Directory(outDir, None)
@@ -89,6 +91,14 @@ trait DottyBytecodeTest extends DottyTest {
8991
cn
9092
}
9193

94+
/** Finds a class with `cls` as name in `dir`, throws if it can't find it */
95+
def findClass(cls: String, dir: Directory) = {
96+
val clsIn = dir.lookupName(s"$cls.class", directory = false).input
97+
val clsNode = loadClassNode(clsIn)
98+
assert(clsNode.name == cls, s"inspecting wrong class: ${clsNode.name}")
99+
clsNode
100+
}
101+
92102
protected def getMethod(classNode: ClassNode, name: String): MethodNode =
93103
classNode.methods.asScala.find(_.name == name) getOrElse
94104
sys.error(s"Didn't find method '$name' in class '${classNode.name}'")
@@ -205,4 +215,41 @@ trait DottyBytecodeTest extends DottyTest {
205215
s"Wrong number of null checks ($actualChecks), expected: $expectedChecks"
206216
)
207217
}
218+
219+
def assertBoxing(nodeName: String, methods: java.lang.Iterable[MethodNode])(implicit source: String): Unit =
220+
methods.asScala.find(_.name == nodeName)
221+
.map { node =>
222+
val (ins, boxed) = boxingInstructions(node)
223+
if (!boxed) fail("No boxing in:\n" + boxingError(ins, source))
224+
}
225+
.getOrElse(fail("Could not find constructor for object `Test`"))
226+
227+
private def boxingError(ins: List[_], source: String) =
228+
s"""|----------------------------------
229+
|${ins.mkString("\n")}
230+
|----------------------------------
231+
|From code:
232+
|$source
233+
|----------------------------------""".stripMargin
234+
235+
236+
protected def assertNoBoxing(nodeName: String, methods: java.lang.Iterable[MethodNode])(implicit source: String): Unit =
237+
methods.asScala.find(_.name == nodeName)
238+
.map { node =>
239+
val (ins, boxed) = boxingInstructions(node)
240+
if (boxed) fail(boxingError(ins, source))
241+
}
242+
.getOrElse(fail("Could not find constructor for object `Test`"))
243+
244+
protected def boxingInstructions(method: MethodNode): (List[_], Boolean) = {
245+
val ins = instructionsFromMethod(method)
246+
val boxed = ins.exists {
247+
case Invoke(op, owner, name, desc, itf) =>
248+
owner.toLowerCase.contains("box") || name.toLowerCase.contains("box")
249+
case _ => false
250+
}
251+
252+
(ins, boxed)
253+
}
254+
208255
}

0 commit comments

Comments
 (0)