Skip to content

Commit 4f3632f

Browse files
Support records in JavaParsers (#16762)
This is a port of scala/scala#9551. Fixes #14846.
2 parents b67e269 + da4996a commit 4f3632f

File tree

12 files changed

+208
-20
lines changed

12 files changed

+208
-20
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@ class Definitions {
688688
@tu lazy val JavaCalendarClass: ClassSymbol = requiredClass("java.util.Calendar")
689689
@tu lazy val JavaDateClass: ClassSymbol = requiredClass("java.util.Date")
690690
@tu lazy val JavaFormattableClass: ClassSymbol = requiredClass("java.util.Formattable")
691+
@tu lazy val JavaRecordClass: Symbol = getClassIfDefined("java.lang.Record")
691692

692693
@tu lazy val JavaEnumClass: ClassSymbol = {
693694
val cls = requiredClass("java.lang.Enum")

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ object StdNames {
204204
final val Null: N = "Null"
205205
final val Object: N = "Object"
206206
final val FromJavaObject: N = "<FromJavaObject>"
207+
final val Record: N = "Record"
207208
final val Product: N = "Product"
208209
final val PartialFunction: N = "PartialFunction"
209210
final val PrefixType: N = "PrefixType"
@@ -913,6 +914,10 @@ object StdNames {
913914
final val VOLATILEkw: N = kw("volatile")
914915
final val WHILEkw: N = kw("while")
915916

917+
final val RECORDid: N = "record"
918+
final val VARid: N = "var"
919+
final val YIELDid: N = "yield"
920+
916921
final val BoxedBoolean: N = "java.lang.Boolean"
917922
final val BoxedByte: N = "java.lang.Byte"
918923
final val BoxedCharacter: N = "java.lang.Character"
@@ -945,6 +950,8 @@ object StdNames {
945950
final val JavaSerializable: N = "java.io.Serializable"
946951
}
947952

953+
954+
948955
class JavaTermNames extends JavaNames[TermName] {
949956
protected def fromString(s: String): TermName = termName(s)
950957
}

compiler/src/dotty/tools/dotc/parsing/JavaParsers.scala

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ import StdNames._
2020
import reporting._
2121
import dotty.tools.dotc.util.SourceFile
2222
import util.Spans._
23-
import scala.collection.mutable.ListBuffer
23+
24+
import scala.collection.mutable.{ListBuffer, LinkedHashMap}
2425

2526
object JavaParsers {
2627

@@ -96,8 +97,12 @@ object JavaParsers {
9697
def javaLangDot(name: Name): Tree =
9798
Select(javaDot(nme.lang), name)
9899

100+
/** Tree representing `java.lang.Object` */
99101
def javaLangObject(): Tree = javaLangDot(tpnme.Object)
100102

103+
/** Tree representing `java.lang.Record` */
104+
def javaLangRecord(): Tree = javaLangDot(tpnme.Record)
105+
101106
def arrayOf(tpt: Tree): AppliedTypeTree =
102107
AppliedTypeTree(scalaDot(tpnme.Array), List(tpt))
103108

@@ -555,6 +560,14 @@ object JavaParsers {
555560

556561
def definesInterface(token: Int): Boolean = token == INTERFACE || token == AT
557562

563+
/** If the next token is the identifier "record", convert it into the RECORD token.
564+
* This makes it easier to handle records in various parts of the code,
565+
* in particular when a `parentToken` is passed to some functions.
566+
*/
567+
def adaptRecordIdentifier(): Unit =
568+
if in.token == IDENTIFIER && in.name == jnme.RECORDid then
569+
in.token = RECORD
570+
558571
def termDecl(start: Offset, mods: Modifiers, parentToken: Int, parentTParams: List[TypeDef]): List[Tree] = {
559572
val inInterface = definesInterface(parentToken)
560573
val tparams = if (in.token == LT) typeParams(Flags.JavaDefined | Flags.Param) else List()
@@ -581,6 +594,16 @@ object JavaParsers {
581594
TypeTree(), methodBody()).withMods(mods)
582595
}
583596
}
597+
} else if (in.token == LBRACE && rtptName != nme.EMPTY && parentToken == RECORD) {
598+
/*
599+
record RecordName(T param1, ...) {
600+
RecordName { // <- here
601+
// methodBody
602+
}
603+
}
604+
*/
605+
methodBody()
606+
Nil
584607
}
585608
else {
586609
var mods1 = mods
@@ -717,12 +740,11 @@ object JavaParsers {
717740
ValDef(name, tpt2, if (mods.is(Flags.Param)) EmptyTree else unimplementedExpr).withMods(mods1)
718741
}
719742

720-
def memberDecl(start: Offset, mods: Modifiers, parentToken: Int, parentTParams: List[TypeDef]): List[Tree] = in.token match {
721-
case CLASS | ENUM | INTERFACE | AT =>
722-
typeDecl(start, if (definesInterface(parentToken)) mods | Flags.JavaStatic else mods)
743+
def memberDecl(start: Offset, mods: Modifiers, parentToken: Int, parentTParams: List[TypeDef]): List[Tree] = in.token match
744+
case CLASS | ENUM | RECORD | INTERFACE | AT =>
745+
typeDecl(start, if definesInterface(parentToken) then mods | Flags.JavaStatic else mods)
723746
case _ =>
724747
termDecl(start, mods, parentToken, parentTParams)
725-
}
726748

727749
def makeCompanionObject(cdef: TypeDef, statics: List[Tree]): Tree =
728750
atSpan(cdef.span) {
@@ -804,6 +826,51 @@ object JavaParsers {
804826
addCompanionObject(statics, cls)
805827
}
806828

829+
def recordDecl(start: Offset, mods: Modifiers): List[Tree] =
830+
accept(RECORD)
831+
val nameOffset = in.offset
832+
val name = identForType()
833+
val tparams = typeParams()
834+
val header = formalParams()
835+
val superclass = javaLangRecord() // records always extend java.lang.Record
836+
val interfaces = interfacesOpt() // records may implement interfaces
837+
val (statics, body) = typeBody(RECORD, name, tparams)
838+
839+
// We need to generate accessors for every param, if no method with the same name is already defined
840+
841+
var fieldsByName = header.map(v => (v.name, (v.tpt, v.mods.annotations))).to(LinkedHashMap)
842+
843+
for case DefDef(name, paramss, _, _) <- body
844+
if paramss.isEmpty && fieldsByName.contains(name)
845+
do
846+
fieldsByName -= name
847+
end for
848+
849+
val accessors =
850+
(for (name, (tpt, annots)) <- fieldsByName yield
851+
DefDef(name, Nil, tpt, unimplementedExpr)
852+
.withMods(Modifiers(Flags.JavaDefined | Flags.Method | Flags.Synthetic))
853+
).toList
854+
855+
// generate the canonical constructor
856+
val canonicalConstructor =
857+
DefDef(nme.CONSTRUCTOR, joinParams(tparams, List(header)), TypeTree(), EmptyTree)
858+
.withMods(Modifiers(Flags.JavaDefined | Flags.Synthetic, mods.privateWithin))
859+
860+
// return the trees
861+
val recordTypeDef = atSpan(start, nameOffset) {
862+
TypeDef(name,
863+
makeTemplate(
864+
parents = superclass :: interfaces,
865+
stats = canonicalConstructor :: accessors ::: body,
866+
tparams = tparams,
867+
true
868+
)
869+
).withMods(mods)
870+
}
871+
addCompanionObject(statics, recordTypeDef)
872+
end recordDecl
873+
807874
def interfaceDecl(start: Offset, mods: Modifiers): List[Tree] = {
808875
accept(INTERFACE)
809876
val nameOffset = in.offset
@@ -846,7 +913,8 @@ object JavaParsers {
846913
else if (in.token == SEMI)
847914
in.nextToken()
848915
else {
849-
if (in.token == ENUM || definesInterface(in.token)) mods |= Flags.JavaStatic
916+
adaptRecordIdentifier()
917+
if (in.token == ENUM || in.token == RECORD || definesInterface(in.token)) mods |= Flags.JavaStatic
850918
val decls = memberDecl(start, mods, parentToken, parentTParams)
851919
(if (mods.is(Flags.JavaStatic) || inInterface && !(decls exists (_.isInstanceOf[DefDef])))
852920
statics
@@ -947,13 +1015,13 @@ object JavaParsers {
9471015
}
9481016
}
9491017

950-
def typeDecl(start: Offset, mods: Modifiers): List[Tree] = in.token match {
1018+
def typeDecl(start: Offset, mods: Modifiers): List[Tree] = in.token match
9511019
case ENUM => enumDecl(start, mods)
9521020
case INTERFACE => interfaceDecl(start, mods)
9531021
case AT => annotationDecl(start, mods)
9541022
case CLASS => classDecl(start, mods)
1023+
case RECORD => recordDecl(start, mods)
9551024
case _ => in.nextToken(); syntaxError(em"illegal start of type declaration", skipIt = true); List(errorTypeTree)
956-
}
9571025

9581026
def tryConstant: Option[Constant] = {
9591027
val negate = in.token match {
@@ -1004,6 +1072,7 @@ object JavaParsers {
10041072
if (in.token != EOF) {
10051073
val start = in.offset
10061074
val mods = modifiers(inInterface = false)
1075+
adaptRecordIdentifier() // needed for typeDecl
10071076
buf ++= typeDecl(start, mods)
10081077
}
10091078
}

compiler/src/dotty/tools/dotc/parsing/JavaTokens.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ object JavaTokens extends TokensCommon {
4141
inline val SWITCH = 133; enter(SWITCH, "switch")
4242
inline val ASSERT = 134; enter(ASSERT, "assert")
4343

44+
/** contextual keywords (turned into keywords in certain conditions, see JLS 3.9 of Java 9+) */
45+
inline val RECORD = 135; enter(RECORD, "record")
46+
4447
/** special symbols */
4548
inline val EQEQ = 140
4649
inline val BANGEQ = 141

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,6 @@ class Namer { typer: Typer =>
862862
* with a user-defined method in the same scope with a matching type.
863863
*/
864864
private def invalidateIfClashingSynthetic(denot: SymDenotation): Unit =
865-
866865
def isCaseClassOrCompanion(owner: Symbol) =
867866
owner.isClass && {
868867
if (owner.is(Module)) owner.linkedClass.is(CaseClass)
@@ -879,10 +878,19 @@ class Namer { typer: Typer =>
879878
!sd.symbol.is(Deferred) && sd.matches(denot)))
880879

881880
val isClashingSynthetic =
882-
denot.is(Synthetic, butNot = ConstructorProxy)
883-
&& desugar.isRetractableCaseClassMethodName(denot.name)
884-
&& isCaseClassOrCompanion(denot.owner)
885-
&& (definesMember || inheritsConcreteMember)
881+
denot.is(Synthetic, butNot = ConstructorProxy) &&
882+
(
883+
(desugar.isRetractableCaseClassMethodName(denot.name)
884+
&& isCaseClassOrCompanion(denot.owner)
885+
&& (definesMember || inheritsConcreteMember)
886+
)
887+
||
888+
// remove synthetic constructor of a java Record if it clashes with a non-synthetic constructor
889+
(denot.isConstructor
890+
&& denot.owner.is(JavaDefined) && denot.owner.derivesFrom(defn.JavaRecordClass)
891+
&& denot.owner.unforcedDecls.lookupAll(denot.name).exists(c => c != denot.symbol && c.info.matches(denot.info))
892+
)
893+
)
886894

887895
if isClashingSynthetic then
888896
typr.println(i"invalidating clashing $denot in ${denot.owner}")

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,11 +2441,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24412441
}
24422442

24432443
def typedDefDef(ddef: untpd.DefDef, sym: Symbol)(using Context): Tree = {
2444-
if (!sym.info.exists) { // it's a discarded synthetic case class method, drop it
2445-
assert(sym.is(Synthetic) && desugar.isRetractableCaseClassMethodName(sym.name))
2444+
def canBeInvalidated(sym: Symbol): Boolean =
2445+
sym.is(Synthetic)
2446+
&& (desugar.isRetractableCaseClassMethodName(sym.name) ||
2447+
(sym.isConstructor && sym.owner.derivesFrom(defn.JavaRecordClass)))
2448+
2449+
if !sym.info.exists then
2450+
// it's a discarded method (synthetic case class method or synthetic java record constructor), drop it
2451+
assert(canBeInvalidated(sym))
24462452
sym.owner.info.decls.openForMutations.unlink(sym)
24472453
return EmptyTree
2448-
}
2454+
24492455
// TODO: - Remove this when `scala.language.experimental.erasedDefinitions` is no longer experimental.
24502456
// - Modify signature to `erased def erasedValue[T]: T`
24512457
if sym.eq(defn.Compiletime_erasedValue) then
@@ -3598,7 +3604,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35983604
adapt(tree, pt, ctx.typerState.ownedVars)
35993605

36003606
private def adapt1(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = {
3601-
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported)
3607+
assert(pt.exists && !pt.isInstanceOf[ExprType] || ctx.reporter.errorsReported, i"tree: $tree, pt: $pt")
36023608
def methodStr = err.refStr(methPart(tree).tpe)
36033609

36043610
def readapt(tree: Tree)(using Context) = adapt(tree, pt, locked)

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class CompilationTests {
2929

3030
@Test def pos: Unit = {
3131
implicit val testGroup: TestGroup = TestGroup("compilePos")
32-
aggregateTests(
32+
var tests = List(
3333
compileFile("tests/pos/nullarify.scala", defaultOptions.and("-Ycheck:nullarify")),
3434
compileFile("tests/pos-special/utf8encoded.scala", explicitUTF8),
3535
compileFile("tests/pos-special/utf16encoded.scala", explicitUTF16),
@@ -65,8 +65,13 @@ class CompilationTests {
6565
compileFile("tests/pos-special/extend-java-enum.scala", defaultOptions.and("-source", "3.0-migration")),
6666
compileFile("tests/pos-custom-args/help.scala", defaultOptions.and("-help", "-V", "-W", "-X", "-Y")),
6767
compileFile("tests/pos-custom-args/i13044.scala", defaultOptions.and("-Xmax-inlines:33")),
68-
compileFile("tests/pos-custom-args/jdk-8-app.scala", defaultOptions.and("-release:8")),
69-
).checkCompile()
68+
compileFile("tests/pos-custom-args/jdk-8-app.scala", defaultOptions.and("-release:8"))
69+
)
70+
71+
if scala.util.Properties.isJavaAtLeast("16") then
72+
tests ::= compileFilesInDir("tests/pos-java16+", defaultOptions.and("-Ysafe-init"))
73+
74+
aggregateTests(tests*).checkCompile()
7075
}
7176

7277
@Test def rewrites: Unit = {
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
object C:
2+
def useR1: Unit =
3+
// constructor signature
4+
val r = R1(123, "hello")
5+
6+
// accessors
7+
val i: Int = r.i
8+
val s: String = r.s
9+
10+
// methods
11+
val iRes: Int = r.getInt()
12+
val sRes: String = r.getString()
13+
14+
// supertype
15+
val record: java.lang.Record = r
16+
17+
def useR2: Unit =
18+
// constructor signature
19+
val r2 = R2.R(123, "hello")
20+
21+
// accessors signature
22+
val i: Int = r2.i
23+
val s: String = r2.s
24+
25+
// method
26+
val i2: Int = r2.getInt
27+
28+
// supertype
29+
val isIntLike: IntLike = r2
30+
val isRecord: java.lang.Record = r2
31+
32+
def useR3 =
33+
// constructor signature
34+
val r3 = R3(123, 42L, "hi")
35+
new R3("hi", 123)
36+
// accessors signature
37+
val i: Int = r3.i
38+
val l: Long = r3.l
39+
val s: String = r3.s
40+
// method
41+
val l2: Long = r3.l(43L, 44L)
42+
// supertype
43+
val isRecord: java.lang.Record = r3
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
trait IntLike:
2+
def getInt: Int
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
public record R1(int i, String s) {
2+
public String getString() {
3+
return s + i;
4+
}
5+
6+
public int getInt() {
7+
return 0;
8+
}
9+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
public class R2 {
2+
final record R(int i, String s) implements IntLike {
3+
public int getInt() {
4+
return i;
5+
}
6+
7+
// Canonical constructor
8+
public R(int i, java.lang.String s) {
9+
this.i = i;
10+
this.s = s.intern();
11+
}
12+
}
13+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
public record R3(int i, long l, String s) {
2+
3+
// User-specified accessor
4+
public int i() {
5+
return i + 1; // evil >:)
6+
}
7+
8+
// Not an accessor - too many parameters
9+
public long l(long a1, long a2) {
10+
return a1 + a2;
11+
}
12+
13+
// Secondary constructor
14+
public R3(String s, int i) {
15+
this(i, 42L, s);
16+
}
17+
18+
// Compact constructor
19+
public R3 {
20+
s = s.intern();
21+
}
22+
}

0 commit comments

Comments
 (0)