@@ -8,6 +8,7 @@ import MegaPhase._
8
8
import SymUtils ._
9
9
import ast .untpd
10
10
import ast .Trees ._
11
+ import dotty .tools .dotc .reporting .diagnostic .messages .TypeMismatch
11
12
import dotty .tools .dotc .util .Positions .Position
12
13
13
14
/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
@@ -51,72 +52,74 @@ class ExpandSAMs extends MiniPhase {
51
52
}
52
53
53
54
private def toPartialFunction (tree : Block , tpe : Type )(implicit ctx : Context ): Tree = {
54
- val Block (
55
- (applyDef @ DefDef (nme.ANON_FUN , Nil , List (List (param)), _, _)) :: Nil , _) = tree
56
-
57
- def translateMatch (tree : Match , selector : Tree , cases : List [CaseDef ], defaultValue : Tree ) = {
58
- assert(tree.selector.symbol == param.symbol)
59
- val selectorTpe = selector.tpe.widen
60
- val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD , Synthetic , selectorTpe)
61
- val defaultCase =
62
- CaseDef (
63
- Bind (defaultSym, Underscore (selectorTpe)),
64
- EmptyTree ,
65
- defaultValue)
66
- val unchecked = Annotated (selector, New (ref(defn.UncheckedAnnotType )))
67
- cpy.Match (tree)(unchecked, cases :+ defaultCase)
68
- .subst(param.symbol :: Nil , selector.symbol :: Nil )
69
- // Needed because a partial function can be written as:
70
- // param => param match { case "foo" if foo(param) => param }
71
- // And we need to update all references to 'param'
55
+ // /** An extractor for match, either contained in a block or standalone. */
56
+ object PartialFunctionRHS {
57
+ def unapply (tree : Tree ): Option [Match ] = tree match {
58
+ case Block (Nil , expr) => unapply(expr)
59
+ case m : Match => Some (m)
60
+ case _ => None
61
+ }
72
62
}
73
63
74
- val applyRhs = applyDef.rhs
75
- val applyFn = applyDef.symbol
76
-
77
- def overrideSym ( sym : Symbol ) = sym.copy(
78
- owner = applyFn.owner,
79
- flags = Synthetic | Method | Final ,
80
- info = tpe.memberInfo(sym) ,
81
- coord = tree.pos).asTerm
82
- val isDefinedAtFn = overrideSym(defn. PartialFunction_isDefinedAt )
83
- val applyOrElseFn = overrideSym(defn. PartialFunction_applyOrElse )
84
-
85
- def isDefinedAtRhs ( paramRefss : List [ List [ Tree ]]) = {
86
- val tru = Literal ( Constant ( true ))
87
- applyRhs match {
88
- case tree @ Match (_, cases) =>
64
+ val closureDef(anon @ DefDef (_, _, List ( List (param)), _, _)) = tree
65
+ anon.rhs match {
66
+ case PartialFunctionRHS (pf) =>
67
+ val anonSym = anon.symbol
68
+
69
+ def overrideSym ( sym : Symbol ) = sym.copy(
70
+ owner = anonSym.owner ,
71
+ flags = Synthetic | Method | Final ,
72
+ info = tpe.memberInfo(sym),
73
+ coord = tree.pos).asTerm
74
+ val isDefinedAtFn = overrideSym(defn. PartialFunction_isDefinedAt )
75
+ val applyOrElseFn = overrideSym(defn. PartialFunction_applyOrElse )
76
+
77
+ def isDefinedAtRhs ( paramRefss : List [ List [ Tree ]]) = {
78
+ val tru = Literal ( Constant ( true ))
89
79
def translateCase (cdef : CaseDef ) =
90
- cpy.CaseDef (cdef)(body = tru).changeOwner(applyFn , isDefinedAtFn)
80
+ cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym , isDefinedAtFn)
91
81
val paramRef = paramRefss.head.head
92
82
val defaultValue = Literal (Constant (false ))
93
- translateMatch(tree, paramRef, cases.map(translateCase), defaultValue)
94
- case _ =>
95
- tru
96
- }
97
- }
83
+ translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue)
84
+ }
98
85
99
- def applyOrElseRhs (paramRefss : List [List [Tree ]]) = {
100
- val List (paramRef, defaultRef) = paramRefss.head
101
- applyRhs match {
102
- case tree @ Match (_, cases) =>
86
+ def applyOrElseRhs (paramRefss : List [List [Tree ]]) = {
87
+ val List (paramRef, defaultRef) = paramRefss.head
103
88
def translateCase (cdef : CaseDef ) =
104
- cdef.changeOwner(applyFn , applyOrElseFn)
89
+ cdef.changeOwner(anonSym , applyOrElseFn)
105
90
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
106
- translateMatch(tree, paramRef, cases.map(translateCase), defaultValue)
107
- case _ =>
108
- applyRhs
109
- .changeOwner(applyFn, applyOrElseFn)
110
- .subst(param.symbol :: Nil , paramRef.symbol :: Nil )
91
+ translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue)
92
+ }
93
+
94
+ def translateMatch (tree : Match , selector : Tree , cases : List [CaseDef ], defaultValue : Tree ) = {
95
+ assert(tree.selector.symbol == param.symbol)
96
+ val selectorTpe = selector.tpe.widen
97
+ val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD , Synthetic , selectorTpe)
98
+ val defaultCase =
99
+ CaseDef (
100
+ Bind (defaultSym, Underscore (selectorTpe)),
101
+ EmptyTree ,
102
+ defaultValue)
103
+ val unchecked = Annotated (selector, New (ref(defn.UncheckedAnnotType )))
104
+ cpy.Match (tree)(unchecked, cases :+ defaultCase)
105
+ .subst(param.symbol :: Nil , selector.symbol :: Nil )
106
+ // Needed because a partial function can be written as:
107
+ // param => param match { case "foo" if foo(param) => param }
108
+ // And we need to update all references to 'param'
111
109
}
112
- }
113
110
114
- val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)))
115
- val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)))
111
+ val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)))
112
+ val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)))
116
113
117
- val parent = defn.AbstractPartialFunctionType .appliedTo(tpe.argInfos)
118
- val anonCls = AnonClass (parent :: Nil , List (isDefinedAtFn, applyOrElseFn), List (nme.isDefinedAt, nme.applyOrElse))
119
- cpy.Block (tree)(List (isDefinedAtDef, applyOrElseDef), anonCls)
114
+ val parent = defn.AbstractPartialFunctionType .appliedTo(tpe.argInfos)
115
+ val anonCls = AnonClass (parent :: Nil , List (isDefinedAtFn, applyOrElseFn), List (nme.isDefinedAt, nme.applyOrElse))
116
+ cpy.Block (tree)(List (isDefinedAtDef, applyOrElseDef), anonCls)
117
+
118
+ case _ =>
119
+ val found = tpe.baseType(defn.FunctionClass (1 ))
120
+ ctx.error(TypeMismatch (found, tpe), tree.pos)
121
+ tree
122
+ }
120
123
}
121
124
122
125
private def checkRefinements (tpe : Type , pos : Position )(implicit ctx : Context ): Type = tpe.dealias match {
0 commit comments