Skip to content

Commit 131f2e3

Browse files
committed
Fix mapping and pickling of annotated types
`Annotation.mapWith` maps an `Annotation` with a type map `tm`. Before actually applying `tm` to the annotation’s `tree`, it first checks if `tm` would result in any change by applying it to the types of the annotation’s arguments, and checking if the mapped types are different. This optimization had two problems: it didn’t include type parameters, and used `frozen_=:=` to compare types, which failed to detected some changes. This commit changes `Annotation.arguments` to also include type parameters, and, and changes `Annotation.MapWith` to use `==` to compare types instead of `frozen_=:=`. Furthermore, in case of changes, the symbol in the annotation's tree should be copied to make sure that the same symbol is not used for different trees. This commit achieves this by using a custom `TreeTypeMap` with an overridden `withMappedSyms` method where `Symbols.mapSymbols` is called with the argument `mapAlways = true`. Finally, positons of trees that appear inside `AnnotatedType` only were not pickled. This commit also fixes this.
1 parent d148973 commit 131f2e3

15 files changed

+134
-9
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
134134
case _ => argss
135135
loop(tree, Nil)
136136

137-
/** All term arguments of an application in a single flattened list */
137+
/** All type and term arguments of an application in a single flattened list */
138138
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
139139
case Apply(fn, args) => allArguments(fn) ::: args
140-
case TypeApply(fn, _) => allArguments(fn)
140+
case TypeApply(fn, args) => allArguments(fn) ::: args
141141
case Block(_, expr) => allArguments(expr)
142142
case _ => Nil
143143
}

compiler/src/dotty/tools/dotc/core/Annotations.scala

+14-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ package dotc
33
package core
44

55
import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.*
6-
import ast.tpd, tpd.*
7-
import util.Spans.Span
6+
import ast.{tpd, untpd, TreeTypeMap}
7+
import tpd.*
8+
import util.Spans.{Span, NoSpan}
89
import printing.{Showable, Printer}
910
import printing.Texts.Text
1011

@@ -30,7 +31,7 @@ object Annotations {
3031
def derivedAnnotation(tree: Tree)(using Context): Annotation =
3132
if (tree eq this.tree) this else Annotation(tree)
3233

33-
/** All arguments to this annotation in a single flat list */
34+
/** All type and term arguments to this annotation in a single flat list */
3435
def arguments(using Context): List[Tree] = tpd.allArguments(tree)
3536

3637
def argument(i: Int)(using Context): Option[Tree] = {
@@ -57,15 +58,23 @@ object Annotations {
5758
val args = arguments
5859
if args.isEmpty then this
5960
else
61+
// Checks if tm would result in any change by applying on the annotations's argument and checking if the resulting types are different.
6062
val findDiff = new TreeAccumulator[Type]:
6163
def apply(x: Type, tree: Tree)(using Context): Type =
6264
if tm.isRange(x) then x
6365
else
6466
val tp1 = tm(tree.tpe)
65-
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
67+
foldOver(if tp1 == tree.tpe then x else tp1, tree)
6668
val diff = findDiff(NoType, args)
6769
if tm.isRange(diff) then EmptyAnnotation
68-
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
70+
else if diff.exists then
71+
// In case of changes, the symbol in the annotation's tree should be
72+
// copied so that the same symbol is not used for different trees.
73+
val ttm =
74+
new TreeTypeMap(typeMap = tm):
75+
final override def withMappedSyms(syms: List[Symbol]): TreeTypeMap =
76+
withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true))
77+
derivedAnnotation(ttm.transform(tree))
6978
else this
7079

7180
/** Does this annotation refer to a parameter of `tl`? */

compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object PositionPickler:
3333
pickler: TastyPickler,
3434
addrOfTree: TreeToAddr,
3535
treeAnnots: untpd.MemberDef => List[tpd.Tree],
36+
typeAnnots: List[tpd.Tree],
3637
relativePathReference: String,
3738
source: SourceFile,
3839
roots: List[Tree],
@@ -136,6 +137,9 @@ object PositionPickler:
136137
}
137138
for (root <- roots)
138139
traverse(root, NoSource)
140+
141+
for annotTree <- typeAnnots do
142+
traverse(annotTree, NoSource)
139143
end picklePositions
140144
end PositionPickler
141145

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
4040
*/
4141
private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]()
4242

