Skip to content

Commit d5409fd

Browse files
committed
Combine cases of several states into a single partial function
1 parent fc6a492 commit d5409fd

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

src/async/library/scala/async/Async.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ object Async extends AsyncUtils {
4949
vprintln(s"states of current method:")
5050
asyncBlockBuilder.asyncStates foreach vprintln
5151

52-
val handlerExpr = asyncBlockBuilder.mkHandlerExpr()
52+
val handlerExpr = asyncBlockBuilder.mkCombinedHandlerExpr()
5353

5454
vprintln(s"GENERATED handler expr:")
5555
vprintln(handlerExpr)

src/async/library/scala/async/ExprBuilder.scala

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,47 +60,56 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
6060
ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
6161
}
6262

63-
def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = {
63+
def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef =
64+
CaseDef(
65+
// pattern
66+
Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))),
67+
// guard
68+
Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))),
69+
rhs
70+
)
71+
72+
def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = {
6473
val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction"))
6574
val intIdent = Ident(definitions.IntClass)
6675
val unitIdent = Ident(definitions.UnitClass)
6776

77+
val caseCheck =
78+
Apply(Select(Apply(Select(Ident(newTermName("List")), newTermName("apply")),
79+
cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1"))))
80+
6881
Block(List(
6982
// anonymous subclass of PartialFunction[Int, Unit]
7083
ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))),
7184
emptyValDef, List(
72-
7385
DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(),
7486
Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))),
75-
87+
7688
DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(),
77-
Apply(Select(Ident(newTermName("x$1")), newTermName("$eq$eq")), List(Literal(Constant(num))))),
89+
caseCheck),
7890

7991
DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(),
80-
Match(Ident(newTermName("x$1")), List(
81-
CaseDef(
82-
// pattern
83-
Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), intIdent)),
84-
// guard
85-
Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))),
86-
rhs
87-
)
88-
))
92+
Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match
8993
)
90-
9194
))
9295
)),
9396
Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List())
9497
)
9598
}
9699

100+
def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree =
101+
mkHandlerTreeFor(List(mkHandlerCase(num, rhs) -> num))
102+
97103
class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) {
98104
val body: c.Tree =
99105
if (stats.size == 1) stats.head
100106
else Block(stats: _*)
101107

102108
val varDefs: List[(c.universe.TermName, c.universe.Type)] = List()
103109

110+
def mkHandlerCaseForState(): CaseDef =
111+
mkHandlerCase(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*))
112+
104113
def mkHandlerTreeForState(): c.Tree =
105114
mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*))
106115

@@ -125,6 +134,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
125134

126135
//TODO mkHandlerTreeForState(nextState: Int)
127136

137+
override def mkHandlerCaseForState(): CaseDef =
138+
mkHandlerCase(state, Block(stats: _*))
139+
128140
override val toString: String =
129141
s"AsyncStateWithIf #$state, next = $nextState"
130142
}
@@ -267,6 +279,11 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
267279
mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*))
268280
}
269281

282+
override def mkHandlerCaseForState(): CaseDef = {
283+
assert(awaitable != null)
284+
mkHandlerCase(state, Block((stats :+ mkOnCompleteIncrStateTree): _*))
285+
}
286+
270287
override def varDefForResult: Option[c.Tree] = {
271288
val rhs =
272289
if (resultType <:< definitions.IntTpe) Literal(Constant(0))
@@ -469,6 +486,13 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
469486
val lastState = stateBuilder.complete(endState).result
470487
asyncStates += lastState
471488

489+
def mkCombinedHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = {
490+
assert(asyncStates.size > 1)
491+
492+
val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState()
493+
c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
494+
}
495+
472496
/* Builds the handler expression for a sequence of async states.
473497
*/
474498
def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = {

0 commit comments

Comments
 (0)