Skip to content

Commit fdb14f2

Browse files
nicolasstuckiNicolas Stucki
authored and
Nicolas Stucki
committed
Enable returning classes from MacroAnnotations (part 3)
Enable the addition of classes from a `MacroAnnotation`: * Can add new `class` definitions next to the annotated definition Special cases: * An annotated top-level `def`, `val`, `var`, `lazy val` can return a `class` definition that is owned by the package or package object. Related PRs: * Follows #16454
1 parent 2916c3d commit fdb14f2

File tree

18 files changed

+261
-6
lines changed

18 files changed

+261
-6
lines changed

compiler/src/dotty/tools/dotc/transform/Inlining.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
6363
}
6464

6565
private class InliningTreeMap extends TreeMapWithImplicits {
66+
67+
/** List of top level classes added by macro annotation in a package object.
68+
* These are added the PackageDef that owns this particular package object.
69+
*/
70+
private var topClasses = List.empty[Tree]
71+
6672
override def transform(tree: Tree)(using Context): Tree = {
6773
tree match
6874
case tree: MemberDef =>
@@ -74,7 +80,15 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
7480
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
7581
then
7682
val trees = new MacroAnnotations(thisPhase).expandAnnotations(tree)
77-
flatTree(trees.map(super.transform))
83+
val trees1 = trees.map(super.transform)
84+
85+
// Find classes added to the top level from a package object
86+
val (topClasses0, trees2) =
87+
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
88+
else (Nil, trees1)
89+
topClasses = topClasses0
90+
91+
flatTree(trees2)
7892
else super.transform(tree)
7993
case _: Typed | _: Block =>
8094
super.transform(tree)
@@ -86,6 +100,13 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
86100
super.transform(tree)(using StagingContext.quoteContext)
87101
case _: GenericApply if tree.symbol.isExprSplice =>
88102
super.transform(tree)(using StagingContext.spliceContext)
103+
case _: PackageDef =>
104+
super.transform(tree) match
105+
case tree1: PackageDef if topClasses.nonEmpty =>
106+
val newStats = topClasses ::: tree1.stats
107+
topClasses = Nil
108+
cpy.PackageDef(tree1)(tree1.pid, newStats)
109+
case tree1 => tree1
89110
case _ =>
90111
super.transform(tree)
91112
}

compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,11 @@ class MacroAnnotations(thisPhase: DenotTransformer):
9191
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])
9292

