diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala index 828e6fcabfa0..51a040aec127 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala @@ -580,6 +580,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given Context): Closure = tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree)) + def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block = + tpd.Lambda(tpe, rhsFn) + type If = tpd.If def isInstanceOfIf(given ctx: Context): IsInstanceOf[If] = new { @@ -1141,17 +1144,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend def Type_isSubType(self: Type)(that: Type)(given Context): Boolean = self <:< that - /** Widen from singleton type to its underlying non-singleton - * base type by applying one or more `underlying` dereferences, - * Also go from => T to T. - * Identity for all other types. Example: - * - * class Outer { class C ; val x: C } - * def o: Outer - * .widen = o.C - */ def Type_widen(self: Type)(given Context): Type = self.widen + def Type_widenTermRefExpr(self: Type)(given Context): Type = self.widenTermRefExpr + def Type_dealias(self: Type)(given Context): Type = self.dealias def Type_simplified(self: Type)(given Context): Type = self.simplified @@ -1398,6 +1394,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend case _ => None } + def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType = + Types.MethodType(paramNames.map(_.toTermName))(paramInfosExp, resultTypeExp) + def MethodType_isErased(self: MethodType): Boolean = self.isErasedMethod def MethodType_isImplicit(self: MethodType): Boolean = self.isImplicitMethod def MethodType_paramNames(self: MethodType)(given Context): List[String] = self.paramNames.map(_.toString) diff --git a/library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala b/library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala index 204172d6e234..876d8cc02622 100644 --- a/library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala +++ b/library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala @@ -7,4 +7,50 @@ trait TreeUtils with SymbolOps with TreeOps { self: Reflection => + abstract class TreeAccumulator[X] { + def foldTree(x: X, tree: Tree)(given ctx: Context): X + def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X = + throw new Exception("non-bootstraped-library") + def foldOverTree(x: X, tree: Tree)(given ctx: Context): X = + throw new Exception("non-bootstraped-library") + } + + abstract class TreeTraverser extends TreeAccumulator[Unit] { + def traverseTree(tree: Tree)(given ctx: Context): Unit = + throw new Exception("non-bootstraped-library") + def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit = + throw new Exception("non-bootstraped-library") + protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit = + throw new Exception("non-bootstraped-library") + } + + abstract class TreeMap { self => + def transformTree(tree: Tree)(given ctx: Context): Tree = + throw new Exception("non-bootstraped-library") + def transformStatement(tree: Statement)(given ctx: Context): Statement = + throw new Exception("non-bootstraped-library") + def transformTerm(tree: Term)(given ctx: Context): Term = + throw new Exception("non-bootstraped-library") + def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree = + throw new Exception("non-bootstraped-library") + def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef = + throw new Exception("non-bootstraped-library") + def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef = + throw new Exception("non-bootstraped-library") + def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] = + throw new Exception("non-bootstraped-library") + def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] = + throw new Exception("non-bootstraped-library") + def transformTerms(trees: List[Term])(given ctx: Context): List[Term] = + throw new Exception("non-bootstraped-library") + def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] = + throw new Exception("non-bootstraped-library") + def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] = + throw new Exception("non-bootstraped-library") + def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] = + throw new Exception("non-bootstraped-library") + def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] = + throw new Exception("non-bootstraped-library") + } + } diff --git a/library/src/scala/internal/quoted/Matcher.scala b/library/src/scala/internal/quoted/Matcher.scala index e480f064b2da..b4ddfc856cea 100644 --- a/library/src/scala/internal/quoted/Matcher.scala +++ b/library/src/scala/internal/quoted/Matcher.scala @@ -10,19 +10,27 @@ private[quoted] object Matcher { class QuoteMatcher[QCtx <: QuoteContext & Singleton](given val qctx: QCtx) { // TODO improve performance + // TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros private final val debug = false import qctx.tasty.{_, given} import Matching._ - private type Env = Set[(Symbol, Symbol)] + /** A map relating equivalent symbols from the scrutinee and the pattern + * For example in + * ``` + * '{val a = 4; a * a} match case '{ val x = 4; x * x } + * ``` + * when matching `a * a` with `x * x` the enviroment will contain `Map(a -> x)`. + */ + private type Env = Map[Symbol, Symbol] inline private def withEnv[T](env: Env)(body: => (given Env) => T): T = body(given env) class SymBinding(val sym: Symbol, val fromAbove: Boolean) def termMatch(scrutineeTerm: Term, patternTerm: Term, hasTypeSplices: Boolean): Option[Tuple] = { - implicit val env: Env = Set.empty + implicit val env: Env = Map.empty if (hasTypeSplices) { implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext) val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument @@ -42,7 +50,7 @@ private[quoted] object Matcher { // TODO factor out common logic with `termMatch` def typeTreeMatch(scrutineeTypeTree: TypeTree, patternTypeTree: TypeTree, hasTypeSplices: Boolean): Option[Tuple] = { - implicit val env: Env = Set.empty + implicit val env: Env = Map.empty if (hasTypeSplices) { implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext) val matchings = scrutineeTypeTree =?= patternTypeTree @@ -138,11 +146,29 @@ private[quoted] object Matcher { matched(scrutinee.seal) // Match a scala.internal.Quoted.patternHole and return the scrutinee tree - case (scrutinee: Term, TypeApply(patternHole, tpt :: Nil)) + case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil)) if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe => matched(scrutinee.seal) + // Matches an open term and wraps it into a lambda that provides the free variables + case (scrutinee, pattern @ Apply(Select(TypeApply(patternHole, List(Inferred())), "apply"), args0 @ IdentArgs(args))) + if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole => + def bodyFn(lambdaArgs: List[Tree]): Tree = { + val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap + new TreeMap { + override def transformTerm(tree: Term)(given ctx: Context): Term = + tree match + case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) + case tree => super.transformTerm(tree) + }.transformTree(scrutinee) + } + val names = args.map(_.name) + val argTypes = args0.map(x => x.tpe.widenTermRefExpr) + val resType = pattern.tpe + val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn) + matched(res.seal) + // // Match two equivalent trees // @@ -156,7 +182,7 @@ private[quoted] object Matcher { case (scrutinee, Typed(expr2, _)) => scrutinee =?= expr2 - case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].apply((scrutinee.symbol, pattern.symbol)) => + case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) => matched case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol => @@ -165,10 +191,10 @@ private[quoted] object Matcher { case (_: Ref, _: Ref) if scrutinee.symbol == pattern.symbol => matched - case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol => + case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) => fn1 =?= fn2 && args1 =?= args2 - case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol => + case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) => fn1 =?= fn2 && args1 =?= args2 case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) => @@ -176,7 +202,13 @@ private[quoted] object Matcher { matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2) case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) => - withEnv(summon[Env] + (stat1.symbol -> stat2.symbol)) { + val newEnv = (stat1, stat2) match { + case (stat1: Definition, stat2: Definition) => + summon[Env] + (stat1.symbol -> stat2.symbol) + case _ => + summon[Env] + } + withEnv(newEnv) { stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2) } @@ -268,7 +300,7 @@ private[quoted] object Matcher { | |${pattern.showExtractors} | - | + |with environment: ${summon[Env]} | | |""".stripMargin) @@ -277,6 +309,33 @@ private[quoted] object Matcher { } end treeOps + private object ClosedPatternTerm { + /** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */ + def unapply(term: Term)(given Context, Env): Option[term.type] = + if freePatternVars(term).isEmpty then Some(term) else None + + /** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */ + def freePatternVars(term: Term)(given qctx: Context, env: Env): Set[Symbol] = + val accumulator = new TreeAccumulator[Set[Symbol]] { + def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] = + tree match + case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree) + case _ => foldOverTree(x, tree) + } + accumulator.foldTree(Set.empty, term) + } + + private object IdentArgs { + def unapply(args: List[Term])(given Context): Option[List[Ident]] = + args.foldRight(Option(List.empty[Ident])) { + case (id: Ident, Some(acc)) => Some(id :: acc) + case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc)) + if params.zip(args).forall(_.symbol == _.symbol) => + Some(id :: acc) + case _ => None + } + } + private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(given Context, Env): Matching = { (scrutinee, pattern) match { case (Some(x), Some(y)) => x =?= y @@ -344,7 +403,7 @@ private[quoted] object Matcher { | |${pattern.showExtractors} | - | + |with environment: ${summon[Env]} | | |""".stripMargin) diff --git a/library/src/scala/quoted/Expr.scala b/library/src/scala/quoted/Expr.scala index 9cc62e55cd62..45b57db95e8e 100644 --- a/library/src/scala/quoted/Expr.scala +++ b/library/src/scala/quoted/Expr.scala @@ -204,8 +204,44 @@ package quoted { val elems: Seq[Expr[_]] = tup.asInstanceOf[Product].productIterator.toSeq.asInstanceOf[Seq[Expr[_]]] ofTuple(elems).cast[Tuple.InverseMap[T, Expr]] } - } + // TODO generalize for any function arity (see Expr.betaReduce) + def open[T1, R, X](f: Expr[T1 => R])(content: (Expr[R], [t] => Expr[t] => Expr[T1] => Expr[t]) => X)(given qctx: QuoteContext): X = { + import qctx.tasty.{given, _} + val (params, bodyExpr) = paramsAndBody(f) + content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T1]) => bodyFn[t](e.unseal, params, List(v.unseal)).seal.asInstanceOf[Expr[t]]) + } + + def open[T1, T2, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit): X = { + import qctx.tasty.{given, _} + val (params, bodyExpr) = paramsAndBody(f) + content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal)).seal.asInstanceOf[Expr[t]]) + } + + def open[T1, T2, T3, R, X](f: Expr[(T1, T2, T3) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2], Expr[T3]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit, DummyImplicit): X = { + import qctx.tasty.{given, _} + val (params, bodyExpr) = paramsAndBody(f) + content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2], v3: Expr[T3]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal, v3.unseal)).seal.asInstanceOf[Expr[t]]) + } + + private def paramsAndBody[R](given qctx: QuoteContext)(f: Expr[Any]) = { + import qctx.tasty.{given, _} + val Block(List(DefDef("$anonfun", Nil, List(params), _, Some(body))), Closure(Ident("$anonfun"), None)) = f.unseal.etaExpand + (params, body.seal.asInstanceOf[Expr[R]]) + } + + private def bodyFn[t](given qctx: QuoteContext)(e: qctx.tasty.Term, params: List[qctx.tasty.ValDef], args: List[qctx.tasty.Term]): qctx.tasty.Term = { + import qctx.tasty.{given, _} + val map = params.map(_.symbol).zip(args).toMap + new TreeMap { + override def transformTerm(tree: Term)(given ctx: Context): Term = + super.transformTerm(tree) match + case tree: Ident => map.getOrElse(tree.symbol, tree) + case tree => tree + }.transformTerm(e) + } + + } } package internal { diff --git a/library/src/scala/quoted/matching/Sym.scala b/library/src/scala/quoted/matching/Sym.scala index 069b43409c0d..47fd87d3335a 100644 --- a/library/src/scala/quoted/matching/Sym.scala +++ b/library/src/scala/quoted/matching/Sym.scala @@ -8,6 +8,8 @@ package matching */ class Sym[T <: AnyKind] private[scala](val name: String, private[Sym] val id: Object) { self => + override def toString: String = s"Sym($name)@${id.hashCode}" + override def equals(obj: Any): Boolean = obj match { case obj: Sym[_] => obj.id == id case _ => false diff --git a/library/src/scala/tasty/reflect/CompilerInterface.scala b/library/src/scala/tasty/reflect/CompilerInterface.scala index 5f56f8489408..0a7c637b885a 100644 --- a/library/src/scala/tasty/reflect/CompilerInterface.scala +++ b/library/src/scala/tasty/reflect/CompilerInterface.scala @@ -443,6 +443,8 @@ trait CompilerInterface { def Closure_apply(meth: Term, tpe: Option[Type])(given ctx: Context): Closure def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given ctx: Context): Closure + def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block + /** Tree representing an if/then/else `if (...) ... else ...` in the source code */ type If <: Term @@ -810,6 +812,11 @@ trait CompilerInterface { */ def Type_widen(self: Type)(given ctx: Context): Type + /** Widen from TermRef to its underlying non-termref + * base type, while also skipping Expr types. + */ + def Type_widenTermRefExpr(self: Type)(given ctx: Context): Type + /** Follow aliases and dereferences LazyRefs, annotated types and instantiated * TypeVars until type is no longer alias type, annotated type, LazyRef, * or instantiated type variable. @@ -992,6 +999,8 @@ trait CompilerInterface { def isInstanceOfMethodType(given ctx: Context): IsInstanceOf[MethodType] + def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType + def MethodType_isErased(self: MethodType): Boolean def MethodType_isImplicit(self: MethodType): Boolean def MethodType_paramNames(self: MethodType)(given ctx: Context): List[String] diff --git a/library/src/scala/tasty/reflect/TreeOps.scala b/library/src/scala/tasty/reflect/TreeOps.scala index c015361b3662..4fc7f3af9859 100644 --- a/library/src/scala/tasty/reflect/TreeOps.scala +++ b/library/src/scala/tasty/reflect/TreeOps.scala @@ -615,6 +615,10 @@ trait TreeOps extends Core { case _ => None } + + def apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block = + internal.Lambda_apply(tpe, rhsFn) + } given (given Context): IsInstanceOf[If] = internal.isInstanceOfIf diff --git a/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala b/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala index bcaf2d747dcf..e8a063bf25d8 100644 --- a/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala +++ b/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala @@ -17,8 +17,22 @@ trait TypeOrBoundsOps extends Core { /** Is this type a subtype of that type? */ def <:<(that: Type)(given ctx: Context): Boolean = internal.Type_isSubType(self)(that) + /** Widen from singleton type to its underlying non-singleton + * base type by applying one or more `underlying` dereferences, + * Also go from => T to T. + * Identity for all other types. Example: + * + * class Outer { class C ; val x: C } + * def o: Outer + * .widen = o.C + */ def widen(given ctx: Context): Type = internal.Type_widen(self) + /** Widen from TermRef to its underlying non-termref + * base type, while also skipping `=>T` types. + */ + def widenTermRefExpr(given ctx: Context): Type = internal.Type_widenTermRefExpr(self) + /** Follow aliases and dereferences LazyRefs, annotated types and instantiated * TypeVars until type is no longer alias type, annotated type, LazyRef, * or instantiated type variable. @@ -325,6 +339,9 @@ trait TypeOrBoundsOps extends Core { def unapply(x: MethodType)(given ctx: Context): Option[MethodType] = Some(x) object MethodType { + def apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType = + internal.MethodType_apply(paramNames)(paramInfosExp, resultTypeExp) + def unapply(x: MethodType)(given ctx: Context): Option[(List[String], List[Type], Type)] = Some((x.paramNames, x.paramTypes, x.resType)) } diff --git a/tests/run-macros/quote-matcher-runtime.check b/tests/run-macros/quote-matcher-runtime.check index 6ff241b9fd61..41908e892b53 100644 --- a/tests/run-macros/quote-matcher-runtime.check +++ b/tests/run-macros/quote-matcher-runtime.check @@ -332,7 +332,7 @@ Pattern: { val x: scala.Int = 45 x.+(scala.internal.Quoted.patternHole[scala.Int]) } -Result: Some(List(Expr(a))) +Result: None Scrutinee: { lazy val a: scala.Int = 45 @@ -622,7 +622,7 @@ Pattern: { def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int] a.+(scala.internal.Quoted.patternHole[scala.Int]) } -Result: Some(List(Expr(a), Expr(a))) +Result: None Scrutinee: { lazy val a: scala.Int = a diff --git a/tests/run-macros/quote-matcher-symantics-2/quoted_1.scala b/tests/run-macros/quote-matcher-symantics-2/quoted_1.scala index 6e80bdf92422..e9977c17290c 100644 --- a/tests/run-macros/quote-matcher-symantics-2/quoted_1.scala +++ b/tests/run-macros/quote-matcher-symantics-2/quoted_1.scala @@ -11,7 +11,7 @@ object Macros { private def impl[T: Type](sym: Symantics[T], a: Expr[DSL])(given qctx: QuoteContext): Expr[T] = { - def lift(e: Expr[DSL])(implicit env: Map[Sym[DSL], Expr[T]]): Expr[T] = e match { + def lift(e: Expr[DSL])(implicit env: Map[Int, Expr[T]]): Expr[T] = e match { case '{ LitDSL(${Const(c)}) } => sym.value(c) @@ -21,23 +21,31 @@ object Macros { case '{ ($f: DSL => DSL)($x: DSL) } => sym.app(liftFun(f), lift(x)) - case '{ val $x: DSL = $value; $body: DSL } => lift(body)(env + (x -> lift(value))) + case '{ val x: DSL = $value; ($bodyFn: DSL => DSL)(x) } => + Expr.open(bodyFn) { (body1, close) => + val (i, nEnvVar) = freshEnvVar() + lift(close(body1)(nEnvVar))(env + (i -> lift(value))) + } - case Sym(b) if env.contains(b) => env(b) + case '{ envVar(${Const(i)}) } => env(i) case _ => import qctx.tasty.{_, given} - error("Expected explicit DSL", e.unseal.pos) + error("Expected explicit DSL " + e.show, e.unseal.pos) ??? } - def liftFun(e: Expr[DSL => DSL])(implicit env: Map[Sym[DSL], Expr[T]]): Expr[T => T] = e match { - case '{ ($x: DSL) => ($body: DSL) } => - sym.lam((y: Expr[T]) => lift(body)(env + (x -> y))) - + def liftFun(e: Expr[DSL => DSL])(implicit env: Map[Int, Expr[T]]): Expr[T => T] = e match { + case '{ (x: DSL) => ($bodyFn: DSL => DSL)(x) } => + sym.lam((y: Expr[T]) => + Expr.open(bodyFn) { (body1, close) => + val (i, nEnvVar) = freshEnvVar() + lift(close(body1)(nEnvVar))(env + (i -> y)) + } + ) case _ => import qctx.tasty.{_, given} - error("Expected explicit DSL => DSL", e.unseal.pos) + error("Expected explicit DSL => DSL " + e.show, e.unseal.pos) ??? } @@ -46,6 +54,13 @@ object Macros { } +def freshEnvVar()(given QuoteContext): (Int, Expr[DSL]) = { + v += 1 + (v, '{envVar(${Expr(v)})}) +} +var v = 0 +def envVar(i: Int): DSL = ??? + // // DSL in which the user write the code // diff --git a/tests/run-macros/quote-matcher-symantics-3/quoted_1.scala b/tests/run-macros/quote-matcher-symantics-3/quoted_1.scala index 761398f7a632..ea10cf1a8b1c 100644 --- a/tests/run-macros/quote-matcher-symantics-3/quoted_1.scala +++ b/tests/run-macros/quote-matcher-symantics-3/quoted_1.scala @@ -9,16 +9,20 @@ object Macros { private def impl[R[_]: Type](sym: Expr[Symantics { type Repr[X] = R[X] }], expr: Expr[Int])(given QuoteContext): Expr[R[Int]] = { - type Env = Map[Any, Any] + type Env = Map[Int, Any] given ev0 : Env = Map.empty - def envWith[T](id: Sym[T], ref: Expr[R[T]])(given env: Env): Env = + def envWith[T](id: Int, ref: Expr[R[T]])(given env: Env): Env = env.updated(id, ref) object FromEnv { - def unapply[T](id: Sym[T])(given Env): Option[Expr[R[T]]] = - summon[Env].get(id).asInstanceOf[Option[Expr[R[T]]]] // We can only add binds that have the same type as the refs + def unapply[T](e: Expr[Any])(given env: Env): Option[Expr[R[T]]] = + e match + case '{envVar[$t](${Const(id)})} => + env.get(id).asInstanceOf[Option[Expr[R[T]]]] // We can only add binds that have the same type as the refs + case _ => + None } def lift[T: Type](e: Expr[T])(given env: Env): Expr[R[T]] = ((e: Expr[Any]) match { @@ -40,17 +44,25 @@ object Macros { case '{ (if ($cond) $thenp else $elsep): $t } => '{ $sym.ifThenElse[$t](${lift(cond)}, ${lift(thenp)}, ${lift(elsep)}) }.asInstanceOf[Expr[R[T]]] - case '{ ($x0: Int) => $body: Any } => - '{ $sym.lam((x: R[Int]) => ${given Env = envWith(x0, 'x)(given env); lift(body)}).asInstanceOf[R[T]] } - case '{ ($x0: Boolean) => $body: Any } => - '{ $sym.lam((x: R[Boolean]) => ${given Env = envWith(x0, 'x)(given env); lift(body)}).asInstanceOf[R[T]] } - case '{ ($x0: Int => Int) => $body: Any } => - '{ $sym.lam((x: R[Int => Int]) => ${given Env = envWith(x0, 'x)(given env); lift(body)}).asInstanceOf[R[T]] } + case '{ (x0: Int) => ($bodyFn: Int => Any)(x0) } => + val (i, nEnvVar) = freshEnvVar[Int]() + val body2 = Expr.open(bodyFn) { (body1, close) => close(body1)(nEnvVar) } + '{ $sym.lam((x: R[Int]) => ${given Env = envWith(i, 'x)(given env); lift(body2)}).asInstanceOf[R[T]] } + + case '{ (x0: Boolean) => ($bodyFn: Boolean => Any)(x0) } => + val (i, nEnvVar) = freshEnvVar[Boolean]() + val body2 = Expr.open(bodyFn) { (body1, close) => close(body1)(nEnvVar) } + '{ $sym.lam((x: R[Boolean]) => ${given Env = envWith(i, 'x)(given env); lift(body2)}).asInstanceOf[R[T]] } + + case '{ (x0: Int => Int) => ($bodyFn: (Int => Int) => Any)(x0) } => + val (i, nEnvVar) = freshEnvVar[Int => Int]() + val body2 = Expr.open(bodyFn) { (body1, close) => close(body1)(nEnvVar) } + '{ $sym.lam((x: R[Int => Int]) => ${given Env = envWith(i, 'x)(given env); lift(body2)}).asInstanceOf[R[T]] } case '{ Symantics.fix[$t, $u]($f) } => '{ $sym.fix[$t, $u]((x: R[$t => $u]) => $sym.app(${lift(f)}, x)).asInstanceOf[R[T]] } - case Sym(FromEnv(expr)) => expr.asInstanceOf[Expr[R[T]]] + case FromEnv(expr) => expr.asInstanceOf[Expr[R[T]]] case _ => summon[QuoteContext].error("Expected explicit value but got: " + e.show, e) @@ -63,6 +75,13 @@ object Macros { } +def freshEnvVar[T: Type]()(given QuoteContext): (Int, Expr[T]) = { + v += 1 + (v, '{envVar[T](${Expr(v)})}) +} +var v = 0 +def envVar[T](i: Int): T = ??? + trait Symantics { type Repr[X] def int(x: Int): Repr[Int] diff --git a/tests/run-macros/quote-matching-open.check b/tests/run-macros/quote-matching-open.check new file mode 100644 index 000000000000..8f9095133205 --- /dev/null +++ b/tests/run-macros/quote-matching-open.check @@ -0,0 +1,8 @@ +2 +4 + +5 +6 + +9 +24 diff --git a/tests/run-macros/quote-matching-open/Macro_1.scala b/tests/run-macros/quote-matching-open/Macro_1.scala new file mode 100644 index 000000000000..a39ccf9a592e --- /dev/null +++ b/tests/run-macros/quote-matching-open/Macro_1.scala @@ -0,0 +1,15 @@ +import scala.quoted._ + +object Macro { + + inline def openTest(x: => Any): Any = ${ Macro.impl('x) } + + def impl(x: Expr[Any])(given QuoteContext): Expr[Any] = { + x match { + case '{ (x: Int) => ($body: Int => Int)(x) } => Expr.open(body) { (body, close) => close(body)(Expr(2)) } + case '{ (x1: Int, x2: Int) => ($body: (Int, Int) => Int)(x1, x2) } => Expr.open(body) { (body, close) => close(body)(Expr(2), Expr(3)) } + case '{ (x1: Int, x2: Int, x3: Int) => ($body: (Int, Int, Int) => Int)(x1, x2, x3) } => Expr.open(body) { (body, close) => close(body)(Expr(2), Expr(3), Expr(4)) } + } + } + +} diff --git a/tests/run-macros/quote-matching-open/Test_2.scala b/tests/run-macros/quote-matching-open/Test_2.scala new file mode 100644 index 000000000000..3fa96cef8726 --- /dev/null +++ b/tests/run-macros/quote-matching-open/Test_2.scala @@ -0,0 +1,15 @@ +object Test { + import Macro._ + + def main(args: Array[String]): Unit = { + println(openTest((x: Int) => x)) + println(openTest((x: Int) => x * x)) + println() + println(openTest((x1: Int, x2: Int) => x1 + x2)) + println(openTest((x1: Int, x2: Int) => x1 * x2)) + println() + println(openTest((x1: Int, x2: Int, x3: Int) => x1 + x2 + x3)) + println(openTest((x1: Int, x2: Int, x3: Int) => x1 * x2 * x3)) + } + +} diff --git a/tests/run-macros/quoted-pattern-open-expr.check b/tests/run-macros/quoted-pattern-open-expr.check new file mode 100644 index 000000000000..a7d8a335a85a --- /dev/null +++ b/tests/run-macros/quoted-pattern-open-expr.check @@ -0,0 +1,10 @@ +Matched closed +(6: scala.Int) +Matched open +((y: scala.Int) => y.*(y)) +Matched open +((y: scala.Function1[scala.Int, scala.Int]) => y.apply(3)) +Matched open +((g: scala.Function1[scala.Int, scala.Int], x: scala.Int) => 5) +Matched open +((g: scala.Function1[scala.Int, scala.Int], x: scala.Int) => x) diff --git a/tests/run-macros/quoted-pattern-open-expr/Macro_1.scala b/tests/run-macros/quoted-pattern-open-expr/Macro_1.scala new file mode 100644 index 000000000000..8880f8398fb2 --- /dev/null +++ b/tests/run-macros/quoted-pattern-open-expr/Macro_1.scala @@ -0,0 +1,12 @@ +import scala.quoted._ + +inline def test(e: Int): String = ${testExpr('e)} + +private def testExpr(e: Expr[Int])(given QuoteContext): Expr[String] = { + e match { + case '{ val y: Int = 4; $body } => Expr("Matched closed\n" + body.show) + case '{ val y: Int = 4; ($body: Int => Int)(y) } => Expr("Matched open\n" + body.show) + case '{ val y: Int => Int = x => x + 1; ($body: (Int => Int) => Int)(y) } => Expr("Matched open\n" + body.show) + case '{ def g(x: Int): Int = ($body: (Int => Int, Int) => Int)(g, x); g(5) } => Expr("Matched open\n" + body.show) + } +} diff --git a/tests/run-macros/quoted-pattern-open-expr/Test_2.scala b/tests/run-macros/quoted-pattern-open-expr/Test_2.scala new file mode 100644 index 000000000000..02a5c6bec5c6 --- /dev/null +++ b/tests/run-macros/quoted-pattern-open-expr/Test_2.scala @@ -0,0 +1,10 @@ + +object Test { + def main(args: Array[String]): Unit = { + println(test { val x: Int = 4; 6: Int }) + println(test { val x: Int = 4; x * x }) + println(test { val f: Int => Int = x => x + 1; f(3) }) + println(test { def f(x: Int): Int = 5; f(5) }) + println(test { def f(x: Int): Int = x; f(5) }) + } +}