Skip to content

Commit 3bd7df2

Browse files
committed
[WIP] Fix mapping of annotations containing defs
1 parent d148973 commit 3bd7df2

File tree

8 files changed

+94
-5
lines changed

8 files changed

+94
-5
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
134134
case _ => argss
135135
loop(tree, Nil)
136136

137-
/** All term arguments of an application in a single flattened list */
137+
/** All type and term arguments of an application in a single flattened list */
138138
def allArguments(tree: Tree): List[Tree] = unsplice(tree) match {
139139
case Apply(fn, args) => allArguments(fn) ::: args
140-
case TypeApply(fn, _) => allArguments(fn)
140+
case TypeApply(fn, args) => allArguments(fn) ::: args
141141
case Block(_, expr) => allArguments(expr)
142142
case _ => Nil
143143
}

compiler/src/dotty/tools/dotc/core/Annotations.scala

+14-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import printing.{Showable, Printer}
99
import printing.Texts.Text
1010

1111
import scala.annotation.internal.sharable
12+
import dotty.tools.dotc.ast.TreeTypeMap
1213

1314
object Annotations {
1415

@@ -30,7 +31,7 @@ object Annotations {
3031
def derivedAnnotation(tree: Tree)(using Context): Annotation =
3132
if (tree eq this.tree) this else Annotation(tree)
3233

33-
/** All arguments to this annotation in a single flat list */
34+
/** All type and term arguments to this annotation in a single flat list */
3435
def arguments(using Context): List[Tree] = tpd.allArguments(tree)
3536

3637
def argument(i: Int)(using Context): Option[Tree] = {
@@ -62,10 +63,20 @@ object Annotations {
6263
if tm.isRange(x) then x
6364
else
6465
val tp1 = tm(tree.tpe)
65-
foldOver(if tp1 frozen_=:= tree.tpe then x else tp1, tree)
66+
foldOver(if tp1 == tree.tpe then x else tp1, tree)
6667
val diff = findDiff(NoType, args)
6768
if tm.isRange(diff) then EmptyAnnotation
68-
else if diff.exists then derivedAnnotation(tm.mapOver(tree))
69+
else if diff.exists then
70+
val ttm = new TreeTypeMap(typeMap = tm):
71+
/*
72+
final override def transformDefs[TT <: Tree](trees: List[TT])(using Context): (TreeTypeMap, List[TT]) =
73+
val syms = localSyms(trees)
74+
val ttmap = withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true))
75+
(ttmap, ttmap.transformSub(trees))
76+
*/
77+
final override def withMappedSyms(syms: List[Symbol]): TreeTypeMap =
78+
withMappedSyms(syms, mapSymbols(syms, this, mapAlways = true))
79+
derivedAnnotation(ttm.transform(tree))
6980
else this
7081

7182
/** Does this annotation refer to a parameter of `tl`? */

tests/pos/annot-17939.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class qualified[T](f: T => Boolean) extends annotation.StaticAnnotation
2+
3+
class Box[T](val x: T)
4+
class Box2(val x: Int)
5+
6+
class A(a: String @qualified((x: Int) => Box(3).x == 3)) // crash
7+
class A2(a2: String @qualified((x: Int) => Box2(3).x == 3)) // works

tests/pos/annot-19846.scala

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation
2+
3+
class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))
4+
5+
@main def main =
6+
val p = EqualPair(42, 42)
7+
val y = p.y
8+
println(42)

tests/printing/annot-18064.check

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[[syntax trees at end of typer]] // tests/printing/annot-18064.scala
2+
package <empty> {
3+
class myAnnot[T >: Nothing <: Any]() extends annotation.Annotation() {
4+
T
5+
}
6+
trait Tensor[T >: Nothing <: Any]() extends Object {
7+
T
8+
def add: Tensor[Tensor.this.T] @myAnnot[T]
9+
}
10+
class TensorImpl[A >: Nothing <: Any]() extends Object(), Tensor[
11+
TensorImpl.this.A] {
12+
A
13+
def add: Tensor[A] @myAnnot[A] = this
14+
}
15+
}
16+

tests/printing/annot-18064.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class myAnnot[T]() extends annotation.Annotation
2+
3+
trait Tensor[T]:
4+
def add: Tensor[T] @myAnnot[T]()
5+
6+
class TensorImpl[A]() extends Tensor[A]:
7+
def add /* : Tensor[A] @myAnnot[A] */ = this

tests/printing/annot-19846b.check

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
[[syntax trees at end of typer]] // tests/printing/annot-19846b.scala
2+
package <empty> {
3+
class lambdaAnnot(g: () => Int) extends scala.annotation.Annotation(),
4+
annotation.StaticAnnotation {
5+
private[this] val g: () => Int
6+
}
7+
final lazy module val Test: Test = new Test()
8+
final module class Test() extends Object() { this: Test.type =>
9+
val y: Int = ???
10+
val z:
11+
Int @lambdaAnnot(
12+
{
13+
def $anonfun(): Int = Test.y
14+
closure($anonfun)
15+
}
16+
)
17+
= f(Test.y)
18+
}
19+
final lazy module val annot-19846b$package: annot-19846b$package =
20+
new annot-19846b$package()
21+
final module class annot-19846b$package() extends Object() {
22+
this: annot-19846b$package.type =>
23+
def f(x: Int):
24+
Int @lambdaAnnot(
25+
{
26+
def $anonfun(): Int = x
27+
closure($anonfun)
28+
}
29+
)
30+
= x
31+
}
32+
}
33+

tests/printing/annot-19846b.scala

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation
2+
3+
def f(x: Int): Int @lambdaAnnot(() => x) = x
4+
5+
object Test:
6+
val y: Int = ???
7+
val z /* : Int @lambdaAnnot(() => y) */ = f(y)

0 commit comments

Comments
 (0)