9393
/** Check that this tree can be added by the macro annotation and enter it if needed */
94-
private def checkAndEnter(newTree: Tree, annotated: Symbol, annot: Annotation)(using Context) =
94+
private def checkAndEnter(newTree: DefTree, annotated: Symbol, annot: Annotation)(using Context) =
9595
val sym = newTree.symbol
96-
if sym.isClass then
97-
report.error(i"macro annotation returning a `class` is not yet supported. $annot tried to add $sym", annot.tree)
98-
else if sym.isType then
96+
if sym.isType && !sym.isClass then
9997
report.error(i"macro annotation cannot return a `type`. $annot tried to add $sym", annot.tree)
100-
else if sym.owner != annotated.owner then
98+
else if sym.owner != annotated.owner && !(annotated.owner.isPackageObject && sym.isClass && sym.owner == annotated.owner.owner) then
10199
report.error(i"macro annotation $annot added $sym with an inconsistent owner. Expected it to be owned by ${annotated.owner} but was owned by ${sym.owner}.", annot.tree)
102100
else if annotated.isClass && annotated.owner.is(Package) /*&& !sym.isClass*/ then
103101
report.error(i"macro annotation can not add top-level ${sym.showKind}. $annot tried to add $sym.", annot.tree)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.annotation.{experimental, MacroAnnotation}
2+
import scala.quoted._
3+
import scala.collection.mutable
4+
5+
@experimental
6+
// Assumes annotation is on top level def or val
7+
class addTopLevelMethodOutsidePackageObject extends MacroAnnotation:
8+
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
9+
import quotes.reflect._
10+
val methType = MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Int])
11+
val methSym = Symbol.newMethod(Symbol.spliceOwner.owner, Symbol.freshName("toLevelMethod"), methType, Flags.EmptyFlags, Symbol.noSymbol)
12+
val methDef = ValDef(methSym, Some(Literal(IntConstant(1))))
13+
List(methDef, tree)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@addTopLevelMethodOutsidePackageObject // error
2+
def foo = 1
3+
4+
@addTopLevelMethodOutsidePackageObject // error
5+
val bar = 1
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import scala.annotation.{experimental, MacroAnnotation}
2+
import scala.quoted._
3+
import scala.collection.mutable
4+
5+
@experimental
6+
// Assumes annotation is on top level def or val
7+
class addTopLevelValOutsidePackageObject extends MacroAnnotation:
8+
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
9+
import quotes.reflect._
10+
val valSym = Symbol.newVal(Symbol.spliceOwner.owner, Symbol.freshName("toLevelVal"), TypeRepr.of[Int], Flags.EmptyFlags, Symbol.noSymbol)
11+
val valDef = ValDef(valSym, Some(Literal(IntConstant(1))))
12+
List(valDef, tree)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
@addTopLevelValOutsidePackageObject // error
2+
def foo = 1
3+
4+
@addTopLevelValOutsidePackageObject // error
5+
val bar = 1
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
macro generated main
2+
executed in: Test_2$package$Baz$1
3+
macro generated main
4+
executed in: Test_2$package$Baz$2
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import scala.annotation.{experimental, MacroAnnotation}
2+
import scala.quoted._
3+
import scala.collection.mutable
4+
5+
@experimental
6+
class addClass extends MacroAnnotation:
7+
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
8+
import quotes.reflect._
9+
tree match
10+
case DefDef(name, List(TermParamClause(Nil)), tpt, Some(rhs)) =>
11+
val parents = List(TypeTree.of[Object])
12+
def decls(cls: Symbol): List[Symbol] =
13+
List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.Static, Symbol.noSymbol))
14+
15+
val cls = Symbol.newClass(Symbol.spliceOwner, "Baz", parents = parents.map(_.tpe), decls, selfType = None)
16+
val runSym = cls.declaredMethod("run").head
17+
18+
val runDef = DefDef(runSym, _ => Some(rhs))
19+
val clsDef = ClassDef(cls, parents, body = List(runDef))
20+
21+
val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)
22+
23+
val newDef = DefDef.copy(tree)(name, List(TermParamClause(Nil)), tpt, Some(Apply(Select(newCls, runSym), Nil)))
24+
List(clsDef, newDef)
25+
case _ =>
26+
report.error("Annotation only supports `def` with one argument")
27+
List(tree)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@main def Test(): Unit =
2+
@addClass def foo(): Unit =
3+
println("macro generated main")
4+
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
5+
//> class Baz extends Object {
6+
//> def run() =
7+
//> println("macro generated main")
8+
//> println("executed in: " + getClass.getName)
9+
//> }
10+
//> def foo(): Unit =
11+
//> new Baz().run
12+
13+
@addClass def bar(): Unit =
14+
println("macro generated main")
15+
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
16+
//> class Baz extends Object {
17+
//> def run() =
18+
//> println("macro generated main")
19+
//> println("executed in: " + getClass.getName)
20+
//> }
21+
//> def Baz(): Unit =
22+
//> new Baz().run
23+
24+
foo()
25+
bar()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
macro generated main
2+
executed in: Test_2$package$Baz$2
3+
macro generated main
4+
executed in: Test_2$package$Baz$4
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import scala.annotation.{experimental, MacroAnnotation}
2+
import scala.quoted._
3+
import scala.collection.mutable
4+
5+
@experimental
6+
class addClass extends MacroAnnotation:
7+
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
8+
import quotes.reflect._
9+
tree match
10+
case DefDef(name, List(TermParamClause(Nil)), tpt, Some(rhs)) =>
11+
val parents = List(TypeTree.of[Object])
12+
def decls(cls: Symbol): List[Symbol] =
13+
List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.Static, Symbol.noSymbol))
14+
15+
// FIXME: missing flags: Final | Module
16+
// FIXME: how to set the self type?
17+
val cls = Symbol.newClass(Symbol.spliceOwner, "Baz", parents = parents.map(_.tpe), decls, selfType = None)
18+
val mod = Symbol.newVal(Symbol.spliceOwner, "Baz", cls.typeRef, Flags.Module | Flags.Lazy | Flags.Final, Symbol.noSymbol)
19+
val runSym = cls.declaredMethod("run").head
20+
21+
val runDef = DefDef(runSym, _ => Some(rhs))
22+
23+
val clsDef = ClassDef(cls, parents, body = List(runDef))
24+
25+
val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)
26+
val modVal = ValDef(mod, Some(newCls))
27+
28+
val newDef = DefDef.copy(tree)(name, List(TermParamClause(Nil)), tpt, Some(Apply(Select(Ref(mod), runSym), Nil)))
29+
List(modVal, clsDef, newDef)
30+
case _ =>
31+
report.error("Annotation only supports `def` with one argument")
32+
List(tree)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@main def Test(): Unit =
2+
@addClass def foo(): Unit =
3+
println("macro generated main")
4+
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
5+
//> object Baz {
6+
//> def run() =
7+
//> println("macro generated main")
8+
//> println("executed in: " + getClass.getName)
9+
//> }
10+
//> def foo(): Unit =
11+
//> Baz.run
12+
13+
@addClass def bar(): Unit =
14+
println("macro generated main")
15+
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
16+
//> object Baz {
17+
//> def run() =
18+
//> println("macro generated main")
19+
//> println("executed in: " + getClass.getName)
20+
//> }
21+
//> def Baz(): Unit =
22+
//> Baz.run
23+
24+
foo()
25+
bar()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
macro generated main
2+
executed in: Foo$Bar$macro$1
3+
macro generated main
4+
executed in: Foo$Bar$macro$2
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import scala.annotation.{experimental, MacroAnnotation}
2+
import scala.quoted._
3+
import scala.collection.mutable
4+
5+
@experimental
6+
class addClass extends MacroAnnotation:
7+
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
8+
import quotes.reflect._
9+
tree match
10+
case DefDef(name, List(TermParamClause(Nil)), tpt, Some(rhs)) =>
11+
val parents = List(TypeTree.of[Object])
12+
def decls(cls: Symbol): List[Symbol] =
13+
List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.Static, Symbol.noSymbol))
14+
15+
val newClassName = Symbol.freshName("Bar")
16+
val cls = Symbol.newClass(Symbol.spliceOwner, newClassName, parents = parents.map(_.tpe), decls, selfType = None)
17+
val runSym = cls.declaredMethod("run").head
18+
19+
val runDef = DefDef(runSym, _ => Some(rhs))
20+
val clsDef = ClassDef(cls, parents, body = List(runDef))
21+
22+
val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)
23+
24+
val newDef = DefDef.copy(tree)(name, List(TermParamClause(Nil)), tpt, Some(Apply(Select(newCls, runSym), Nil)))
25+
List(clsDef, newDef)
26+
case _ =>
27+
report.error("Annotation only supports `def` with one argument")
28+
List(tree)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
class Foo():
2+
@addClass def foo(): Unit =
3+
println("macro generated main")
4+
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
5+
//> class Baz$macro$1 extends Object {
6+
//> def run() =
7+
//> println("macro generated main")
8+
//> println("executed in: " + getClass.getName)
9+
//> }
10+
//> def foo(): Unit =
11+
//> new Baz$macro$1.run
12+
13+
@addClass def bar(): Unit =
14+
println("macro generated main")
15+
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
16+
//> class Baz$macro$2 extends Object {
17+
//> def run() =
18+
//> println("macro generated main")
19+
//> println("executed in: " + getClass.getName)
20+
//> }
21+
//> def foo(): Unit =
22+
//> new Baz$macro$2.run
23+
24+
@main def Test(): Unit =
25+
new Foo().foo()
26+
new Foo().bar()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
macro generated main
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import scala.annotation.{experimental, MacroAnnotation}
2+
import scala.quoted._
3+
import scala.collection.mutable
4+
5+
@experimental
6+
class mainMacro extends MacroAnnotation:
7+
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
8+
import quotes.reflect._
9+
tree match
10+
case DefDef(name, List(TermParamClause(Nil)), _, _) =>
11+
val parents = List(TypeTree.of[Object])
12+
def decls(cls: Symbol): List[Symbol] =
13+
List(Symbol.newMethod(cls, "main", MethodType(List("args"))(_ => List(TypeRepr.of[Array[String]]), _ => TypeRepr.of[Unit]), Flags.Static, Symbol.noSymbol))
14+
15+
val cls = Symbol.newClass(Symbol.spliceOwner.owner, name, parents = parents.map(_.tpe), decls, selfType = None)
16+
val mainSym = cls.declaredMethod("main").head
17+
18+
val mainDef = DefDef(mainSym, _ => Some(Apply(Ref(tree.symbol), Nil)))
19+
val clsDef = ClassDef(cls, parents, body = List(mainDef))
20+
21+
List(clsDef, tree)
22+
case _ =>
23+
report.error("Annotation only supports `def` without arguments")
24+
List(tree)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
@mainMacro def Test(): Unit = println("macro generated main")

0 commit comments

Comments
 (0)