Skip to content

Move TreeMap, TreeAccumulator and TreeTraverser into Reflection #10184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 290 additions & 12 deletions library/src/scala/tasty/Reflection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3313,20 +3313,298 @@ trait Reflection { reflection =>
// UTILS //
///////////////

/** TASTy Reflect tree accumulator */
trait TreeAccumulator[X] extends reflect.TreeAccumulator[X] {
val reflect: reflection.type = reflection
}
/** TASTy Reflect tree accumulator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember one of the previous PRs is to avoid the name TASTy.

Suggested change
/** TASTy Reflect tree accumulator.
/** Tree accumulator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do a general cleanup of the docs in a separate PR

*
* Usage:
* ```
* class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R)
* extends scala.tasty.reflect.TreeAccumulator[X] {
* import reflect._
* def foldTree(x: X, tree: Tree)(using ctx: Context): X = ...
* }
* ```
*/
trait TreeAccumulator[X]:

// Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node.
def foldTree(x: X, tree: Tree)(using ctx: Context): X

def foldTrees(x: X, trees: Iterable[Tree])(using ctx: Context): X = trees.foldLeft(x)(foldTree)

def foldOverTree(x: X, tree: Tree)(using ctx: Context): X = {
def localCtx(definition: Definition): Context = definition.symbol.localContext
tree match {
case Ident(_) =>
x
case Select(qualifier, _) =>
foldTree(x, qualifier)
case This(qual) =>
x
case Super(qual, _) =>
foldTree(x, qual)
case Apply(fun, args) =>
foldTrees(foldTree(x, fun), args)
case TypeApply(fun, args) =>
foldTrees(foldTree(x, fun), args)
case Literal(const) =>
x
case New(tpt) =>
foldTree(x, tpt)
case Typed(expr, tpt) =>
foldTree(foldTree(x, expr), tpt)
case NamedArg(_, arg) =>
foldTree(x, arg)
case Assign(lhs, rhs) =>
foldTree(foldTree(x, lhs), rhs)
case Block(stats, expr) =>
foldTree(foldTrees(x, stats), expr)
case If(cond, thenp, elsep) =>
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
case While(cond, body) =>
foldTree(foldTree(x, cond), body)
case Closure(meth, tpt) =>
foldTree(x, meth)
case Match(selector, cases) =>
foldTrees(foldTree(x, selector), cases)
case Return(expr, _) =>
foldTree(x, expr)
case Try(block, handler, finalizer) =>
foldTrees(foldTrees(foldTree(x, block), handler), finalizer)
case Repeated(elems, elemtpt) =>
foldTrees(foldTree(x, elemtpt), elems)
case Inlined(call, bindings, expansion) =>
foldTree(foldTrees(x, bindings), expansion)
case vdef @ ValDef(_, tpt, rhs) =>
val ctx = localCtx(vdef)
given Context = ctx
foldTrees(foldTree(x, tpt), rhs)
case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) =>
val ctx = localCtx(ddef)
given Context = ctx
foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs)
case tdef @ TypeDef(_, rhs) =>
val ctx = localCtx(tdef)
given Context = ctx
foldTree(x, rhs)
case cdef @ ClassDef(_, constr, parents, derived, self, body) =>
val ctx = localCtx(cdef)
given Context = ctx
foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body)
case Import(expr, _) =>
foldTree(x, expr)
case clause @ PackageClause(pid, stats) =>
foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext)
case Inferred() => x
case TypeIdent(_) => x
case TypeSelect(qualifier, _) => foldTree(x, qualifier)
case Projection(qualifier, _) => foldTree(x, qualifier)
case Singleton(ref) => foldTree(x, ref)
case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements)
case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args)
case ByName(result) => foldTree(x, result)
case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot)
case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg)
case TypeBind(_, tbt) => foldTree(x, tbt)
case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt)
case MatchTypeTree(boundopt, selector, cases) =>
foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases)
case WildcardTypeTree() => x
case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi)
case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body)
case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body)
case Bind(_, body) => foldTree(x, body)
case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns)
case Alternatives(patterns) => foldTrees(x, patterns)
}
}
end TreeAccumulator

/** TASTy Reflect tree traverser */
trait TreeTraverser extends reflect.TreeTraverser {
val reflect: reflection.type = reflection
}

/** TASTy Reflect tree map */
trait TreeMap extends reflect.TreeMap {
val reflect: reflection.type = reflection
}
/** TASTy Reflect tree traverser.
*
* Usage:
* ```
* class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)
* extends scala.tasty.reflect.TreeTraverser {
* import reflect._
* override def traverseTree(tree: Tree)(using ctx: Context): Unit = ...
* }
* ```
*/
trait TreeTraverser extends TreeAccumulator[Unit]:

def traverseTree(tree: Tree)(using ctx: Context): Unit = traverseTreeChildren(tree)

def foldTree(x: Unit, tree: Tree)(using ctx: Context): Unit = traverseTree(tree)

protected def traverseTreeChildren(tree: Tree)(using ctx: Context): Unit = foldOverTree((), tree)

end TreeTraverser

/** TASTy Reflect tree map.
*
* Usage:
* ```
* import qctx.reflect._
* class MyTreeMap extends TreeMap {
* override def transformTree(tree: Tree)(using ctx: Context): Tree = ...
* }
* ```
*/
trait TreeMap:

