Skip to content

Commit d03b226

Browse files
committed
Fix #4241: Follow Scalac for partial function literals
- `x => x match { case x => x }` is a PF - `x => { x match { case x => x } }` is a PF - `x => { println("foo"); x match { case x => x } }` is not a PF - `x => x` is not a PF
1 parent c8904c5 commit d03b226

File tree

3 files changed

+70
-57
lines changed

3 files changed

+70
-57
lines changed

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

+58-55
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import MegaPhase._
88
import SymUtils._
99
import ast.untpd
1010
import ast.Trees._
11+
import dotty.tools.dotc.reporting.diagnostic.messages.TypeMismatch
1112
import dotty.tools.dotc.util.Positions.Position
1213

1314
/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
@@ -51,72 +52,74 @@ class ExpandSAMs extends MiniPhase {
5152
}
5253

5354
private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = {
54-
val Block(
55-
(applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil, _) = tree
56-
57-
def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = {
58-
assert(tree.selector.symbol == param.symbol)
59-
val selectorTpe = selector.tpe.widen
60-
val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe)
61-
val defaultCase =
62-
CaseDef(
63-
Bind(defaultSym, Underscore(selectorTpe)),
64-
EmptyTree,
65-
defaultValue)
66-
val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType)))
67-
cpy.Match(tree)(unchecked, cases :+ defaultCase)
68-
.subst(param.symbol :: Nil, selector.symbol :: Nil)
69-
// Needed because a partial function can be written as:
70-
// param => param match { case "foo" if foo(param) => param }
71-
// And we need to update all references to 'param'
55+
// /** An extractor for match, either contained in a block or standalone. */
56+
object PartialFunctionRHS {
57+
def unapply(tree: Tree): Option[Match] = tree match {
58+
case Block(Nil, expr) => unapply(expr)
59+
case m: Match => Some(m)
60+
case _ => None
61+
}
7262
}
7363

74-
val applyRhs = applyDef.rhs
75-
val applyFn = applyDef.symbol
76-
77-
def overrideSym(sym: Symbol) = sym.copy(
78-
owner = applyFn.owner,
79-
flags = Synthetic | Method | Final,
80-
info = tpe.memberInfo(sym),
81-
coord = tree.pos).asTerm
82-
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
83-
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
84-
85-
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
86-
val tru = Literal(Constant(true))
87-
applyRhs match {
88-
case tree @ Match(_, cases) =>
64+
val closureDef(anon @ DefDef(_, _, List(List(param)), _, _)) = tree
65+
anon.rhs match {
66+
case PartialFunctionRHS(pf) =>
67+
val anonSym = anon.symbol
68+
69+
def overrideSym(sym: Symbol) = sym.copy(
70+
owner = anonSym.owner,
71+
flags = Synthetic | Method | Final,
72+
info = tpe.memberInfo(sym),
73+
coord = tree.pos).asTerm
74+
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
75+
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
76+
77+
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
78+
val tru = Literal(Constant(true))
8979
def translateCase(cdef: CaseDef) =
90-
cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
80+
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
9181
val paramRef = paramRefss.head.head
9282
val defaultValue = Literal(Constant(false))
93-
translateMatch(tree, paramRef, cases.map(translateCase), defaultValue)
94-
case _ =>
95-
tru
96-
}
97-
}
83+
translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue)
84+
}
9885

99-
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
100-
val List(paramRef, defaultRef) = paramRefss.head
101-
applyRhs match {
102-
case tree @ Match(_, cases) =>
86+
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
87+
val List(paramRef, defaultRef) = paramRefss.head
10388
def translateCase(cdef: CaseDef) =
104-
cdef.changeOwner(applyFn, applyOrElseFn)
89+
cdef.changeOwner(anonSym, applyOrElseFn)
10590
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
106-
translateMatch(tree, paramRef, cases.map(translateCase), defaultValue)
107-
case _ =>
108-
applyRhs
109-
.changeOwner(applyFn, applyOrElseFn)
110-
.subst(param.symbol :: Nil, paramRef.symbol :: Nil)
91+
translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue)
92+
}
93+
94+
def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = {
95+
assert(tree.selector.symbol == param.symbol)
96+
val selectorTpe = selector.tpe.widen
97+
val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe)
98+
val defaultCase =
99+
CaseDef(
100+
Bind(defaultSym, Underscore(selectorTpe)),
101+
EmptyTree,
102+
defaultValue)
103+
val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType)))
104+
cpy.Match(tree)(unchecked, cases :+ defaultCase)
105+
.subst(param.symbol :: Nil, selector.symbol :: Nil)
106+
// Needed because a partial function can be written as:
107+
// param => param match { case "foo" if foo(param) => param }
108+
// And we need to update all references to 'param'
111109
}
112-
}
113110

114-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
115-
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
111+
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
112+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
116113

117-
val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos)
118-
val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
119-
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)
114+
val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos)
115+
val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
116+
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)
117+
118+
case _ =>
119+
val found = tpe.baseType(defn.FunctionClass(1))
120+
ctx.error(TypeMismatch(found, tpe), tree.pos)
121+
tree
122+
}
120123
}
121124

122125
private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match {

tests/neg/i4241.scala

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
class Test {
2+
def test: Unit = {
3+
val a: PartialFunction[Int, Int] = { case x => x }
4+
val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case _ => 2 }
5+
val c: PartialFunction[Int, Int] = x => { x match { case y => y } }
6+
val d: PartialFunction[Int, Int] = x => { { x match { case y => y } } }
7+
8+
val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case y => y } } // error
9+
val f: PartialFunction[Int, Int] = x => x // error
10+
val g: PartialFunction[Int, String] = { x => x.toString } // error
11+
}
12+
}

tests/pos/i4177.scala

-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ class Test {
55
def test: Unit = {
66
val a: PartialFunction[Int, String] = { case Foo(x) => x }
77
val b: PartialFunction[Int, String] = { case x => x.toString }
8-
val c: PartialFunction[Int, String] = { x => x.toString }
9-
val d: PartialFunction[Int, String] = x => x.toString
108

119
val e: PartialFunction[String, String] = { case x @ "abc" => x }
1210
val f: PartialFunction[String, String] = x => x match { case "abc" => x }

0 commit comments

Comments
 (0)