Skip to content

Backport "Make sure symbols in annotation trees are fresh before pickling" to 3.3 LTS #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package dotty.tools
package dotc
package transform

import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar}
import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar, TreeTypeMap}
import scala.collection.mutable
import core.*
import dotty.tools.dotc.typer.Checking
Expand All @@ -16,7 +16,7 @@ import Symbols.*, NameOps.*
import ContextFunctionResults.annotateContextResults
import config.Printers.typr
import config.Feature
import util.SrcPos
import util.{SrcPos, Stats}
import reporting.*
import NameKinds.WildcardParamName

Expand Down Expand Up @@ -132,7 +132,21 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
case _ =>
case _ =>

private def transformAnnot(annot: Tree)(using Context): Tree = {
/** Returns a copy of the given tree with all symbols fresh.
*
* Used to guarantee that no symbols are shared between trees in different
* annotations.
*/
private def copySymbols(tree: Tree)(using Context) =
Stats.trackTime("Annotations copySymbols"):
val ttm =
new TreeTypeMap:
override def withMappedSyms(syms: List[Symbol]) =
withMappedSyms(syms, mapSymbols(syms, this, true))
ttm(tree)

/** Transforms the given annotation tree. */
private def transformAnnotTree(annot: Tree)(using Context): Tree = {
val saved = inJavaAnnot
inJavaAnnot = annot.symbol.is(JavaDefined)
if (inJavaAnnot) checkValidJavaAnnotation(annot)
Expand All @@ -141,7 +155,19 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
}

private def transformAnnot(annot: Annotation)(using Context): Annotation =
annot.derivedAnnotation(transformAnnot(annot.tree))
val tree1 =
annot match
case _: BodyAnnotation => annot.tree
case _ => copySymbols(annot.tree)
annot.derivedAnnotation(transformAnnotTree(tree1))

/** Transforms all annotations in the given type. */
private def transformAnnotsIn(using Context) =
new TypeMap:
def apply(tp: Type) = tp match
case tp @ AnnotatedType(parent, annot) =>
tp.derivedAnnotatedType(mapOver(parent), transformAnnot(annot))
case _ => mapOver(tp)

private def processMemberDef(tree: Tree)(using Context): tree.type = {
val sym = tree.symbol
Expand Down Expand Up @@ -438,7 +464,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
Checking.checkRealizable(tree.tpt.tpe, tree.srcPos, "SAM type")
super.transform(tree)
case tree @ Annotated(annotated, annot) =>
cpy.Annotated(tree)(transform(annotated), transformAnnot(annot))
cpy.Annotated(tree)(transform(annotated), transformAnnotTree(annot))
case tree: AppliedTypeTree =>
if (tree.tpt.symbol == defn.andType)
Checking.checkNonCyclicInherited(tree.tpe, tree.args.tpes, EmptyScope, tree.srcPos)
Expand All @@ -460,12 +486,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
report.error(em"type ${alias.tpe} outside bounds $bounds", tree.srcPos)
super.transform(tree)
case tree: TypeTree =>
tree.withType(
tree.tpe match {
case AnnotatedType(tpe, annot) => AnnotatedType(tpe, transformAnnot(annot))
case tpe => tpe
}
)
tree.withType(transformAnnotsIn(tree.tpe))
case Typed(Ident(nme.WILDCARD), _) =>
withMode(Mode.Pattern)(super.transform(tree))
// The added mode signals that bounds in a pattern need not
Expand Down
8 changes: 8 additions & 0 deletions tests/pos/annot-17939.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.annotation.Annotation
class myRefined[T](f: T => Boolean) extends Annotation

class Box[T](val x: T)
class Box2(val x: Int)

class A(a: String @myRefined((x: Int) => Box(3).x == 3)) // crash
class A2(a2: String @myRefined((x: Int) => Box2(3).x == 3)) // works
9 changes: 9 additions & 0 deletions tests/pos/annot-19846.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package dependentAnnotation

class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation

def f(x: Int): Int @lambdaAnnot(() => x + 1) = x

@main def main =
val y: Int = 5
val z = f(y)
8 changes: 8 additions & 0 deletions tests/pos/annot-19846b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation

class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))

@main def main =
val p = EqualPair(42, 42)
val y = p.y
println(42)
15 changes: 15 additions & 0 deletions tests/pos/annot-body.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// This test checks that symbols in `BodyAnnotation` are not copied in
// `transformAnnot` during `PostTyper`.

package json

trait Reads[A] {
def reads(a: Any): A
}

object JsMacroImpl {
inline def reads[A]: Reads[A] =
new Reads[A] { self =>
def reads(a: Any) = ???
}
}