Skip to content

Commit 669423a

Browse files
Merge pull request #8193 from dotty-staging/extract-tree-utils-from-reflection
Extract Tree utils from Reflection
2 parents ff93bb1 + 06da51f commit 669423a

File tree

8 files changed

+343
-276
lines changed

8 files changed

+343
-276
lines changed

docs/docs/reference/metaprogramming/tasty-reflect.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def macroImpl()(qctx: QuoteContext): Expr[Unit] = {
119119

120120
### Tree Utilities
121121

122-
`scala.tasty.reflect.TreeUtils` contains three facilities for tree traversal and
122+
`scala.tasty.reflect` contains three facilities for tree traversal and
123123
transformations.
124124

125125
`TreeAccumulator` ties the knot of a traversal. By calling `foldOver(x, tree))`
@@ -144,7 +144,7 @@ but without returning any value. Finally a `TreeMap` performs a transformation.
144144

145145
#### Let
146146

147-
`scala.tasty.reflect.utils.TreeUtils` also offers a method `let` that allows us
147+
`scala.tasty.Reflection` also offers a method `let` that allows us
148148
to bind the `rhs` to a `val` and use it in `body`. Additionally, `lets` binds
149149
the given `terms` to names and use them in the `body`. Their type definitions
150150
are shown below:

library/src/scala/internal/quoted/Matcher.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ private[quoted] object Matcher {
317317
if freePatternVars(term).isEmpty then Some(term) else None
318318

319319
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
320-
def freePatternVars(term: Term)(given qctx: Context, env: Env): Set[Symbol] =
320+
def freePatternVars(term: Term)(given ctx: Context, env: Env): Set[Symbol] =
321321
val accumulator = new TreeAccumulator[Set[Symbol]] {
322322
def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] =
323323
tree match

library/src/scala/tasty/Reflection.scala

+11-254
Original file line numberDiff line numberDiff line change
@@ -2748,266 +2748,23 @@ class Reflection(private[scala] val internal: CompilerInterface) { self =>
27482748
// UTILS //
27492749
///////////////
27502750

2751-
abstract class TreeAccumulator[X] {
2752-
2753-
// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
2754-
def foldTree(x: X, tree: Tree)(given ctx: Context): X
2755-
2756-
def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X = trees.foldLeft(x)(foldTree)
2757-
2758-
def foldOverTree(x: X, tree: Tree)(given ctx: Context): X = {
2759-
def localCtx(definition: Definition): Context = definition.symbol.localContext
2760-
tree match {
2761-
case Ident(_) =>
2762-
x
2763-
case Select(qualifier, _) =>
2764-
foldTree(x, qualifier)
2765-
case This(qual) =>
2766-
x
2767-
case Super(qual, _) =>
2768-
foldTree(x, qual)
2769-
case Apply(fun, args) =>
2770-
foldTrees(foldTree(x, fun), args)
2771-
case TypeApply(fun, args) =>
2772-
foldTrees(foldTree(x, fun), args)
2773-
case Literal(const) =>
2774-
x
2775-
case New(tpt) =>
2776-
foldTree(x, tpt)
2777-
case Typed(expr, tpt) =>
2778-
foldTree(foldTree(x, expr), tpt)
2779-
case NamedArg(_, arg) =>
2780-
foldTree(x, arg)
2781-
case Assign(lhs, rhs) =>
2782-
foldTree(foldTree(x, lhs), rhs)
2783-
case Block(stats, expr) =>
2784-
foldTree(foldTrees(x, stats), expr)
2785-
case If(cond, thenp, elsep) =>
2786-
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
2787-
case While(cond, body) =>
2788-
foldTree(foldTree(x, cond), body)
2789-
case Closure(meth, tpt) =>
2790-
foldTree(x, meth)
2791-
case Match(selector, cases) =>
2792-
foldTrees(foldTree(x, selector), cases)
2793-
case Return(expr) =>
2794-
foldTree(x, expr)
2795-
case Try(block, handler, finalizer) =>
2796-
foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
2797-
case Repeated(elems, elemtpt) =>
2798-
foldTrees(foldTree(x, elemtpt), elems)
2799-
case Inlined(call, bindings, expansion) =>
2800-
foldTree(foldTrees(x, bindings), expansion)
2801-
case vdef @ ValDef(_, tpt, rhs) =>
2802-
val ctx = localCtx(vdef)
2803-
given Context = ctx
2804-
foldTrees(foldTree(x, tpt), rhs)
2805-
case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) =>
2806-
val ctx = localCtx(ddef)
2807-
given Context = ctx
2808-
foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
2809-
case tdef @ TypeDef(_, rhs) =>
2810-
val ctx = localCtx(tdef)
2811-
given Context = ctx
2812-
foldTree(x, rhs)
2813-
case cdef @ ClassDef(_, constr, parents, derived, self, body) =>
2814-
val ctx = localCtx(cdef)
2815-
given Context = ctx
2816-
foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
2817-
case Import(expr, _) =>
2818-
foldTree(x, expr)
2819-
case clause @ PackageClause(pid, stats) =>
2820-
foldTrees(foldTree(x, pid), stats)(given clause.symbol.localContext)
2821-
case Inferred() => x
2822-
case TypeIdent(_) => x
2823-
case TypeSelect(qualifier, _) => foldTree(x, qualifier)
2824-
case Projection(qualifier, _) => foldTree(x, qualifier)
2825-
case Singleton(ref) => foldTree(x, ref)
2826-
case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
2827-
case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args)
2828-
case ByName(result) => foldTree(x, result)
2829-
case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot)
2830-
case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
2831-
case TypeBind(_, tbt) => foldTree(x, tbt)
2832-
case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
2833-
case MatchTypeTree(boundopt, selector, cases) =>
2834-
foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
2835-
case WildcardTypeTree() => x
2836-
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
2837-
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
2838-
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
2839-
case Bind(_, body) => foldTree(x, body)
2840-
case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
2841-
case Alternatives(patterns) => foldTrees(x, patterns)
2842-
}
2843-
}
2751+
/** TASTy Reflect tree accumulator */
2752+
trait TreeAccumulator[X] extends reflect.TreeAccumulator[X] {
2753+
val reflect: self.type = self
28442754
}
28452755

2846-
abstract class TreeTraverser extends TreeAccumulator[Unit] {
2847-
2848-
def traverseTree(tree: Tree)(given ctx: Context): Unit = traverseTreeChildren(tree)
2849-
2850-
def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit = traverseTree(tree)
2851-
2852-
protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit = foldOverTree((), tree)
2853-
2756+
/** TASTy Reflect tree traverser */
2757+
trait TreeTraverser extends reflect.TreeTraverser {
2758+
val reflect: self.type = self
28542759
}
28552760

2856-
abstract class TreeMap { self =>
2857-
2858-
def transformTree(tree: Tree)(given ctx: Context): Tree = {
2859-
tree match {
2860-
case tree: PackageClause =>
2861-
PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(given tree.symbol.localContext))
2862-
case tree: Import =>
2863-
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
2864-
case tree: Statement =>
2865-
transformStatement(tree)
2866-
case tree: TypeTree => transformTypeTree(tree)
2867-
case tree: TypeBoundsTree => tree // TODO traverse tree
2868-
case tree: WildcardTypeTree => tree // TODO traverse tree
2869-
case tree: CaseDef =>
2870-
transformCaseDef(tree)
2871-
case tree: TypeCaseDef =>
2872-
transformTypeCaseDef(tree)
2873-
case pattern: Bind =>
2874-
Bind.copy(pattern)(pattern.name, pattern.pattern)
2875-
case pattern: Unapply =>
2876-
Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
2877-
case pattern: Alternatives =>
2878-
Alternatives.copy(pattern)(transformTrees(pattern.patterns))
2879-
}
2880-
}
2881-
2882-
def transformStatement(tree: Statement)(given ctx: Context): Statement = {
2883-
def localCtx(definition: Definition): Context = definition.symbol.localContext
2884-
tree match {
2885-
case tree: Term =>
2886-
transformTerm(tree)
2887-
case tree: ValDef =>
2888-
val ctx = localCtx(tree)
2889-
given Context = ctx
2890-
val tpt1 = transformTypeTree(tree.tpt)
2891-
val rhs1 = tree.rhs.map(x => transformTerm(x))
2892-
ValDef.copy(tree)(tree.name, tpt1, rhs1)
2893-
case tree: DefDef =>
2894-
val ctx = localCtx(tree)
2895-
given Context = ctx
2896-
DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
2897-
case tree: TypeDef =>
2898-
val ctx = localCtx(tree)
2899-
given Context = ctx
2900-
TypeDef.copy(tree)(tree.name, transformTree(tree.rhs))
2901-
case tree: ClassDef =>
2902-
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
2903-
case tree: Import =>
2904-
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
2905-
}
2906-
}
2907-
2908-
def transformTerm(tree: Term)(given ctx: Context): Term = {
2909-
tree match {
2910-
case Ident(name) =>
2911-
tree
2912-
case Select(qualifier, name) =>
2913-
Select.copy(tree)(transformTerm(qualifier), name)
2914-
case This(qual) =>
2915-
tree
2916-
case Super(qual, mix) =>
2917-
Super.copy(tree)(transformTerm(qual), mix)
2918-
case Apply(fun, args) =>
2919-
Apply.copy(tree)(transformTerm(fun), transformTerms(args))
2920-
case TypeApply(fun, args) =>
2921-
TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args))
2922-
case Literal(const) =>
2923-
tree
2924-
case New(tpt) =>
2925-
New.copy(tree)(transformTypeTree(tpt))
2926-
case Typed(expr, tpt) =>
2927-
Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt))
2928-
case tree: NamedArg =>
2929-
NamedArg.copy(tree)(tree.name, transformTerm(tree.value))
2930-
case Assign(lhs, rhs) =>
2931-
Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs))
2932-
case Block(stats, expr) =>
2933-
Block.copy(tree)(transformStats(stats), transformTerm(expr))
2934-
case If(cond, thenp, elsep) =>
2935-
If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
2936-
case Closure(meth, tpt) =>
2937-
Closure.copy(tree)(transformTerm(meth), tpt)
2938-
case Match(selector, cases) =>
2939-
Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases))
2940-
case Return(expr) =>
2941-
Return.copy(tree)(transformTerm(expr))
2942-
case While(cond, body) =>
2943-
While.copy(tree)(transformTerm(cond), transformTerm(body))
2944-
case Try(block, cases, finalizer) =>
2945-
Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
2946-
case Repeated(elems, elemtpt) =>
2947-
Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
2948-
case Inlined(call, bindings, expansion) =>
2949-
Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/)
2950-
}
2951-
}
2952-
2953-
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = tree match {
2954-
case Inferred() => tree
2955-
case tree: TypeIdent => tree
2956-
case tree: TypeSelect =>
2957-
TypeSelect.copy(tree)(tree.qualifier, tree.name)
2958-
case tree: Projection =>
2959-
Projection.copy(tree)(tree.qualifier, tree.name)
2960-
case tree: Annotated =>
2961-
Annotated.copy(tree)(tree.arg, tree.annotation)
2962-
case tree: Singleton =>
2963-
Singleton.copy(tree)(transformTerm(tree.ref))
2964-
case tree: Refined =>
2965-
Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]])
2966-
case tree: Applied =>
2967-
Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
2968-
case tree: MatchTypeTree =>
2969-
MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
2970-
case tree: ByName =>
2971-
ByName.copy(tree)(transformTypeTree(tree.result))
2972-
case tree: LambdaTypeTree =>
2973-
LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))(given tree.symbol.localContext)
2974-
case tree: TypeBind =>
2975-
TypeBind.copy(tree)(tree.name, tree.body)
2976-
case tree: TypeBlock =>
2977-
TypeBlock.copy(tree)(tree.aliases, tree.tpt)
2978-
}
2979-
2980-
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef = {
2981-
CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
2982-
}
2983-
2984-
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = {
2985-
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
2986-
}
2987-
2988-
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
2989-
trees mapConserve (transformStatement(_))
2990-
2991-
def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] =
2992-
trees mapConserve (transformTree(_))
2993-
2994-
def transformTerms(trees: List[Term])(given ctx: Context): List[Term] =
2995-
trees mapConserve (transformTerm(_))
2996-
2997-
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
2998-
trees mapConserve (transformTypeTree(_))
2999-
3000-
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
3001-
trees mapConserve (transformCaseDef(_))
3002-
3003-
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
3004-
trees mapConserve (transformTypeCaseDef(_))
3005-
3006-
def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
3007-
transformTrees(trees).asInstanceOf[List[Tr]]
3008-
2761+
/** TASTy Reflect tree map */
2762+
trait TreeMap extends reflect.TreeMap {
2763+
val reflect: self.type = self
30092764
}
30102765

2766+
// TODO extract from Reflection
2767+
30112768
/** Bind the `rhs` to a `val` and use it in `body` */
30122769
def let(rhs: Term)(body: Ident => Term)(given ctx: Context): Term = {
30132770
import scala.quoted.QuoteContext

0 commit comments

Comments
 (0)