Skip to content

Commit 8a54e2c

Browse files
committed
Handle more corner cases in etaReduce
Generalize eta reduce so that it also handles eta expansions that are subject to adaptation or specialization. Fixes #14623
1 parent 1b25f65 commit 8a54e2c

File tree

3 files changed

+54
-13
lines changed

3 files changed

+54
-13
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+3
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,9 @@ class Definitions {
15941594
yield
15951595
nme.apply.specializedFunction(r, List(t1, t2)).asTermName
15961596

1597+
@tu lazy val FunctionSpecializedApplyNames: collection.Set[Name] =
1598+
Function0SpecializedApplyNames ++ Function1SpecializedApplyNames ++ Function2SpecializedApplyNames
1599+
15971600
def functionArity(tp: Type)(using Context): Int = tp.dropDependentRefinement.dealias.argInfos.length - 1
15981601

15991602
/** Return underlying context function type (i.e. instance of an ContextFunctionN class)

compiler/src/dotty/tools/dotc/transform/EtaReduce.scala

+36-13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import MegaPhase.MiniPhase
66
import core.*
77
import Symbols.*, Contexts.*, Types.*, Decorators.*
88
import StdNames.nme
9+
import SymUtils.*
10+
import NameKinds.AdaptedClosureName
911

1012
/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`,
1113
* provided `f` is a pure path of function type.
@@ -15,6 +17,11 @@ import StdNames.nme
1517
* where a context function is expected, unless that value has the
1618
* syntactic form of a context function literal.
1719
*
20+
* Also handle variants of eta-expansions where
21+
* - result f.apply(X_1,...,X_n) is subject to a synthetic cast, or
22+
* - the application uses a specialized apply method, or
23+
* - the closure is adapted (see Erasure#adaptClosure)
24+
*
1825
* Without this phase, when a contextual function is passed as an argument to a
1926
* recursive function, that would have the unfortunate effect of a linear growth
2027
* in transient thunks of identical type wrapped around each other, leading
@@ -27,20 +34,36 @@ class EtaReduce extends MiniPhase:
2734

2835
override def description: String = EtaReduce.description
2936

30-
override def transformBlock(tree: Block)(using Context): Tree = tree match
31-
case Block((meth : DefDef) :: Nil, closure: Closure)
32-
if meth.symbol == closure.meth.symbol =>
33-
meth.rhs match
34-
case Apply(Select(fn, nme.apply), args)
35-
if meth.paramss.head.corresponds(args)((param, arg) =>
37+
override def transformBlock(tree: Block)(using Context): Tree =
38+
39+
def tryReduce(mdef: DefDef, rhs: Tree): Tree = rhs match
40+
case Apply(Select(fn, name), args)
41+
if (name == nme.apply || defn.FunctionSpecializedApplyNames.contains(name))
42+
&& mdef.paramss.head.corresponds(args)((param, arg) =>
3643
arg.isInstanceOf[Ident] && arg.symbol == param.symbol)
37-
&& isPurePath(fn)
38-
&& fn.tpe <:< tree.tpe
39-
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
40-
report.log(i"eta reducing $tree --> $fn")
41-
fn
42-
case _ => tree
43-
case _ => tree
44+
&& isPurePath(fn)
45+
&& fn.tpe <:< tree.tpe
46+
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
47+
report.log(i"eta reducing $tree --> $fn")
48+
fn
49+
case TypeApply(Select(qual, _), _) if rhs.symbol.isTypeCast && rhs.span.isSynthetic =>
50+
tryReduce(mdef, qual)
51+
case _ =>
52+
tree
53+
54+
tree match
55+
case Block((meth: DefDef) :: Nil, expr) if meth.symbol.isAnonymousFunction =>
56+
expr match
57+
case closure: Closure if meth.symbol == closure.meth.symbol =>
58+
tryReduce(meth, meth.rhs)
59+
case Block((adapted: DefDef) :: Nil, closure: Closure)
60+
if adapted.name.is(AdaptedClosureName) && adapted.symbol == closure.meth.symbol =>
61+
tryReduce(meth, meth.rhs)
62+
case _ =>
63+
tree
64+
case _ =>
65+
tree
66+
end transformBlock
4467

4568
end EtaReduce
4669

tests/run/i14623.scala

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
object Thunk {
2+
private[this] val impl =
3+
((x: Any) => x).asInstanceOf[(=> Any) => Function0[Any]]
4+
5+
def asFunction0[A](thunk: => A): Function0[A] = impl(thunk).asInstanceOf[Function0[A]]
6+
}
7+
8+
@main def Test =
9+
var i = 0
10+
val f1 = { () => i += 1; "" }
11+
assert(Thunk.asFunction0(f1()) eq f1)
12+
val f2 = { () => i += 1; i }
13+
assert(Thunk.asFunction0(f2()) eq f2)
14+
val f3 = { () => i += 1 }
15+
assert(Thunk.asFunction0(f3()) eq f3)

0 commit comments

Comments
 (0)