Skip to content

Commit ab70f18

Browse files
committed
Fix mapping and pickling of non-trivial annotated types
`Annotation.mapWith` maps an `Annotation` with a type map `tm`. Before actually applying `tm` to the annotation’s `tree`, it first checks if `tm` would result in any change by traversing the annotation’s arguments, and checking if the mapped type of any any tree is different. This optimization had two problems: it didn’t include type parameters, and used `frozen_=:=` to compare types, which failed to detected some changes. This commit changes `Annotation.arguments` to also include type parameters, and changes `Annotation.MapWith` to use `==` to compare types instead of `frozen_=:=`. Furthermore, in case of changes, the symbol in the annotation's tree should be copied to make sure that the same symbol is not used for different trees. This commit achieves this by using a custom `TreeTypeMap` with an overridden `withMappedSyms` method where `Symbols.mapSymbols` is called with the argument `mapAlways = true`. Finally, positions of trees that appear only inside `AnnotatedType` were not pickled. This commit also fixes this.
1 parent 6ee0967 commit ab70f18

19 files changed

+180
-15
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
141141
loop(tree, Nil)
142142

143143
/** All term arguments of an application in a single flattened list */
144+
def allTermArguments(tree: Tree): List[Tree] = unsplice(tree) match {
145+
case Apply(fn, args) => allArguments(fn) ::: args
146+
case TypeApply(fn, args) => allArguments(fn)
147+
case Block(_, expr) => allArguments(expr)
148+
case _ => Nil
149+
}
150+
151+
/** All type and term arguments of an application in a single flattened list */
144152
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
145153
case Apply(fn, args) => allArguments(fn) ::: args
146-
case TypeApply(fn, _) => allArguments(fn)
154+
case TypeApply(fn, args) => allArguments(fn) ::: args
147155
case Block(_, expr) => allArguments(expr)
148156
case _ => Nil
149157
}

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

