@@ -60,47 +60,56 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
60
60
ValDef (Modifiers (Flag .MUTABLE ), resultName, TypeTree (resultType), rhs)
61
61
}
62
62
63
- def mkHandlerTree (num : Int , rhs : c.Tree ): c.Tree = {
63
+ def mkHandlerCase (num : Int , rhs : c.Tree ): CaseDef =
64
+ CaseDef (
65
+ // pattern
66
+ Bind (newTermName(" any" ), Typed (Ident (nme.WILDCARD ), Ident (definitions.IntClass ))),
67
+ // guard
68
+ Apply (Select (Ident (newTermName(" any" )), newTermName(" $eq$eq" )), List (Literal (Constant (num)))),
69
+ rhs
70
+ )
71
+
72
+ def mkHandlerTreeFor (cases : List [(CaseDef , Int )]): c.Tree = {
64
73
val partFunIdent = Ident (c.mirror.staticClass(" scala.PartialFunction" ))
65
74
val intIdent = Ident (definitions.IntClass )
66
75
val unitIdent = Ident (definitions.UnitClass )
67
76
77
+ val caseCheck =
78
+ Apply (Select (Apply (Select (Ident (newTermName(" List" )), newTermName(" apply" )),
79
+ cases.map(p => Literal (Constant (p._2)))), newTermName(" contains" )), List (Ident (newTermName(" x$1" ))))
80
+
68
81
Block (List (
69
82
// anonymous subclass of PartialFunction[Int, Unit]
70
83
ClassDef (Modifiers (FINAL ), newTypeName(" $anon" ), List (), Template (List (AppliedTypeTree (partFunIdent, List (intIdent, unitIdent))),
71
84
emptyValDef, List (
72
-
73
85
DefDef (Modifiers (), nme.CONSTRUCTOR , List (), List (List ()), TypeTree (),
74
86
Block (List (Apply (Select (Super (This (tpnme.EMPTY ), tpnme.EMPTY ), nme.CONSTRUCTOR ), List ())), Literal (Constant (())))),
75
-
87
+
76
88
DefDef (Modifiers (), newTermName(" isDefinedAt" ), List (), List (List (ValDef (Modifiers (PARAM ), newTermName(" x$1" ), intIdent, EmptyTree ))), TypeTree (),
77
- Apply ( Select ( Ident (newTermName( " x$1 " )), newTermName( " $eq$eq " )), List ( Literal ( Constant (num)))) ),
89
+ caseCheck ),
78
90
79
91
DefDef (Modifiers (), newTermName(" apply" ), List (), List (List (ValDef (Modifiers (PARAM ), newTermName(" x$1" ), intIdent, EmptyTree ))), TypeTree (),
80
- Match (Ident (newTermName(" x$1" )), List (
81
- CaseDef (
82
- // pattern
83
- Bind (newTermName(" any" ), Typed (Ident (nme.WILDCARD ), intIdent)),
84
- // guard
85
- Apply (Select (Ident (newTermName(" any" )), newTermName(" $eq$eq" )), List (Literal (Constant (num)))),
86
- rhs
87
- )
88
- ))
92
+ Match (Ident (newTermName(" x$1" )), cases.map(_._1)) // combine all cases into a single match
89
93
)
90
-
91
94
))
92
95
)),
93
96
Apply (Select (New (Ident (newTypeName(" $anon" ))), nme.CONSTRUCTOR ), List ())
94
97
)
95
98
}
96
99
100
+ def mkHandlerTree (num : Int , rhs : c.Tree ): c.Tree =
101
+ mkHandlerTreeFor(List (mkHandlerCase(num, rhs) -> num))
102
+
97
103
class AsyncState (stats : List [c.Tree ], val state : Int , val nextState : Int ) {
98
104
val body : c.Tree =
99
105
if (stats.size == 1 ) stats.head
100
106
else Block (stats : _* )
101
107
102
108
val varDefs : List [(c.universe.TermName , c.universe.Type )] = List ()
103
109
110
+ def mkHandlerCaseForState (): CaseDef =
111
+ mkHandlerCase(state, Block ((stats :+ mkStateTree(nextState) :+ Apply (Ident (" resume" ), List ())): _* ))
112
+
104
113
def mkHandlerTreeForState (): c.Tree =
105
114
mkHandlerTree(state, Block ((stats :+ mkStateTree(nextState) :+ Apply (Ident (" resume" ), List ())): _* ))
106
115
@@ -125,6 +134,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
125
134
126
135
// TODO mkHandlerTreeForState(nextState: Int)
127
136
137
+ override def mkHandlerCaseForState (): CaseDef =
138
+ mkHandlerCase(state, Block (stats : _* ))
139
+
128
140
override val toString : String =
129
141
s " AsyncStateWithIf # $state, next = $nextState"
130
142
}
@@ -267,6 +279,11 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
267
279
mkHandlerTree(state, Block ((stats :+ mkOnCompleteStateTree(nextState)): _* ))
268
280
}
269
281
282
+ override def mkHandlerCaseForState (): CaseDef = {
283
+ assert(awaitable != null )
284
+ mkHandlerCase(state, Block ((stats :+ mkOnCompleteIncrStateTree): _* ))
285
+ }
286
+
270
287
override def varDefForResult : Option [c.Tree ] = {
271
288
val rhs =
272
289
if (resultType <:< definitions.IntTpe ) Literal (Constant (0 ))
@@ -469,6 +486,13 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
469
486
val lastState = stateBuilder.complete(endState).result
470
487
asyncStates += lastState
471
488
489
+ def mkCombinedHandlerExpr (): c.Expr [PartialFunction [Int , Unit ]] = {
490
+ assert(asyncStates.size > 1 )
491
+
492
+ val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState()
493
+ c.Expr (mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf [c.Expr [PartialFunction [Int , Unit ]]]
494
+ }
495
+
472
496
/* Builds the handler expression for a sequence of async states.
473
497
*/
474
498
def mkHandlerExpr (): c.Expr [PartialFunction [Int , Unit ]] = {
0 commit comments