Skip to content

Commit 3bfea78

Browse files
committed
Support a limited form of rewrite unapplys
Currently only unapplys that return a tuple are supported.
1 parent f09ba1b commit 3bfea78

File tree

6 files changed

+77
-15
lines changed

6 files changed

+77
-15
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,19 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
716716
Nil
717717
}
718718

719+
/** If `tree` is an instance of `TupleN[...](e1, ..., eN)`, the arguments `e1, ..., eN`
720+
* otherwise the empty list.
721+
*/
722+
def tupleArgs(tree: Tree)(implicit ctx: Context): List[Tree] = tree match {
723+
case Block(Nil, expr) => tupleArgs(expr)
724+
case Inlined(_, Nil, expr) => tupleArgs(expr)
725+
case Apply(fn, args)
726+
if fn.symbol.name == nme.apply &&
727+
fn.symbol.owner.is(Module) &&
728+
defn.isTupleClass(fn.symbol.owner.companionClass) => args
729+
case _ => Nil
730+
}
731+
719732
/** The qualifier part of a Select or Ident.
720733
* For an Ident, this is the `This` of the current class.
721734
*/

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -732,21 +732,34 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
732732
case UnApply(unapp, _, pats) =>
733733
unapp.tpe.widen match {
734734
case mt: MethodType if mt.paramInfos.length == 1 =>
735+
736+
def reduceSubPatterns(pats: List[Tree], selectors: List[Tree]): Boolean = (pats, selectors) match {
737+
case (Nil, Nil) => true
738+
case (pat :: pats1, selector :: selectors1) =>
739+
val elem = newBinding(RewriteBinderName.fresh(), Synthetic, selector)
740+
reducePattern(bindingsBuf, elem.termRef, pat) &&
741+
reduceSubPatterns(pats1, selectors1)
742+
case _ => false
743+
}
744+
735745
val paramType = mt.paramInfos.head
736746
val paramCls = paramType.classSymbol
737-
paramCls.is(Case) && unapp.symbol.is(Synthetic) && scrut <:< paramType && {
747+
if (paramCls.is(Case) && unapp.symbol.is(Synthetic) && scrut <:< paramType) {
738748
val caseAccessors =
739749
if (paramCls.is(Scala2x)) paramCls.caseAccessors.filter(_.is(Method))
740750
else paramCls.asClass.paramAccessors
741-
var subOK = caseAccessors.length == pats.length
742-
for ((pat, accessor) <- (pats, caseAccessors).zipped)
743-
subOK = subOK && {
744-
val rhs = constToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied))
745-
val elem = newBinding(RewriteBinderName.fresh(), Synthetic, rhs)
746-
reducePattern(bindingsBuf, elem.termRef, pat)
747-
}
748-
subOK
751+
val selectors =
752+
for (accessor <- caseAccessors)
753+
yield constToLiteral(reduceProjection(ref(scrut).select(accessor).ensureApplied))
754+
caseAccessors.length == pats.length && reduceSubPatterns(pats, selectors)
755+
}
756+
else if (unapp.symbol.isRewriteMethod) {
757+
val app = untpd.Apply(untpd.TypedSplice(unapp), untpd.ref(scrut))
758+
val app1 = typer.typedExpr(app)
759+
val args = tupleArgs(app1)
760+
args.nonEmpty && reduceSubPatterns(pats, args)
749761
}
762+
else false
750763
case _ =>
751764
false
752765
}

compiler/src/dotty/tools/dotc/typer/PrepareInlineable.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ object PrepareInlineable {
243243
val typedBody =
244244
if (ctx.reporter.hasErrors) rawBody
245245
else ctx.compilationUnit.inlineAccessors.makeInlineable(rawBody)
246+
if (inlined.isRewriteMethod)
247+
checkRewriteMethod(inlined, typedBody)
246248
val inlineableBody = addReferences(inlined, originalBody, typedBody)
247249
inlining.println(i"Body to inline for $inlined: $inlineableBody")
248250
inlineableBody
@@ -251,6 +253,13 @@ object PrepareInlineable {
251253
}
252254
}
253255

256+
def checkRewriteMethod(inlined: Symbol, body: Tree)(implicit ctx: Context) = {
257+
if (inlined.name == nme.unapply && tupleArgs(body).isEmpty)
258+
ctx.warning(
259+
em"rewrite unapply method can be rewritten only if its tight hand side is a tuple (e1, ..., eN)",
260+
body.pos)
261+
}
262+
254263
/** Tweak untyped tree `original` so that all external references are typed
255264
* and it reflects the changes in the corresponding typed tree `typed` that
256265
* make `typed` inlineable. Concretely:

library/src-scala3/scala/Tuple.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,5 +254,5 @@ sealed class *:[+H, +T <: Tuple] extends Tuple {
254254
}
255255

256256
object *: {
257-
rewrite def unapply[H, T <: Tuple](x: H *: T) = Some((x.head, x.tail))
257+
rewrite def unapply[H, T <: Tuple](x: H *: T) = (x.head, x.tail)
258258
}

tests/run/tuples1.check

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,8 @@ c2_1 = (A,1,1)
3232
c2_2 = (A,1,A,1)
3333
c2_3 = (A,1,2,A,1)
3434
c3_3 = (2,A,1,2,A,1)
35-
276
3635
(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23)
36+
276
37+
(A,1) -> A, (1)
38+
(A,1) -> A, 1, ()
39+
(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23) -> 1, 2, (3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23)

tests/run/tuples1.scala

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,35 @@ object Test extends App {
4141
Int, Int, Int, Int, Int,
4242
Int, Int, Int)
4343
val x23c: T23 = x23
44+
println(x23)
45+
assert(x23(0) == 1)
46+
assert(x23(22) == 23)
47+
4448
x23 match {
4549
case (x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, x21, x22, x23) =>
4650
println(x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + x19 + x20 + x21 + x22 + x23)
4751
}
48-
println(x23)
49-
assert(x23(0) == 1)
50-
assert(x23(22) == 23)
51-
}
52+
rewrite def decompose1 = rewrite x2 match { case x *: xs => (x, xs) }
53+
rewrite def decompose2 = rewrite x2 match { case x *: y *: xs => (x, y, xs) }
54+
rewrite def decompose3 = rewrite x23 match { case x *: y *: xs => (x, y, xs) }
55+
56+
{ val (x, xs) = decompose1
57+
val xc: String = x
58+
val xsc: Int *: Unit = xs
59+
println(s"$x2 -> $x, $xs")
60+
}
61+
62+
{ val (x, y, xs) = decompose2
63+
val xc: String = x
64+
val yc: Int = y
65+
val xsc: Unit = xs
66+
println(s"$x2 -> $x, $y, $xs")
67+
}
68+
69+
{ val (x, y, xs) = decompose3
70+
val xc: Int = x
71+
val yc: Int = y
72+
val xsc: Unit = xs
73+
println(s"$x23 -> $x, $y, $xs")
74+
}
75+
}

0 commit comments

Comments
 (0)