Skip to content

Commit cf4591f

Browse files
authored
Introduce boundary/break control abstraction. (#16612)
The abstractions are intended to replace the `scala.util.control.Breaks` and `scala.util.control.NonLocalReturns`. They are simpler, safer, and more performant, since there is a new MiniPhase `DropBreaks` that rewrites local breaks to labeled returns, i.e. jumps. The abstractions are not experimental. This break from usual procedure is because we need to roll them out fast. Non local returns were just deprecated in 3.2, and we proposed `NonLocalReturns.{returning,throwReturn}` as an alternative. But these APIs were a mistake and should be deprecated themselves. So rolling out boundary/break now counts as a bugfix.
2 parents 865aa63 + 69b7a48 commit cf4591f

25 files changed

+1114
-20
lines changed

compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala

+74-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package jvm
44

55
import scala.language.unsafeNulls
66

7-
import scala.annotation.switch
7+
import scala.annotation.{switch, tailrec}
88
import scala.collection.mutable.SortedMap
99

1010
import scala.tools.asm
@@ -79,9 +79,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
7979

8080
tree match {
8181
case Assign(lhs @ DesugaredSelect(qual, _), rhs) =>
82+
val savedStackHeight = stackHeight
8283
val isStatic = lhs.symbol.isStaticMember
83-
if (!isStatic) { genLoadQualifier(lhs) }
84+
if (!isStatic) {
85+
genLoadQualifier(lhs)
86+
stackHeight += 1
87+
}
8488
genLoad(rhs, symInfoTK(lhs.symbol))
89+
stackHeight = savedStackHeight
8590
lineNumber(tree)
8691
// receiverClass is used in the bytecode to access the field. using sym.owner may lead to IllegalAccessError
8792
val receiverClass = qual.tpe.typeSymbol
@@ -145,7 +150,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
145150
}
146151

147152
genLoad(larg, resKind)
153+
stackHeight += resKind.size
148154
genLoad(rarg, if (isShift) INT else resKind)
155+
stackHeight -= resKind.size
149156

150157
(code: @switch) match {
151158
case ADD => bc add resKind
@@ -182,14 +189,19 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
182189
if (isArrayGet(code)) {
183190
// load argument on stack
184191
assert(args.length == 1, s"Too many arguments for array get operation: $tree");
192+
stackHeight += 1
185193
genLoad(args.head, INT)
194+
stackHeight -= 1
186195
generatedType = k.asArrayBType.componentType
187196
bc.aload(elementType)
188197
}
189198
else if (isArraySet(code)) {
190199
val List(a1, a2) = args
200+
stackHeight += 1
191201
genLoad(a1, INT)
202+
stackHeight += 1
192203
genLoad(a2)
204+
stackHeight -= 2
193205
generatedType = UNIT
194206
bc.astore(elementType)
195207
} else {
@@ -223,7 +235,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
223235
val resKind = if (hasUnitBranch) UNIT else tpeTK(tree)
224236

225237
val postIf = new asm.Label
226-
genLoadTo(thenp, resKind, LoadDestination.Jump(postIf))
238+
genLoadTo(thenp, resKind, LoadDestination.Jump(postIf, stackHeight))
227239
markProgramPoint(failure)
228240
genLoadTo(elsep, resKind, LoadDestination.FallThrough)
229241
markProgramPoint(postIf)
@@ -482,7 +494,17 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
482494
dest match
483495
case LoadDestination.FallThrough =>
484496
()
485-
case LoadDestination.Jump(label) =>
497+
case LoadDestination.Jump(label, targetStackHeight) =>
498+
if targetStackHeight < stackHeight then
499+
val stackDiff = stackHeight - targetStackHeight
500+
if expectedType == UNIT then
501+
bc dropMany stackDiff
502+
else
503+
val loc = locals.makeTempLocal(expectedType)
504+
bc.store(loc.idx, expectedType)
505+
bc dropMany stackDiff
506+
bc.load(loc.idx, expectedType)
507+
end if
486508
bc goTo label
487509
case LoadDestination.Return =>
488510
bc emitRETURN returnType
@@ -577,7 +599,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
577599
if dest == LoadDestination.FallThrough then
578600
val resKind = tpeTK(tree)
579601
val jumpTarget = new asm.Label
580-
registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget))
602+
registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget, stackHeight))
581603
genLoad(expr, resKind)
582604
markProgramPoint(jumpTarget)
583605
resKind
@@ -635,7 +657,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
635657
markProgramPoint(loop)
636658