43+
/** A set of annotation trees appearing in annotated types.
44+
*/
45+
private val annotatedTypeTrees = mutable.ListBuffer[Tree]()
46+
4347
/** A map from member definitions to their doc comments, so that later
4448
* parallel comment pickling does not need to access symbols of trees (which
4549
* would involve accessing symbols of named types and possibly changing phases
@@ -56,6 +60,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
5660
val ts = annotTrees.lookup(tree)
5761
if ts == null then Nil else ts.toList
5862

63+
def typeAnnots: List[Tree] = annotatedTypeTrees.toList
64+
5965
def docString(tree: untpd.MemberDef): Option[Comment] =
6066
Option(docStrings.lookup(tree))
6167

@@ -266,6 +272,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
266272
case tpe: AnnotatedType =>
267273
writeByte(ANNOTATEDtype)
268274
withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) }
275+
annotatedTypeTrees += tpe.annot.tree
269276
case tpe: AndType =>
270277
writeByte(ANDtype)
271278
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }

compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ object PickledQuotes {
224224
if tree.span.exists then
225225
val positionWarnings = new mutable.ListBuffer[Message]()
226226
val reference = ctx.settings.sourceroot.value
227-
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
227+
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
228228
ctx.compilationUnit.source, tree :: Nil, positionWarnings)
229229
positionWarnings.foreach(report.warning(_))
230230

compiler/src/dotty/tools/dotc/transform/Pickler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class Pickler extends Phase {
143143
if tree.span.exists then
144144
val reference = ctx.settings.sourceroot.value
145145
PositionPickler.picklePositions(
146-
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
146+
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
147147
unit.source, tree :: Nil, positionWarnings,
148148
scratch.positionBuffer, scratch.pickledIndices)
149149

tests/pos/annot-17939.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation
2+
3+
class Box[T](val x: T)
4+
class Box2(val x: Int)
5+
6+
class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash
7+
class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works

tests/pos/annot-17939b.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.annotation.Annotation
2+
class myRefined(f: ? => Boolean) extends Annotation
3+
4+
def test(axes: Int) = true
5+
6+
trait Tensor:
7+
def mean(axes: Int): Int @myRefined(_ => test(axes))
8+
9+
class TensorImpl() extends Tensor:
10+
def mean(axes: Int) = ???

tests/pos/annot-19846.scala

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation
2+
3+
class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))
4+
5+
@main def main =
6+
val p = EqualPair(42, 42)
7+
val y = p.y
8+
println(42)

tests/pos/annot-19846b.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation
2+
3+
def f(x: Int): Int @qualified[Int](it => it == x) = ???
4+
5+
@main def main =
6+
val z = f(42)
7+
()

tests/pos/annot-5789.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class Annot[T] extends scala.annotation.Annotation
2+
3+
class D[T](val f: Int@Annot[T])
4+
5+
object A{
6+
def main(a:Array[String]) = {
7+
val c = new D[Int](1)
8+
c.f
9+
}
10+
}

tests/printing/annot-18064.check

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[[syntax trees at end of typer]] // tests/printing/annot-18064.scala
2+
package <empty> {
3+
class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() {
4+
T
5+
}
6+
trait Tensor[T >: Nothing <: Any]() extends Object {
7+
T
8+
def add: Tensor[Tensor.this.T] @myAnnot[T]
9+
}
10+
class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[
11+
TensorImpl.this.A] {
12+
A
13+
def add: Tensor[A] @myAnnot[A] = this
14+
}
15+
}
16+

tests/printing/annot-18064.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class myAnnot[T]() extends annotation.Annotation
2+
3+
trait Tensor[T]:
4+
def add: Tensor[T] @myAnnot[T]()
5+
6+
class TensorImpl[A]() extends Tensor[A]:
7+
def add /* : Tensor[A] @myAnnot[A] */ = this

tests/printing/annot-19846b.check

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala
2+
package <empty> {
3+
class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(),
4+
annotation.StaticAnnotation {
5+
private[this] val g: () => Int
6+
}
7+
final lazy module val Test: Test = new Test()
8+
final module class Test() extends Object() { this: Test.type =>
9+
val y: Int = ???
10+
val z:
11+
Int @lambdaAnnot(
12+
{
13+
def $anonfun(): Int = Test.y
14+
closure($anonfun)
15+
}
16+
)
17+
= f(Test.y)
18+
}
19+
final lazy module val annot-19846b$package: annot-19846b$package =
20+
new annot-19846b$package()
21+
final module class annot-19846b$package() extends Object() {
22+
this: annot-19846b$package.type =>
23+
def f(x: Int):
24+
Int @lambdaAnnot(
25+
{
26+
def $anonfun(): Int = x
27+
closure($anonfun)
28+
}
29+
)
30+
= x
31+
}
32+
}
33+

tests/printing/annot-19846b.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation
2+
3+
def f(x: Int): Int @lambdaAnnot(() => x) = x
4+
5+
object Test:
6+
val y: Int = ???
7+
val z /* : Int @lambdaAnnot(() => y) */ = f(y)

0 commit comments

Comments
 (0)