Skip to content

Commit fb618ad

Browse files
authored
Merge pull request #14628 from dotty-staging/fix-14623
Handle more corner cases in etaReduce
2 parents 3c5dbc3 + 8a54e2c commit fb618ad

File tree

3 files changed

+54
-13
lines changed

3 files changed

+54
-13
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+3
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,9 @@ class Definitions {
15991599
yield
16001600
nme.apply.specializedFunction(r, List(t1, t2)).asTermName
16011601

1602+
@tu lazy val FunctionSpecializedApplyNames: collection.Set[Name] =
1603+
Function0SpecializedApplyNames ++ Function1SpecializedApplyNames ++ Function2SpecializedApplyNames
1604+
16021605
def functionArity(tp: Type)(using Context): Int = tp.dropDependentRefinement.dealias.argInfos.length - 1
16031606

16041607
/** Return underlying context function type (i.e. instance of an ContextFunctionN class)

compiler/src/dotty/tools/dotc/transform/EtaReduce.scala

+36-13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import MegaPhase.MiniPhase
66
import core.*
77
import Symbols.*, Contexts.*, Types.*, Decorators.*
88
import StdNames.nme
9+
import SymUtils.*
10+
import NameKinds.AdaptedClosureName
911

1012
/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`,
1113
* provided `f` is a pure path of function type.
@@ -15,6 +17,11 @@ import StdNames.nme
1517
* where a context function is expected, unless that value has the
1618
* syntactic form of a context function literal.
1719
*
20+
* Also handle variants of eta-expansions where
21+
* - result f.apply(X_1,...,X_n) is subject to a synthetic cast, or
22+
* - the application uses a specialized apply method, or
23+
* - the closure is adapted (see Erasure#adaptClosure)
24+
*
1825
* Without this phase, when a contextual function is passed as an argument to a
1926
* recursive function, that would have the unfortunate effect of a linear growth
2027
* in transient thunks of identical type wrapped around each other, leading
@@ -27,20 +34,36 @@ class EtaReduce extends MiniPhase:
2734

2835
override def description: String = EtaReduce.description
2936

30-
override def transformBlock(tree: Block)(using Context): Tree = tree match
31-
case Block((meth : DefDef) :: Nil, closure: Closure)
32-
if meth.symbol == closure.meth.symbol =>
33-
meth.rhs match
34-
case Apply(Select(fn, nme.apply), args)
35-
if meth.paramss.head.corresponds(args)((param, arg) =>
37+
override def transformBlock(tree: Block)(using Context): Tree =
38+
39+
def tryReduce(mdef: DefDef, rhs: Tree): Tree = rhs match
40+
case Apply(Select(fn, name), args)
41+
if (name == nme.apply || defn.FunctionSpecializedApplyNames.contains(name))
42+
&& mdef.paramss.head.corresponds(args)((param, arg) =>
3643
arg.isInstanceOf[Ident] && arg.symbol == param.symbol)
37-
&& isPurePath(fn)
38-
&& fn.tpe <:< tree.tpe
39-
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
40-
report.log(i"eta reducing $tree --> $fn")
41-
fn
42-
case _ => tree
43-
case _ => tree
44+
&& isPurePath(fn)
45+
&& fn.tpe <:< tree.tpe
46+
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
47+
report.log(i"eta reducing $tree --> $fn")
48+
fn
49+
case TypeApply(Select(qual, _), _) if rhs.symbol.isTypeCast && rhs.span.isSynthetic =>
50+
tryReduce(mdef, qual)
51+
case _ =>
52+
tree
53+
54+
tree match
55+
case Block((meth: DefDef) :: Nil, expr) if meth.symbol.isAnonymousFunction =>
56+
expr match
57+
case closure: Closure if meth.symbol == closure.meth.symbol =>
58+
tryReduce(meth, meth.rhs)
59+
case Block((adapted: DefDef) :: Nil, closure: Closure)
60+
if adapted.name.is(AdaptedClosureName) && adapted.symbol == closure.meth.symbol =>
61+
tryReduce(meth, meth.rhs)
62+
case _ =>
63+
tree
64+
case _ =>
65+
tree
66+
end transformBlock
4467

4568
end EtaReduce
4669

tests/run/i14623.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
object Thunk {
2+
private[this] val impl =
3+
((x: Any) => x).asInstanceOf[(=> Any) => Function0[Any]]
4+
5+
def asFunction0[A](thunk: => A): Function0[A] = impl(thunk).asInstanceOf[Function0[A]]
6+
}
7+
8+
@main def Test =
9+
var i = 0
10+
val f1 = { () => i += 1; "" }
11+
assert(Thunk.asFunction0(f1()) eq f1)
12+
val f2 = { () => i += 1; i }
13+
assert(Thunk.asFunction0(f2()) eq f2)
14+
val f3 = { () => i += 1 }
15+
assert(Thunk.asFunction0(f3()) eq f3)

0 commit comments

Comments
 (0)