637659
if isInfinite then
638-
val dest = LoadDestination.Jump(loop)
660+
val dest = LoadDestination.Jump(loop, stackHeight)
639661
genLoadTo(body, UNIT, dest)
640662
dest
641663
else
@@ -650,7 +672,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
650672
val failure = new asm.Label
651673
genCond(cond, success, failure, targetIfNoJump = success)
652674
markProgramPoint(success)
653-
genLoadTo(body, UNIT, LoadDestination.Jump(loop))
675+
genLoadTo(body, UNIT, LoadDestination.Jump(loop, stackHeight))
654676
markProgramPoint(failure)
655677
end match
656678
LoadDestination.FallThrough
@@ -744,7 +766,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
744766

745767
// scala/bug#10290: qual can be `this.$outer()` (not just `this`), so we call genLoad (not just ALOAD_0)
746768
genLoad(superQual)
769+
stackHeight += 1
747770
genLoadArguments(args, paramTKs(app))
771+
stackHeight -= 1
748772
generatedType = genCallMethod(fun.symbol, InvokeStyle.Super, app.span)
749773

750774
// 'new' constructor call: Note: since constructors are
@@ -766,7 +790,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
766790
assert(classBTypeFromSymbol(ctor.owner) == rt, s"Symbol ${ctor.owner.showFullName} is different from $rt")
767791
mnode.visitTypeInsn(asm.Opcodes.NEW, rt.internalName)
768792
bc dup generatedType
793+
stackHeight += 2
769794
genLoadArguments(args, paramTKs(app))
795+
stackHeight -= 2
770796
genCallMethod(ctor, InvokeStyle.Special, app.span)
771797

772798
case _ =>
@@ -799,8 +825,12 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
799825
else if (app.hasAttachment(BCodeHelpers.UseInvokeSpecial)) InvokeStyle.Special
800826
else InvokeStyle.Virtual
801827

802-
if (invokeStyle.hasInstance) genLoadQualifier(fun)
828+
val savedStackHeight = stackHeight
829+
if invokeStyle.hasInstance then
830+
genLoadQualifier(fun)
831+
stackHeight += 1
803832
genLoadArguments(args, paramTKs(app))
833+
stackHeight = savedStackHeight
804834

805835
val DesugaredSelect(qual, name) = fun: @unchecked // fun is a Select, also checked in genLoadQualifier
806836
val isArrayClone = name == nme.clone_ && qual.tpe.widen.isInstanceOf[JavaArrayType]
@@ -858,6 +888,8 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
858888
bc iconst elems.length
859889
bc newarray elmKind
860890

891+
stackHeight += 3 // during the genLoad below, there is the result, its dup, and the index
892+
861893
var i = 0
862894
var rest = elems
863895
while (!rest.isEmpty) {
@@ -869,6 +901,8 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
869901
i = i + 1
870902
}
871903

904+
stackHeight -= 3
905+
872906
generatedType
873907
}
874908

@@ -883,7 +917,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
883917
val (generatedType, postMatch, postMatchDest) =
884918
if dest == LoadDestination.FallThrough then
885919
val postMatch = new asm.Label
886-
(tpeTK(tree), postMatch, LoadDestination.Jump(postMatch))
920+
(tpeTK(tree), postMatch, LoadDestination.Jump(postMatch, stackHeight))
887921
else
888922
(expectedType, null, dest)
889923

@@ -1160,14 +1194,21 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
11601194
}
11611195

11621196
def genLoadArguments(args: List[Tree], btpes: List[BType]): Unit =
1163-
args match
1164-
case arg :: args1 =>
1165-
btpes match
1166-
case btpe :: btpes1 =>
1167-
genLoad(arg, btpe)
1168-
genLoadArguments(args1, btpes1)
1169-
case _ =>
1170-
case _ =>
1197+
@tailrec def loop(args: List[Tree], btpes: List[BType]): Unit =
1198+
args match
1199+
case arg :: args1 =>
1200+
btpes match
1201+
case btpe :: btpes1 =>
1202+
genLoad(arg, btpe)
1203+
stackHeight += btpe.size
1204+
loop(args1, btpes1)
1205+
case _ =>
1206+
case _ =>
1207+
1208+
val savedStackHeight = stackHeight
1209+
loop(args, btpes)
1210+
stackHeight = savedStackHeight
1211+
end genLoadArguments
11711212