def transformTree(tree: Tree)(using ctx: Context): Tree = {
tree match {
case tree: PackageClause =>
PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(using tree.symbol.localContext))
case tree: Import =>
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
case tree: Statement =>
transformStatement(tree)
case tree: TypeTree => transformTypeTree(tree)
case tree: TypeBoundsTree => tree // TODO traverse tree
case tree: WildcardTypeTree => tree // TODO traverse tree
case tree: CaseDef =>
transformCaseDef(tree)
case tree: TypeCaseDef =>
transformTypeCaseDef(tree)
case pattern: Bind =>
Bind.copy(pattern)(pattern.name, pattern.pattern)
case pattern: Unapply =>
Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns))
case pattern: Alternatives =>
Alternatives.copy(pattern)(transformTrees(pattern.patterns))
}
}

def transformStatement(tree: Statement)(using ctx: Context): Statement = {
def localCtx(definition: Definition): Context = definition.symbol.localContext
tree match {
case tree: Term =>
transformTerm(tree)
case tree: ValDef =>
val ctx = localCtx(tree)
given Context = ctx
val tpt1 = transformTypeTree(tree.tpt)
val rhs1 = tree.rhs.map(x => transformTerm(x))
ValDef.copy(tree)(tree.name, tpt1, rhs1)
case tree: DefDef =>
val ctx = localCtx(tree)
given Context = ctx
DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x)))
case tree: TypeDef =>
val ctx = localCtx(tree)
given Context = ctx
TypeDef.copy(tree)(tree.name, transformTree(tree.rhs))
case tree: ClassDef =>
ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body)
case tree: Import =>
Import.copy(tree)(transformTerm(tree.expr), tree.selectors)
}
}

def transformTerm(tree: Term)(using ctx: Context): Term = {
tree match {
case Ident(name) =>
tree
case Select(qualifier, name) =>
Select.copy(tree)(transformTerm(qualifier), name)
case This(qual) =>
tree
case Super(qual, mix) =>
Super.copy(tree)(transformTerm(qual), mix)
case Apply(fun, args) =>
Apply.copy(tree)(transformTerm(fun), transformTerms(args))
case TypeApply(fun, args) =>
TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args))
case Literal(const) =>
tree
case New(tpt) =>
New.copy(tree)(transformTypeTree(tpt))
case Typed(expr, tpt) =>
Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt))
case tree: NamedArg =>
NamedArg.copy(tree)(tree.name, transformTerm(tree.value))
case Assign(lhs, rhs) =>
Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs))
case Block(stats, expr) =>
Block.copy(tree)(transformStats(stats), transformTerm(expr))
case If(cond, thenp, elsep) =>
If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
case Closure(meth, tpt) =>
Closure.copy(tree)(transformTerm(meth), tpt)
case Match(selector, cases) =>
Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases))
case Return(expr, from) =>
Return.copy(tree)(transformTerm(expr), from)
case While(cond, body) =>
While.copy(tree)(transformTerm(cond), transformTerm(body))
case Try(block, cases, finalizer) =>
Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x)))
case Repeated(elems, elemtpt) =>
Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt))
case Inlined(call, bindings, expansion) =>
Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/)
}
}

def transformTypeTree(tree: TypeTree)(using ctx: Context): TypeTree = tree match {
case Inferred() => tree
case tree: TypeIdent => tree
case tree: TypeSelect =>
TypeSelect.copy(tree)(tree.qualifier, tree.name)
case tree: Projection =>
Projection.copy(tree)(tree.qualifier, tree.name)
case tree: Annotated =>
Annotated.copy(tree)(tree.arg, tree.annotation)
case tree: Singleton =>
Singleton.copy(tree)(transformTerm(tree.ref))
case tree: Refined =>
Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]])
case tree: Applied =>
Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args))
case tree: MatchTypeTree =>
MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases))
case tree: ByName =>
ByName.copy(tree)(transformTypeTree(tree.result))
case tree: LambdaTypeTree =>
LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body))
case tree: TypeBind =>
TypeBind.copy(tree)(tree.name, tree.body)
case tree: TypeBlock =>
TypeBlock.copy(tree)(tree.aliases, tree.tpt)
}

def transformCaseDef(tree: CaseDef)(using ctx: Context): CaseDef = {
CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs))
}

def transformTypeCaseDef(tree: TypeCaseDef)(using ctx: Context): TypeCaseDef = {
TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs))
}

def transformStats(trees: List[Statement])(using ctx: Context): List[Statement] =
trees mapConserve (transformStatement(_))

def transformTrees(trees: List[Tree])(using ctx: Context): List[Tree] =
trees mapConserve (transformTree(_))

def transformTerms(trees: List[Term])(using ctx: Context): List[Term] =
trees mapConserve (transformTerm(_))

def transformTypeTrees(trees: List[TypeTree])(using ctx: Context): List[TypeTree] =
trees mapConserve (transformTypeTree(_))

def transformCaseDefs(trees: List[CaseDef])(using ctx: Context): List[CaseDef] =
trees mapConserve (transformCaseDef(_))

def transformTypeCaseDefs(trees: List[TypeCaseDef])(using ctx: Context): List[TypeCaseDef] =
trees mapConserve (transformTypeCaseDef(_))

def transformSubTrees[Tr <: Tree](trees: List[Tr])(using ctx: Context): List[Tr] =
transformTrees(trees).asInstanceOf[List[Tr]]

end TreeMap

// TODO: extract from Reflection

Expand Down
Loading