Skip to content

Commit adbe51e

Browse files
committed
Merge pull request #50 from retronym/ticket/48
Fix crashers in do/while and while(await(..))
2 parents 93ab624 + 4fc5463 commit adbe51e

File tree

5 files changed

+73
-34
lines changed

5 files changed

+73
-34
lines changed

src/main/scala/scala/async/internal/AnfTransform.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,14 @@ private[async] trait AnfTransform {
162162
def _transformToList(tree: Tree): List[Tree] = trace(tree) {
163163
val containsAwait = tree exists isAwait
164164
if (!containsAwait) {
165-
List(tree)
165+
tree match {
166+
case Block(stats, expr) =>
167+
// avoids nested block in `while(await(false)) ...`.
168+
// TODO I think `containsAwait` really should return true if the code contains a label jump to an enclosing
169+
// while/doWhile and there is an await *anywhere* inside that construct.
170+
stats :+ expr
171+
case _ => List(tree)
172+
}
166173
} else tree match {
167174
case Select(qual, sel) =>
168175
val stats :+ expr = linearize.transformToList(qual)

src/main/scala/scala/async/internal/ExprBuilder.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ trait ExprBuilder {
127127
private var nextJumpState: Option[Int] = None
128128

129129
def +=(stat: Tree): this.type = {
130-
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
130+
stat match {
131+
case Literal(Constant(())) => // This case occurs in do/while
132+
case _ =>
133+
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
134+
}
131135
def addStat() = stats += stat
132136
stat match {
133137
case Apply(fun, Nil) =>
@@ -228,7 +232,7 @@ trait ExprBuilder {
228232
currState = afterAwaitState
229233
stateBuilder = new AsyncStateBuilder(currState, symLookup)
230234

231-
case If(cond, thenp, elsep) if stat exists isAwait =>
235+
case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) =>
232236
checkForUnsupportedAwait(cond)
233237

234238
val thenStartState = nextState()

src/main/scala/scala/async/internal/TransformUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,19 @@ private[async] trait TransformUtils {
9696
treeInfo.isExprSafeToInline(tree)
9797
}
9898

99+
// `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops;
100+
// we must break that `If` into states so that it convert the label jump into a state machine
101+
// transition
102+
final def containsForiegnLabelJump(t: Tree): Boolean = {
103+
val labelDefs = t.collect {
104+
case ld: LabelDef => ld.symbol
105+
}.toSet
106+
t.exists {
107+
case rt: RefTree => !(labelDefs contains rt.symbol)
108+
case _ => false
109+
}
110+
}
111+
99112
/** Map a list of arguments to:
100113
* - A list of argument Trees
101114
* - A list of auxillary results.

src/test/scala/scala/async/TreeInterrogation.scala

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -66,43 +66,18 @@ object TreeInterrogation extends App {
6666
withDebug {
6767
val cm = reflect.runtime.currentMirror
6868
val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid")
69-
import scala.async.internal.AsyncTestLV._
69+
import scala.async.internal.AsyncId._
7070
val tree = tb.parse(
7171
"""
72-
| import scala.async.internal.AsyncTestLV._
73-
| import scala.async.internal.AsyncTestLV
74-
|
75-
| case class MCell[T](var v: T)
76-
| val f = async { MCell(1) }
77-
|
78-
| def m1(x: MCell[Int], y: Int): Int =
79-
| async { x.v + y }
80-
| case class Cell[T](v: T)
81-
|
72+
| import scala.async.internal.AsyncId._
8273
| async {
83-
| // state #1
84-
| val a: MCell[Int] = await(f) // await$13$1
85-
| // state #2
86-
| var y = MCell(0)
87-
|
88-
| while (a.v < 10) {
89-
| // state #4
90-
| a.v = a.v + 1
91-
| y = MCell(await(a).v + 1) // await$14$1
92-
| // state #7
74+
| var b = true
75+
| while(await(b)) {
76+
| b = false
9377
| }
94-
|
95-
| // state #3
96-
| assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1"))
97-
|
98-
| val b = await(m1(a, y.v)) // await$15$1
99-
| // state #8
100-
| assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))))
101-
| assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11))))
102-
| b
78+
| await(b)
10379
| }
10480
|
105-
|
10681
| """.stripMargin)
10782
println(tree)
10883
val tree1 = tb.typeCheck(tree.duplicate)

src/test/scala/scala/async/run/ifelse0/WhileSpec.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,44 @@ class WhileSpec {
7676
}
7777
result mustBe ()
7878
}
79+
80+
@Test def doWhile() {
81+
import AsyncId._
82+
val result = async {
83+
var b = 0
84+
var x = ""
85+
await(do {
86+
x += "1"
87+
x += await("2")
88+
x += "3"
89+
b += await(1)
90+
} while (b < 2))
91+
await(x)
92+
}
93+
result mustBe "123123"
94+
}
95+
96+
@Test def whileAwaitCondition() {
97+
import AsyncId._
98+
val result = async {
99+
var b = true
100+
while(await(b)) {
101+
b = false
102+
}
103+
await(b)
104+
}
105+
result mustBe false
106+
}
107+
108+
@Test def doWhileAwaitCondition() {
109+
import AsyncId._
110+
val result = async {
111+
var b = true
112+
do {
113+
b = false
114+
} while(await(b))
115+
b
116+
}
117+
result mustBe false
118+
}
79119
}

0 commit comments

Comments
 (0)