11721213
def genLoadModule(tree: Tree): BType = {
11731214
val module = (
@@ -1266,11 +1307,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
12661307
}.sum
12671308
bc.genNewStringBuilder(approxBuilderSize)
12681309

1310+
stackHeight += 1 // during the genLoad below, there is a reference to the StringBuilder on the stack
12691311
for (elem <- concatArguments) {
12701312
val elemType = tpeTK(elem)
12711313
genLoad(elem, elemType)
12721314
bc.genStringBuilderAppend(elemType)
12731315
}
1316+
stackHeight -= 1
1317+
12741318
bc.genStringBuilderEnd
12751319
} else {
12761320

@@ -1287,12 +1331,15 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
12871331
var totalArgSlots = 0
12881332
var countConcats = 1 // ie. 1 + how many times we spilled
12891333

1334+
val savedStackHeight = stackHeight
1335+
12901336
for (elem <- concatArguments) {
12911337
val tpe = tpeTK(elem)
12921338
val elemSlots = tpe.size
12931339

12941340
// Unlikely spill case
12951341
if (totalArgSlots + elemSlots >= MaxIndySlots) {
1342+
stackHeight = savedStackHeight + countConcats
12961343
bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result())
12971344
countConcats += 1
12981345
totalArgSlots = 0
@@ -1317,8 +1364,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
13171364
val tpe = tpeTK(elem)
13181365
argTypes += tpe.toASMType
13191366
genLoad(elem, tpe)
1367+
stackHeight += 1
13201368
}
13211369
}
1370+
stackHeight = savedStackHeight
13221371
bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result())
13231372

13241373
// If we spilled, generate one final concat
@@ -1513,7 +1562,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
15131562
} else {
15141563
val tk = tpeTK(l).maxType(tpeTK(r))
15151564
genLoad(l, tk)
1565+
stackHeight += tk.size
15161566
genLoad(r, tk)
1567+
stackHeight -= tk.size
15171568
genCJUMP(success, failure, op, tk, targetIfNoJump)
15181569
}
15191570
}
@@ -1628,7 +1679,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
16281679
}
16291680