+12-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import Decorators.*
2121
* @param newOwners New owners, replacing previous owners.
2222
* @param substFrom The symbols that need to be substituted.
2323
* @param substTo The substitution targets.
24+
* @param cpy A tree copier that is used to create new trees.
25+
* @param alwaysCopySymbols If set, symbols are always copied, even when they
26+
* are not impacted by the transformation.
2427
*
2528
* The reason the substitution is broken out from the rest of the type map is
2629
* that all symbols have to be substituted at the same time. If we do not do this,
@@ -38,7 +41,9 @@ class TreeTypeMap(
3841
val newOwners: List[Symbol] = Nil,
3942
val substFrom: List[Symbol] = Nil,
4043
val substTo: List[Symbol] = Nil,
41-
cpy: tpd.TreeCopier = tpd.cpy)(using Context) extends tpd.TreeMap(cpy) {
44+
cpy: tpd.TreeCopier = tpd.cpy,
45+
alwaysCopySymbols: Boolean = false,
46+
)(using Context) extends tpd.TreeMap(cpy) {
4247
import tpd.*
4348

4449
def copy(
@@ -48,7 +53,7 @@ class TreeTypeMap(
4853
newOwners: List[Symbol],
4954
substFrom: List[Symbol],
5055
substTo: List[Symbol])(using Context): TreeTypeMap =
51-
new TreeTypeMap(typeMap, treeMap, oldOwners, newOwners, substFrom, substTo)
56+
new TreeTypeMap(typeMap, treeMap, oldOwners, newOwners, substFrom, substTo, cpy, alwaysCopySymbols)
5257

5358
/** If `sym` is one of `oldOwners`, replace by corresponding symbol in `newOwners` */
5459
def mapOwner(sym: Symbol): Symbol = sym.subst(oldOwners, newOwners)
@@ -207,7 +212,7 @@ class TreeTypeMap(
207212
* between original and mapped symbols.
208213
*/
209214
def withMappedSyms(syms: List[Symbol]): TreeTypeMap =
210-
withMappedSyms(syms, mapSymbols(syms, this))
215+
withMappedSyms(syms, mapSymbols(syms, this, mapAlways = alwaysCopySymbols))
211216

212217
/** The tree map with the substitution between originals `syms`
213218
* and mapped symbols `mapped`. Also goes into mapped classes
@@ -229,6 +234,10 @@ class TreeTypeMap(
229234
tmap1
230235
}
231236

237+
def withAlwaysCopySymbols: TreeTypeMap =
238+
if alwaysCopySymbols then this
239+
else new TreeTypeMap(typeMap, treeMap, oldOwners, newOwners, substFrom, substTo, cpy, alwaysCopySymbols = true)
240+
232241
override def toString =
233242
def showSyms(syms: List[Symbol]) =
234243
syms.map(sym => s"$sym#${sym.id}").mkString(", ")

compiler/src/dotty/tools/dotc/core/Annotations.scala

+16-7
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ package dotc
33
package core
44

55
import Symbols.*, Types.*, Contexts.*, Constants.*, Phases.*
6-
import ast.tpd, tpd.*
7-
import util.Spans.Span
6+
import ast.{tpd, untpd, TreeTypeMap}
7+
import tpd.*
8+
import util.Spans.{Span, NoSpan}
89
import printing.{Showable, Printer}
910
import printing.Texts.Text
1011

@@ -30,8 +31,8 @@ object Annotations {
3031
def derivedAnnotation(tree: Tree)(using Context): Annotation =
3132
if (tree eq this.tree) this else Annotation(tree)
3233

33-
/** All arguments to this annotation in a single flat list */
34-
def arguments(using Context): List[Tree] = tpd.allArguments(tree)
34+
/** All term arguments of this annotation in a single flat list */
35+
def arguments(using Context): List[Tree] = tpd.allTermArguments(tree)
3536

3637
def argument(i: Int)(using Context): Option[Tree] = {
3738
val args = arguments
@@ -54,18 +55,26 @@ object Annotations {
5455
* type, since ranges cannot be types of trees.
5556
*/
5657
def mapWith(tm: TypeMap)(using Context) =
57-
val args = arguments
58+
val args = tpd.allArguments(tree)
5859
if args.isEmpty then this
5960
else
61+
// Checks if `tm` would result in any change by applying it to types
62+
// inside the annotations' arguments and checking if the resulting types
63+
// are different.
6064
val findDiff = new TreeAccumulator[Type]:
6165
def apply(x: Type, tree: Tree)(using Context): Type =
6266
if tm.isRange(x) then x
6367
else
6468
val tp1 = tm(tree.tpe)
65-
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
69+
foldOver(if !tp1.exists || (tp1 frozen_=:= tree.tpe) then x else tp1, tree)
6670
val diff = findDiff(NoType, args)
6771
if tm.isRange(diff) then EmptyAnnotation
68-
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
72+
else if diff.exists then
73+
// If the annotation has been transformed, we need to make sure that the
74+
// symbol are copied so that we don't end up with the same symbol in different
75+
// trees, which would lead to a crash in pickling.
76+
val mappedTree = TreeTypeMap(typeMap = tm, alwaysCopySymbols = true).transform(tree)
77+
derivedAnnotation(mappedTree)
6978
else this
7079

7180
/** Does this annotation refer to a parameter of `tl`? */

compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ object PositionPickler:
3333
pickler: TastyPickler,
3434
addrOfTree: TreeToAddr,
3535
treeAnnots: untpd.MemberDef => List[tpd.Tree],
36+
typeAnnots: List[tpd.Tree],
3637
relativePathReference: String,
3738
source: SourceFile,
3839
roots: List[Tree],
@@ -136,6 +137,9 @@ object PositionPickler:
136137
}
137138
for (root <- roots)
138139
traverse(root, NoSource)
140+
141+
for annotTree <- typeAnnots do
142+
traverse(annotTree, NoSource)
139143
end picklePositions
140144
end PositionPickler
141145

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+7
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
4141
*/
4242
private val annotTrees = util.EqHashMap[untpd.MemberDef, mutable.ListBuffer[Tree]]()
4343

44+
/** A set of annotation trees appearing in annotated types.
45+
*/
46+
private val annotatedTypeTrees = mutable.ListBuffer[Tree]()
47+
4448
/** A map from member definitions to their doc comments, so that later
4549
* parallel comment pickling does not need to access symbols of trees (which
4650
* would involve accessing symbols of named types and possibly changing phases
@@ -57,6 +61,8 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
5761
val ts = annotTrees.lookup(tree)
5862
if ts == null then Nil else ts.toList
5963

64+
def typeAnnots: List[Tree] = annotatedTypeTrees.toList
65+
6066
def docString(tree: untpd.MemberDef): Option[Comment] =
6167
Option(docStrings.lookup(tree))
6268

@@ -278,6 +284,7 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
278284
case tpe: AnnotatedType =>
279285
writeByte(ANNOTATEDtype)
280286
withLength { pickleType(tpe.parent, richTypes); pickleTree(tpe.annot.tree) }
287+
annotatedTypeTrees += tpe.annot.tree
281288
case tpe: AndType =>
282289
writeByte(ANDtype)
283290
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }

compiler/src/dotty/tools/dotc/quoted/PickledQuotes.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ object PickledQuotes {
224224
if tree.span.exists then
225225
val positionWarnings = new mutable.ListBuffer[Message]()
226226
val reference = ctx.settings.sourceroot.value
227-
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
227+
PositionPickler.picklePositions(pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
228228
ctx.compilationUnit.source, tree :: Nil, positionWarnings)
229229
positionWarnings.foreach(report.warning(_))
230230

compiler/src/dotty/tools/dotc/transform/Pickler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ class Pickler extends Phase {
322322
if tree.span.exists then
323323
val reference = ctx.settings.sourceroot.value
324324
PositionPickler.picklePositions(
325-
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, reference,
325+
pickler, treePkl.buf.addrOfTree, treePkl.treeAnnots, treePkl.typeAnnots, reference,
326326
unit.source, tree :: Nil, positionWarnings,
327327
scratch.positionBuffer, scratch.pickledIndices)
328328

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

+9-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package dotty.tools
22
package dotc
33
package transform
44

5-
import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar}
5+
import dotty.tools.dotc.ast.{Trees, TreeTypeMap, tpd, untpd, desugar}
66
import scala.collection.mutable
77
import core.*
88
import dotty.tools.dotc.typer.Checking
@@ -158,7 +158,14 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
158158
val saved = inJavaAnnot
159159
inJavaAnnot = annot.symbol.is(JavaDefined)
160160
if (inJavaAnnot) checkValidJavaAnnotation(annot)
161-
try transform(annot)
161+
try
162+
val res = transform(annot)
163+
if res ne annot then
164+
// If the annotation has been transformed, we need to make sure that the
165+
// symbol are copied so that we don't end up with the same symbol in different
166+
// trees, which would lead to a crash in pickling.
167+
TreeTypeMap(alwaysCopySymbols = true)(res)
168+
else res
162169
finally inJavaAnnot = saved
163170
}
164171

tests/pos/annot-17939.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation
2+
3+
class Box[T](val x: T)
4+
class Box2(val x: Int)
5+
6+
class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash
7+
class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works

tests/pos/annot-17939b.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.annotation.Annotation
2+
class myRefined(f: ? => Boolean) extends Annotation
3+
4+
def test(axes: Int) = true
5+
6+
trait Tensor:
7+
def mean(axes: Int): Int @myRefined(_ => test(axes))
8+
9+
class TensorImpl() extends Tensor:
10+
def mean(axes: Int) = ???

tests/pos/annot-17939c.scala

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class qualified(f: Int => Boolean) extends annotation.StaticAnnotation
2+
class Box[T](val y: T)
3+
def Test =
4+
val x: String @qualified((x: Int) => Box(42).y == 2) = ???
5+
val y = x

tests/pos/annot-18064.scala

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//> using options "-Xprint:typer"
2+
3+
class myAnnot[T]() extends annotation.Annotation
4+
5+
trait Tensor[T]:
6+
def add: Tensor[T] @myAnnot[T]()
7+
8+
class TensorImpl[A]() extends Tensor[A]:
9+
def add /* : Tensor[A] @myAnnot[A] */ = this

tests/pos/annot-19846.scala

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation
2+
3+
class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))
4+
5+
@main def main =
6+
val p = EqualPair(42, 42)
7+
val y = p.y
8+
println(42)

tests/pos/annot-19846b.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation
2+
3+
def f(x: Int): Int @lambdaAnnot(() => x) = x
4+
5+
object Test:
6+
val y: Int = ???
7+
val z /* : Int @lambdaAnnot(() => y) */ = f(y)

tests/pos/annot-5789.scala

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class Annot[T] extends scala.annotation.Annotation
2+
3+
class D[T](val f: Int@Annot[T])
4+
5+
object A{
6+
def main(a:Array[String]) = {
7+
val c = new D[Int](1)
8+
c.f
9+
}
10+
}

tests/printing/annot-18064.check

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[[syntax trees at end of typer]] // tests/printing/annot-18064.scala
2+
package <empty> {
3+
class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() {
4+
T
5+
}
6+
trait Tensor[T >: Nothing <: Any]() extends Object {
7+
T
8+
def add: Tensor[Tensor.this.T] @myAnnot[T]
9+
}
10+
class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[
11+
TensorImpl.this.A] {
12+
A
13+
def add: Tensor[A] @myAnnot[A] = this
14+
}
15+
}
16+

tests/printing/annot-18064.scala

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//> using options "-Xprint:typer"
2+
3+
class myAnnot[T]() extends annotation.Annotation
4+
5+
trait Tensor[T]:
6+
def add: Tensor[T] @myAnnot[T]()
7+
8+
class TensorImpl[A]() extends Tensor[A]:
9+
def add /* : Tensor[A] @myAnnot[A] */ = this

tests/printing/annot-19846b.check

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala
2+
package <empty> {
3+
class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(),
4+
annotation.StaticAnnotation {
5+
private[this] val g: () => Int
6+
}
7+
final lazy module val Test: Test = new Test()
8+
final module class Test() extends Object() { this: Test.type =>
9+
val y: Int = ???
10+
val z:
11+
Int @lambdaAnnot(
12+
{
13+
def $anonfun(): Int = Test.y
14+
closure($anonfun)
15+
}
16+
)
17+
= f(Test.y)
18+
}
19+
final lazy module val annot-19846b$package: annot-19846b$package =
20+
new annot-19846b$package()
21+
final module class annot-19846b$package() extends Object() {
22+
this: annot-19846b$package.type =>
23+
def f(x: Int):
24+
Int @lambdaAnnot(
25+
{
26+
def $anonfun(): Int = x
27+
closure($anonfun)
28+
}
29+
)
30+
= x
31+
}
32+
}
33+

tests/printing/annot-19846b.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation
2+
3+
def f(x: Int): Int @lambdaAnnot(() => x) = x
4+
5+
object Test:
6+
val y: Int = ???
7+
val z /* : Int @lambdaAnnot(() => y) */ = f(y)

0 commit comments

Comments
 (0)