diff --git a/library/src/scala/tasty/Reflection.scala b/library/src/scala/tasty/Reflection.scala index 2c108f785c2a..c2eee40a2ac6 100644 --- a/library/src/scala/tasty/Reflection.scala +++ b/library/src/scala/tasty/Reflection.scala @@ -3313,20 +3313,298 @@ trait Reflection { reflection => // UTILS // /////////////// - /** TASTy Reflect tree accumulator */ - trait TreeAccumulator[X] extends reflect.TreeAccumulator[X] { - val reflect: reflection.type = reflection - } + /** TASTy Reflect tree accumulator. + * + * Usage: + * ``` + * class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R) + * extends scala.tasty.reflect.TreeAccumulator[X] { + * import reflect._ + * def foldTree(x: X, tree: Tree)(using ctx: Context): X = ... + * } + * ``` + */ + trait TreeAccumulator[X]: + + // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node. + def foldTree(x: X, tree: Tree)(using ctx: Context): X + + def foldTrees(x: X, trees: Iterable[Tree])(using ctx: Context): X = trees.foldLeft(x)(foldTree) + + def foldOverTree(x: X, tree: Tree)(using ctx: Context): X = { + def localCtx(definition: Definition): Context = definition.symbol.localContext + tree match { + case Ident(_) => + x + case Select(qualifier, _) => + foldTree(x, qualifier) + case This(qual) => + x + case Super(qual, _) => + foldTree(x, qual) + case Apply(fun, args) => + foldTrees(foldTree(x, fun), args) + case TypeApply(fun, args) => + foldTrees(foldTree(x, fun), args) + case Literal(const) => + x + case New(tpt) => + foldTree(x, tpt) + case Typed(expr, tpt) => + foldTree(foldTree(x, expr), tpt) + case NamedArg(_, arg) => + foldTree(x, arg) + case Assign(lhs, rhs) => + foldTree(foldTree(x, lhs), rhs) + case Block(stats, expr) => + foldTree(foldTrees(x, stats), expr) + case If(cond, thenp, elsep) => + foldTree(foldTree(foldTree(x, cond), thenp), elsep) + case While(cond, body) => + foldTree(foldTree(x, cond), body) + case Closure(meth, tpt) => + foldTree(x, meth) + case Match(selector, cases) => + foldTrees(foldTree(x, selector), cases) + case Return(expr, _) => + foldTree(x, expr) + case Try(block, handler, finalizer) => + foldTrees(foldTrees(foldTree(x, block), handler), finalizer) + case Repeated(elems, elemtpt) => + foldTrees(foldTree(x, elemtpt), elems) + case Inlined(call, bindings, expansion) => + foldTree(foldTrees(x, bindings), expansion) + case vdef @ ValDef(_, tpt, rhs) => + val ctx = localCtx(vdef) + given Context = ctx + foldTrees(foldTree(x, tpt), rhs) + case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) => + val ctx = localCtx(ddef) + given Context = ctx + foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs) + case tdef @ TypeDef(_, rhs) => + val ctx = localCtx(tdef) + given Context = ctx + foldTree(x, rhs) + case cdef @ ClassDef(_, constr, parents, derived, self, body) => + val ctx = localCtx(cdef) + given Context = ctx + foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body) + case Import(expr, _) => + foldTree(x, expr) + case clause @ PackageClause(pid, stats) => + foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext) + case Inferred() => x + case TypeIdent(_) => x + case TypeSelect(qualifier, _) => foldTree(x, qualifier) + case Projection(qualifier, _) => foldTree(x, qualifier) + case Singleton(ref) => foldTree(x, ref) + case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements) + case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args) + case ByName(result) => foldTree(x, result) + case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot) + case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg) + case TypeBind(_, tbt) => foldTree(x, tbt) + case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt) + case MatchTypeTree(boundopt, selector, cases) => + foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases) + case WildcardTypeTree() => x + case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi) + case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body) + case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body) + case Bind(_, body) => foldTree(x, body) + case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns) + case Alternatives(patterns) => foldTrees(x, patterns) + } + } + end TreeAccumulator - /** TASTy Reflect tree traverser */ - trait TreeTraverser extends reflect.TreeTraverser { - val reflect: reflection.type = reflection - } - /** TASTy Reflect tree map */ - trait TreeMap extends reflect.TreeMap { - val reflect: reflection.type = reflection - } + /** TASTy Reflect tree traverser. + * + * Usage: + * ``` + * class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R) + * extends scala.tasty.reflect.TreeTraverser { + * import reflect._ + * override def traverseTree(tree: Tree)(using ctx: Context): Unit = ... + * } + * ``` + */ + trait TreeTraverser extends TreeAccumulator[Unit]: + + def traverseTree(tree: Tree)(using ctx: Context): Unit = traverseTreeChildren(tree) + + def foldTree(x: Unit, tree: Tree)(using ctx: Context): Unit = traverseTree(tree) + + protected def traverseTreeChildren(tree: Tree)(using ctx: Context): Unit = foldOverTree((), tree) + + end TreeTraverser + + /** TASTy Reflect tree map. + * + * Usage: + * ``` + * import qctx.reflect._ + * class MyTreeMap extends TreeMap { + * override def transformTree(tree: Tree)(using ctx: Context): Tree = ... + * } + * ``` + */ + trait TreeMap: + + def transformTree(tree: Tree)(using ctx: Context): Tree = { + tree match { + case tree: PackageClause => + PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(using tree.symbol.localContext)) + case tree: Import => + Import.copy(tree)(transformTerm(tree.expr), tree.selectors) + case tree: Statement => + transformStatement(tree) + case tree: TypeTree => transformTypeTree(tree) + case tree: TypeBoundsTree => tree // TODO traverse tree + case tree: WildcardTypeTree => tree // TODO traverse tree + case tree: CaseDef => + transformCaseDef(tree) + case tree: TypeCaseDef => + transformTypeCaseDef(tree) + case pattern: Bind => + Bind.copy(pattern)(pattern.name, pattern.pattern) + case pattern: Unapply => + Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns)) + case pattern: Alternatives => + Alternatives.copy(pattern)(transformTrees(pattern.patterns)) + } + } + + def transformStatement(tree: Statement)(using ctx: Context): Statement = { + def localCtx(definition: Definition): Context = definition.symbol.localContext + tree match { + case tree: Term => + transformTerm(tree) + case tree: ValDef => + val ctx = localCtx(tree) + given Context = ctx + val tpt1 = transformTypeTree(tree.tpt) + val rhs1 = tree.rhs.map(x => transformTerm(x)) + ValDef.copy(tree)(tree.name, tpt1, rhs1) + case tree: DefDef => + val ctx = localCtx(tree) + given Context = ctx + DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x))) + case tree: TypeDef => + val ctx = localCtx(tree) + given Context = ctx + TypeDef.copy(tree)(tree.name, transformTree(tree.rhs)) + case tree: ClassDef => + ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body) + case tree: Import => + Import.copy(tree)(transformTerm(tree.expr), tree.selectors) + } + } + + def transformTerm(tree: Term)(using ctx: Context): Term = { + tree match { + case Ident(name) => + tree + case Select(qualifier, name) => + Select.copy(tree)(transformTerm(qualifier), name) + case This(qual) => + tree + case Super(qual, mix) => + Super.copy(tree)(transformTerm(qual), mix) + case Apply(fun, args) => + Apply.copy(tree)(transformTerm(fun), transformTerms(args)) + case TypeApply(fun, args) => + TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args)) + case Literal(const) => + tree + case New(tpt) => + New.copy(tree)(transformTypeTree(tpt)) + case Typed(expr, tpt) => + Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt)) + case tree: NamedArg => + NamedArg.copy(tree)(tree.name, transformTerm(tree.value)) + case Assign(lhs, rhs) => + Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs)) + case Block(stats, expr) => + Block.copy(tree)(transformStats(stats), transformTerm(expr)) + case If(cond, thenp, elsep) => + If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep)) + case Closure(meth, tpt) => + Closure.copy(tree)(transformTerm(meth), tpt) + case Match(selector, cases) => + Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases)) + case Return(expr, from) => + Return.copy(tree)(transformTerm(expr), from) + case While(cond, body) => + While.copy(tree)(transformTerm(cond), transformTerm(body)) + case Try(block, cases, finalizer) => + Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x))) + case Repeated(elems, elemtpt) => + Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt)) + case Inlined(call, bindings, expansion) => + Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/) + } + } + + def transformTypeTree(tree: TypeTree)(using ctx: Context): TypeTree = tree match { + case Inferred() => tree + case tree: TypeIdent => tree + case tree: TypeSelect => + TypeSelect.copy(tree)(tree.qualifier, tree.name) + case tree: Projection => + Projection.copy(tree)(tree.qualifier, tree.name) + case tree: Annotated => + Annotated.copy(tree)(tree.arg, tree.annotation) + case tree: Singleton => + Singleton.copy(tree)(transformTerm(tree.ref)) + case tree: Refined => + Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]]) + case tree: Applied => + Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args)) + case tree: MatchTypeTree => + MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases)) + case tree: ByName => + ByName.copy(tree)(transformTypeTree(tree.result)) + case tree: LambdaTypeTree => + LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body)) + case tree: TypeBind => + TypeBind.copy(tree)(tree.name, tree.body) + case tree: TypeBlock => + TypeBlock.copy(tree)(tree.aliases, tree.tpt) + } + + def transformCaseDef(tree: CaseDef)(using ctx: Context): CaseDef = { + CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs)) + } + + def transformTypeCaseDef(tree: TypeCaseDef)(using ctx: Context): TypeCaseDef = { + TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs)) + } + + def transformStats(trees: List[Statement])(using ctx: Context): List[Statement] = + trees mapConserve (transformStatement(_)) + + def transformTrees(trees: List[Tree])(using ctx: Context): List[Tree] = + trees mapConserve (transformTree(_)) + + def transformTerms(trees: List[Term])(using ctx: Context): List[Term] = + trees mapConserve (transformTerm(_)) + + def transformTypeTrees(trees: List[TypeTree])(using ctx: Context): List[TypeTree] = + trees mapConserve (transformTypeTree(_)) + + def transformCaseDefs(trees: List[CaseDef])(using ctx: Context): List[CaseDef] = + trees mapConserve (transformCaseDef(_)) + + def transformTypeCaseDefs(trees: List[TypeCaseDef])(using ctx: Context): List[TypeCaseDef] = + trees mapConserve (transformTypeCaseDef(_)) + + def transformSubTrees[Tr <: Tree](trees: List[Tr])(using ctx: Context): List[Tr] = + transformTrees(trees).asInstanceOf[List[Tr]] + + end TreeMap // TODO: extract from Reflection diff --git a/library/src/scala/tasty/reflect/TreeAccumulator.scala b/library/src/scala/tasty/reflect/TreeAccumulator.scala deleted file mode 100644 index ec2a509b658d..000000000000 --- a/library/src/scala/tasty/reflect/TreeAccumulator.scala +++ /dev/null @@ -1,111 +0,0 @@ -package scala.tasty -package reflect - -/** TASTy Reflect tree accumulator. - * - * Usage: - * ``` - * class MyTreeAccumulator[R <: scala.tasty.Reflection & Singleton](val reflect: R) - * extends scala.tasty.reflect.TreeAccumulator[X] { - * import reflect._ - * def foldTree(x: X, tree: Tree)(using ctx: Context): X = ... - * } - * ``` - */ -trait TreeAccumulator[X] { - - val reflect: Reflection - import reflect._ - - // Ties the knot of the traversal: call `foldOver(x, tree))` to dive in the `tree` node. - def foldTree(x: X, tree: Tree)(using ctx: Context): X - - def foldTrees(x: X, trees: Iterable[Tree])(using ctx: Context): X = trees.foldLeft(x)(foldTree) - - def foldOverTree(x: X, tree: Tree)(using ctx: Context): X = { - def localCtx(definition: Definition): Context = definition.symbol.localContext - tree match { - case Ident(_) => - x - case Select(qualifier, _) => - foldTree(x, qualifier) - case This(qual) => - x - case Super(qual, _) => - foldTree(x, qual) - case Apply(fun, args) => - foldTrees(foldTree(x, fun), args) - case TypeApply(fun, args) => - foldTrees(foldTree(x, fun), args) - case Literal(const) => - x - case New(tpt) => - foldTree(x, tpt) - case Typed(expr, tpt) => - foldTree(foldTree(x, expr), tpt) - case NamedArg(_, arg) => - foldTree(x, arg) - case Assign(lhs, rhs) => - foldTree(foldTree(x, lhs), rhs) - case Block(stats, expr) => - foldTree(foldTrees(x, stats), expr) - case If(cond, thenp, elsep) => - foldTree(foldTree(foldTree(x, cond), thenp), elsep) - case While(cond, body) => - foldTree(foldTree(x, cond), body) - case Closure(meth, tpt) => - foldTree(x, meth) - case Match(selector, cases) => - foldTrees(foldTree(x, selector), cases) - case Return(expr, _) => - foldTree(x, expr) - case Try(block, handler, finalizer) => - foldTrees(foldTrees(foldTree(x, block), handler), finalizer) - case Repeated(elems, elemtpt) => - foldTrees(foldTree(x, elemtpt), elems) - case Inlined(call, bindings, expansion) => - foldTree(foldTrees(x, bindings), expansion) - case vdef @ ValDef(_, tpt, rhs) => - val ctx = localCtx(vdef) - given Context = ctx - foldTrees(foldTree(x, tpt), rhs) - case ddef @ DefDef(_, tparams, vparamss, tpt, rhs) => - val ctx = localCtx(ddef) - given Context = ctx - foldTrees(foldTree(vparamss.foldLeft(foldTrees(x, tparams))(foldTrees), tpt), rhs) - case tdef @ TypeDef(_, rhs) => - val ctx = localCtx(tdef) - given Context = ctx - foldTree(x, rhs) - case cdef @ ClassDef(_, constr, parents, derived, self, body) => - val ctx = localCtx(cdef) - given Context = ctx - foldTrees(foldTrees(foldTrees(foldTrees(foldTree(x, constr), parents), derived), self), body) - case Import(expr, _) => - foldTree(x, expr) - case clause @ PackageClause(pid, stats) => - foldTrees(foldTree(x, pid), stats)(using clause.symbol.localContext) - case Inferred() => x - case TypeIdent(_) => x - case TypeSelect(qualifier, _) => foldTree(x, qualifier) - case Projection(qualifier, _) => foldTree(x, qualifier) - case Singleton(ref) => foldTree(x, ref) - case Refined(tpt, refinements) => foldTrees(foldTree(x, tpt), refinements) - case Applied(tpt, args) => foldTrees(foldTree(x, tpt), args) - case ByName(result) => foldTree(x, result) - case Annotated(arg, annot) => foldTree(foldTree(x, arg), annot) - case LambdaTypeTree(typedefs, arg) => foldTree(foldTrees(x, typedefs), arg) - case TypeBind(_, tbt) => foldTree(x, tbt) - case TypeBlock(typedefs, tpt) => foldTree(foldTrees(x, typedefs), tpt) - case MatchTypeTree(boundopt, selector, cases) => - foldTrees(foldTree(boundopt.fold(x)(foldTree(x, _)), selector), cases) - case WildcardTypeTree() => x - case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo), hi) - case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat), guard), body) - case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat), body) - case Bind(_, body) => foldTree(x, body) - case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun), implicits), patterns) - case Alternatives(patterns) => foldTrees(x, patterns) - } - } -} diff --git a/library/src/scala/tasty/reflect/TreeMap.scala b/library/src/scala/tasty/reflect/TreeMap.scala deleted file mode 100644 index 50db1f5fd5d7..000000000000 --- a/library/src/scala/tasty/reflect/TreeMap.scala +++ /dev/null @@ -1,171 +0,0 @@ -package scala.tasty -package reflect - -/** TASTy Reflect tree map. - * - * Usage: - * ``` - * class MyTreeMap[R <: scala.tasty.Reflection & Singleton](val reflect: R) - * extends scala.tasty.reflect.TreeMap { - * import reflect._ - * override def transformTree(tree: Tree)(using ctx: Context): Tree = ... - * } - * ``` - */ -trait TreeMap { - - val reflect: Reflection - import reflect._ - - def transformTree(tree: Tree)(using ctx: Context): Tree = { - tree match { - case tree: PackageClause => - PackageClause.copy(tree)(transformTerm(tree.pid).asInstanceOf[Ref], transformTrees(tree.stats)(using tree.symbol.localContext)) - case tree: Import => - Import.copy(tree)(transformTerm(tree.expr), tree.selectors) - case tree: Statement => - transformStatement(tree) - case tree: TypeTree => transformTypeTree(tree) - case tree: TypeBoundsTree => tree // TODO traverse tree - case tree: WildcardTypeTree => tree // TODO traverse tree - case tree: CaseDef => - transformCaseDef(tree) - case tree: TypeCaseDef => - transformTypeCaseDef(tree) - case pattern: Bind => - Bind.copy(pattern)(pattern.name, pattern.pattern) - case pattern: Unapply => - Unapply.copy(pattern)(transformTerm(pattern.fun), transformSubTrees(pattern.implicits), transformTrees(pattern.patterns)) - case pattern: Alternatives => - Alternatives.copy(pattern)(transformTrees(pattern.patterns)) - } - } - - def transformStatement(tree: Statement)(using ctx: Context): Statement = { - def localCtx(definition: Definition): Context = definition.symbol.localContext - tree match { - case tree: Term => - transformTerm(tree) - case tree: ValDef => - val ctx = localCtx(tree) - given Context = ctx - val tpt1 = transformTypeTree(tree.tpt) - val rhs1 = tree.rhs.map(x => transformTerm(x)) - ValDef.copy(tree)(tree.name, tpt1, rhs1) - case tree: DefDef => - val ctx = localCtx(tree) - given Context = ctx - DefDef.copy(tree)(tree.name, transformSubTrees(tree.typeParams), tree.paramss mapConserve (transformSubTrees(_)), transformTypeTree(tree.returnTpt), tree.rhs.map(x => transformTerm(x))) - case tree: TypeDef => - val ctx = localCtx(tree) - given Context = ctx - TypeDef.copy(tree)(tree.name, transformTree(tree.rhs)) - case tree: ClassDef => - ClassDef.copy(tree)(tree.name, tree.constructor, tree.parents, tree.derived, tree.self, tree.body) - case tree: Import => - Import.copy(tree)(transformTerm(tree.expr), tree.selectors) - } - } - - def transformTerm(tree: Term)(using ctx: Context): Term = { - tree match { - case Ident(name) => - tree - case Select(qualifier, name) => - Select.copy(tree)(transformTerm(qualifier), name) - case This(qual) => - tree - case Super(qual, mix) => - Super.copy(tree)(transformTerm(qual), mix) - case Apply(fun, args) => - Apply.copy(tree)(transformTerm(fun), transformTerms(args)) - case TypeApply(fun, args) => - TypeApply.copy(tree)(transformTerm(fun), transformTypeTrees(args)) - case Literal(const) => - tree - case New(tpt) => - New.copy(tree)(transformTypeTree(tpt)) - case Typed(expr, tpt) => - Typed.copy(tree)(transformTerm(expr), transformTypeTree(tpt)) - case tree: NamedArg => - NamedArg.copy(tree)(tree.name, transformTerm(tree.value)) - case Assign(lhs, rhs) => - Assign.copy(tree)(transformTerm(lhs), transformTerm(rhs)) - case Block(stats, expr) => - Block.copy(tree)(transformStats(stats), transformTerm(expr)) - case If(cond, thenp, elsep) => - If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep)) - case Closure(meth, tpt) => - Closure.copy(tree)(transformTerm(meth), tpt) - case Match(selector, cases) => - Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases)) - case Return(expr, from) => - Return.copy(tree)(transformTerm(expr), from) - case While(cond, body) => - While.copy(tree)(transformTerm(cond), transformTerm(body)) - case Try(block, cases, finalizer) => - Try.copy(tree)(transformTerm(block), transformCaseDefs(cases), finalizer.map(x => transformTerm(x))) - case Repeated(elems, elemtpt) => - Repeated.copy(tree)(transformTerms(elems), transformTypeTree(elemtpt)) - case Inlined(call, bindings, expansion) => - Inlined.copy(tree)(call, transformSubTrees(bindings), transformTerm(expansion)/*()call.symbol.localContext)*/) - } - } - - def transformTypeTree(tree: TypeTree)(using ctx: Context): TypeTree = tree match { - case Inferred() => tree - case tree: TypeIdent => tree - case tree: TypeSelect => - TypeSelect.copy(tree)(tree.qualifier, tree.name) - case tree: Projection => - Projection.copy(tree)(tree.qualifier, tree.name) - case tree: Annotated => - Annotated.copy(tree)(tree.arg, tree.annotation) - case tree: Singleton => - Singleton.copy(tree)(transformTerm(tree.ref)) - case tree: Refined => - Refined.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.refinements).asInstanceOf[List[Definition]]) - case tree: Applied => - Applied.copy(tree)(transformTypeTree(tree.tpt), transformTrees(tree.args)) - case tree: MatchTypeTree => - MatchTypeTree.copy(tree)(tree.bound.map(b => transformTypeTree(b)), transformTypeTree(tree.selector), transformTypeCaseDefs(tree.cases)) - case tree: ByName => - ByName.copy(tree)(transformTypeTree(tree.result)) - case tree: LambdaTypeTree => - LambdaTypeTree.copy(tree)(transformSubTrees(tree.tparams), transformTree(tree.body)) - case tree: TypeBind => - TypeBind.copy(tree)(tree.name, tree.body) - case tree: TypeBlock => - TypeBlock.copy(tree)(tree.aliases, tree.tpt) - } - - def transformCaseDef(tree: CaseDef)(using ctx: Context): CaseDef = { - CaseDef.copy(tree)(transformTree(tree.pattern), tree.guard.map(transformTerm), transformTerm(tree.rhs)) - } - - def transformTypeCaseDef(tree: TypeCaseDef)(using ctx: Context): TypeCaseDef = { - TypeCaseDef.copy(tree)(transformTypeTree(tree.pattern), transformTypeTree(tree.rhs)) - } - - def transformStats(trees: List[Statement])(using ctx: Context): List[Statement] = - trees mapConserve (transformStatement(_)) - - def transformTrees(trees: List[Tree])(using ctx: Context): List[Tree] = - trees mapConserve (transformTree(_)) - - def transformTerms(trees: List[Term])(using ctx: Context): List[Term] = - trees mapConserve (transformTerm(_)) - - def transformTypeTrees(trees: List[TypeTree])(using ctx: Context): List[TypeTree] = - trees mapConserve (transformTypeTree(_)) - - def transformCaseDefs(trees: List[CaseDef])(using ctx: Context): List[CaseDef] = - trees mapConserve (transformCaseDef(_)) - - def transformTypeCaseDefs(trees: List[TypeCaseDef])(using ctx: Context): List[TypeCaseDef] = - trees mapConserve (transformTypeCaseDef(_)) - - def transformSubTrees[Tr <: Tree](trees: List[Tr])(using ctx: Context): List[Tr] = - transformTrees(trees).asInstanceOf[List[Tr]] - -} diff --git a/library/src/scala/tasty/reflect/TreeTraverser.scala b/library/src/scala/tasty/reflect/TreeTraverser.scala deleted file mode 100644 index 8155ec3c1cac..000000000000 --- a/library/src/scala/tasty/reflect/TreeTraverser.scala +++ /dev/null @@ -1,25 +0,0 @@ -package scala.tasty -package reflect - -/** TASTy Reflect tree traverser. - * - * Usage: - * ``` - * class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R) - * extends scala.tasty.reflect.TreeTraverser { - * import reflect._ - * override def traverseTree(tree: Tree)(using ctx: Context): Unit = ... - * } - * ``` - */ -trait TreeTraverser extends TreeAccumulator[Unit] { - - import reflect._ - - def traverseTree(tree: Tree)(using ctx: Context): Unit = traverseTreeChildren(tree) - - def foldTree(x: Unit, tree: Tree)(using ctx: Context): Unit = traverseTree(tree) - - protected def traverseTreeChildren(tree: Tree)(using ctx: Context): Unit = foldOverTree((), tree) - -} diff --git a/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala b/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala index 487f159b71fe..24b8f34608b9 100644 --- a/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala +++ b/tests/run-custom-args/Yretain-trees/tasty-extractors-owners/quoted_1.scala @@ -10,15 +10,16 @@ object Macros { val buff = new StringBuilder - val output = new MyTraverser(qctx.reflect)(buff) + val output = myTraverser(buff) val tree = x.unseal output.traverseTree(tree) '{print(${Expr(buff.result())})} } - class MyTraverser[R <: scala.tasty.Reflection & Singleton](val reflect: R)(buff: StringBuilder) extends scala.tasty.reflect.TreeTraverser { - import reflect._ + + def myTraverser(using qctx: QuoteContext)(buff: StringBuilder): qctx.reflect.TreeTraverser = new { + import qctx.reflect._ override def traverseTree(tree: Tree)(implicit ctx: Context): Unit = { tree match { case tree @ DefDef(name, _, _, _, _) =>