16301681
genLoad(l, ObjectRef)
1682+
stackHeight += 1
16311683
genLoad(r, ObjectRef)
1684+
stackHeight -= 1
16321685
genCallMethod(equalsMethod, InvokeStyle.Static)
16331686
genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump)
16341687
}
@@ -1644,7 +1697,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
16441697
} else if (isNonNullExpr(l)) {
16451698
// SI-7852 Avoid null check if L is statically non-null.
16461699
genLoad(l, ObjectRef)
1700+
stackHeight += 1
16471701
genLoad(r, ObjectRef)
1702+
stackHeight -= 1
16481703
genCallMethod(defn.Any_equals, InvokeStyle.Virtual)
16491704
genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump)
16501705
} else {
@@ -1654,7 +1709,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
16541709
val lNonNull = new asm.Label
16551710

16561711
genLoad(l, ObjectRef)
1712+
stackHeight += 1
16571713
genLoad(r, ObjectRef)
1714+
stackHeight -= 1
16581715
locals.store(eqEqTempLocal)
16591716
bc dup ObjectRef
16601717
genCZJUMP(lNull, lNonNull, Primitives.EQ, ObjectRef, targetIfNoJump = lNull)

compiler/src/dotty/tools/backend/jvm/BCodeIdiomatic.scala

+10
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,16 @@ trait BCodeIdiomatic {
620620
// can-multi-thread
621621
final def drop(tk: BType): Unit = { emit(if (tk.isWideType) Opcodes.POP2 else Opcodes.POP) }
622622

623+
// can-multi-thread
624+
final def dropMany(size: Int): Unit = {
625+
var s = size
626+
while s >= 2 do
627+
emit(Opcodes.POP2)
628+
s -= 2
629+
if s > 0 then
630+
emit(Opcodes.POP)
631+
}
632+
623633
// can-multi-thread
624634
final def dup(tk: BType): Unit = { emit(if (tk.isWideType) Opcodes.DUP2 else Opcodes.DUP) }
625635

compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala

+12-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ trait BCodeSkelBuilder extends BCodeHelpers {
4545
/** The value is put on the stack, and control flows through to the next opcode. */
4646
case FallThrough
4747
/** The value is put on the stack, and control flow is transferred to the given `label`. */
48-
case Jump(label: asm.Label)
48+
case Jump(label: asm.Label, targetStackHeight: Int)
4949
/** The value is RETURN'ed from the enclosing method. */
5050
case Return
5151
/** The value is ATHROW'n. */
@@ -368,6 +368,8 @@ trait BCodeSkelBuilder extends BCodeHelpers {
368368
// used by genLoadTry() and genSynchronized()
369369
var earlyReturnVar: Symbol = null
370370
var shouldEmitCleanup = false
371+
// stack tracking
372+
var stackHeight = 0
371373
// line numbers
372374
var lastEmittedLineNr = -1
373375

@@ -504,6 +506,13 @@ trait BCodeSkelBuilder extends BCodeHelpers {
504506
loc
505507
}
506508

509+
def makeTempLocal(tk: BType): Local =
510+
assert(nxtIdx != -1, "not a valid start index")
511+
assert(tk.size > 0, "makeLocal called for a symbol whose type is Unit.")
512+
val loc = Local(tk, "temp", nxtIdx, isSynth = true)
513+
nxtIdx += tk.size
514+
loc
515+
507516
// not to be confused with `fieldStore` and `fieldLoad` which also take a symbol but a field-symbol.
508517
def store(locSym: Symbol): Unit = {
509518
val Local(tk, _, idx, _) = slots(locSym)
@@ -574,6 +583,8 @@ trait BCodeSkelBuilder extends BCodeHelpers {
574583
earlyReturnVar = null
575584
shouldEmitCleanup = false
576585

586+
stackHeight = 0
587+
577588
lastEmittedLineNr = -1
578589
}
579590

compiler/src/dotty/tools/dotc/Compiler.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ class Compiler {
8888
new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only)
8989
new ExplicitOuter, // Add accessors to outer classes from nested ones.
9090
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
91-
new StringInterpolatorOpt) :: // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
91+
new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
92+
new DropBreaks) :: // Optimize local Break throws by rewriting them
9293
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
9394
new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_`
9495
new InlinePatterns, // Remove placeholders of inlined patterns

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+6
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
102102
case _ => tree
103103
}
104104

105+
def stripTyped(tree: Tree): Tree = unsplice(tree) match
106+
case Typed(expr, _) =>
107+
stripTyped(expr)
108+
case _ =>
109+
tree
110+
105111
/** The number of arguments in an application */
106112
def numArgs(tree: Tree): Int = unsplice(tree) match {
107113
case Apply(fn, args) => numArgs(fn) + args.length

compiler/src/dotty/tools/dotc/core/Definitions.scala

+4
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,10 @@ class Definitions {
970970
def TupledFunctionClass(using Context): ClassSymbol = TupledFunctionTypeRef.symbol.asClass
971971
def RuntimeTupleFunctionsModule(using Context): Symbol = requiredModule("scala.runtime.TupledFunctions")
972972

973+
@tu lazy val boundaryModule: Symbol = requiredModule("scala.util.boundary")
974+
@tu lazy val LabelClass: Symbol = requiredClass("scala.util.boundary.Label")
975+
@tu lazy val BreakClass: Symbol = requiredClass("scala.util.boundary.Break")
976+
973977
@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
974978
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("*")
975979
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")

compiler/src/dotty/tools/dotc/core/NameKinds.scala

+2
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ object NameKinds {
325325

326326
val LocalOptInlineLocalObj: UniqueNameKind = new UniqueNameKind("ilo")
327327

328+
val BoundaryName: UniqueNameKind = new UniqueNameKind("boundary")
329+
328330
/** The kind of names of default argument getters */
329331
val DefaultGetterName: NumberedNameKind = new NumberedNameKind(DEFAULTGETTER, "DefaultGetter") {
330332
def mkString(underlying: TermName, info: ThisInfo) = {

0 commit comments

Comments
 (0)