Skip to content

Commit 84865a1

Browse files
authored
Prevent concurrent access to local breaker in rerank (#128162)
When an async operator receives a response, we can't create new blocks on the responding thread because multiple threads may adjust the local breaker simultaneously, leading to a data race. To address this, we can either use the global breaker or delay block creation in getOutput. While using the global block factory is simpler, I prefer the second option to use the local breaker when possible. Therefore, I opted to keep the results in the queue and create new blocks in getOutput. Our tests didn't catch this issue because: (1) only one block is created in the test, and (2) there is no delay when simulating the inference service. Closes #127638 Closes #127051
1 parent 28b10c3 commit 84865a1

File tree

8 files changed

+195
-91
lines changed

8 files changed

+195
-91
lines changed

test/framework/src/main/java/org/elasticsearch/indices/CrankyCircuitBreakerService.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class CrankyCircuitBreakerService extends CircuitBreakerService {
2929
*/
3030
public static final String ERROR_MESSAGE = "cranky breaker";
3131

32-
private final CircuitBreaker breaker = new CircuitBreaker() {
32+
public static final class CrankyCircuitBreaker implements CircuitBreaker {
3333
private final AtomicLong used = new AtomicLong();
3434

3535
@Override
@@ -82,7 +82,9 @@ public Durability getDurability() {
8282
public void setLimitAndOverhead(long limit, double overhead) {
8383

8484
}
85-
};
85+
}
86+
87+
private final CrankyCircuitBreaker breaker = new CrankyCircuitBreaker();
8688

8789
@Override
8890
public CircuitBreaker getBreaker(String name) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/LocalCircuitBreaker.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public final class LocalCircuitBreaker implements CircuitBreaker, Releasable {
2828
private final long maxOverReservedBytes;
2929
private long reservedBytes;
3030
private final AtomicBoolean closed = new AtomicBoolean(false);
31+
private volatile Thread activeThread;
3132

3233
public record SizeSettings(long overReservedBytes, long maxOverReservedBytes) {
3334
public SizeSettings(Settings settings) {
@@ -57,6 +58,7 @@ public void circuitBreak(String fieldName, long bytesNeeded) {
5758

5859
@Override
5960
public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
61+
assert assertSingleThread();
6062
if (bytes <= reservedBytes) {
6163
reservedBytes -= bytes;
6264
maybeReduceReservedBytes();
@@ -68,6 +70,7 @@ public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws Circu
6870

6971
@Override
7072
public void addWithoutBreaking(long bytes) {
73+
assert assertSingleThread();
7174
if (bytes <= reservedBytes) {
7275
reservedBytes -= bytes;
7376
maybeReduceReservedBytes();
@@ -130,6 +133,7 @@ public void setLimitAndOverhead(long limit, double overhead) {
130133

131134
@Override
132135
public void close() {
136+
assert assertSingleThread();
133137
if (closed.compareAndSet(false, true)) {
134138
breaker.addWithoutBreaking(-reservedBytes);
135139
}
@@ -139,4 +143,34 @@ public void close() {
139143
public String toString() {
140144
return "LocalCircuitBreaker[" + reservedBytes + "/" + overReservedBytes + ":" + maxOverReservedBytes + "]";
141145
}
146+
147+
private boolean assertSingleThread() {
148+
Thread activeThread = this.activeThread;
149+
Thread currentThread = Thread.currentThread();
150+
assert activeThread == null || activeThread == currentThread
151+
: "Local breaker must be accessed by a single thread at a time: expected ["
152+
+ activeThread
153+
+ "] != actual ["
154+
+ currentThread
155+
+ "]";
156+
return true;
157+
}
158+
159+
/**
160+
* Marks the beginning of a run loop for assertion purposes.
161+
* Sets the current thread as the only thread allowed to access this breaker.
162+
*/
163+
public boolean assertBeginRunLoop() {
164+
activeThread = Thread.currentThread();
165+
return true;
166+
}
167+
168+
/**
169+
* Marks the end of a run loop for assertion purposes.
170+
* Clears the active thread to allow other threads to access this breaker.
171+
*/
172+
public boolean assertEndRunLoop() {
173+
activeThread = null;
174+
return true;
175+
}
142176
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,13 @@ SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplie
181181
while (true) {
182182
IsBlockedResult isBlocked = Operator.NOT_BLOCKED;
183183
try {
184+
assert driverContext.assertBeginRunLoop();
184185
isBlocked = runSingleLoopIteration();
185186
} catch (DriverEarlyTerminationException unused) {
186187
closeEarlyFinishedOperators();
187188
assert isFinished() : "not finished after early termination";
189+
} finally {
190+
assert driverContext.assertEndRunLoop();
188191
}
189192
totalIterationsThisRun++;
190193
iterationsSinceLastStatusUpdate++;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.action.support.SubscribableListener;
1212
import org.elasticsearch.common.breaker.CircuitBreaker;
13-
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
1413
import org.elasticsearch.common.util.BigArrays;
1514
import org.elasticsearch.compute.data.BlockFactory;
15+
import org.elasticsearch.compute.data.LocalCircuitBreaker;
1616
import org.elasticsearch.core.Releasable;
1717
import org.elasticsearch.core.Releasables;
1818

@@ -74,14 +74,6 @@ private DriverContext(BigArrays bigArrays, BlockFactory blockFactory, WarningsMo
7474
this.warningsMode = warningsMode;
7575
}
7676

77-
public static DriverContext getLocalDriver() {
78-
return new DriverContext(
79-
BigArrays.NON_RECYCLING_INSTANCE,
80-
// TODO maybe this should have a small fixed limit?
81-
new BlockFactory(new NoopCircuitBreaker(CircuitBreaker.REQUEST), BigArrays.NON_RECYCLING_INSTANCE)
82-
);
83-
}
84-
8577
public BigArrays bigArrays() {
8678
return bigArrays;
8779
}
@@ -208,6 +200,26 @@ public enum WarningsMode {
208200
IGNORE
209201
}
210202

203+
/**
204+
* Marks the beginning of a run loop for assertion purposes.
205+
*/
206+
public boolean assertBeginRunLoop() {
207+
if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) {
208+
assert localBreaker.assertBeginRunLoop();
209+
}
210+
return true;
211+
}
212+
213+
/**
214+
* Marks the end of a run loop for assertion purposes.
215+
*/
216+
public boolean assertEndRunLoop() {
217+
if (blockFactory.breaker() instanceof LocalCircuitBreaker localBreaker) {
218+
assert localBreaker.assertEndRunLoop();
219+
}
220+
return true;
221+
}
222+
211223
private static class AsyncActions {
212224
private final SubscribableListener<Void> completion = new SubscribableListener<>();
213225
private final AtomicBoolean finished = new AtomicBoolean();

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/OperatorTestCase.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.compute.data.Block;
2424
import org.elasticsearch.compute.data.BlockFactory;
2525
import org.elasticsearch.compute.data.Page;
26+
import org.elasticsearch.compute.operator.AsyncOperator;
2627
import org.elasticsearch.compute.operator.Driver;
2728
import org.elasticsearch.compute.operator.DriverContext;
2829
import org.elasticsearch.compute.operator.DriverRunner;
@@ -247,9 +248,20 @@ public void testSimpleFinishClose() {
247248
try (var operator = simple().get(driverContext)) {
248249
assert operator.needsInput();
249250
for (Page page : input) {
250-
operator.addInput(page);
251+
if (operator.needsInput()) {
252+
operator.addInput(page);
253+
} else {
254+
page.releaseBlocks();
255+
}
251256
}
252257
operator.finish();
258+
// for async operator, we need to wait for async actions to finish.
259+
if (operator instanceof AsyncOperator<?> || randomBoolean()) {
260+
driverContext.finish();
261+
PlainActionFuture<Void> waitForAsync = new PlainActionFuture<>();
262+
driverContext.waitForAsyncActions(waitForAsync);
263+
waitForAsync.actionGet(TimeValue.timeValueSeconds(30));
264+
}
253265
}
254266
}
255267

x-pack/plugin/esql/compute/test/src/main/java/org/elasticsearch/compute/test/TestDriverFactory.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77

88
package org.elasticsearch.compute.test;
99

10+
import org.elasticsearch.common.Randomness;
11+
import org.elasticsearch.common.unit.ByteSizeValue;
12+
import org.elasticsearch.compute.data.BlockFactory;
13+
import org.elasticsearch.compute.data.LocalCircuitBreaker;
1014
import org.elasticsearch.compute.operator.Driver;
1115
import org.elasticsearch.compute.operator.DriverContext;
1216
import org.elasticsearch.compute.operator.Operator;
1317
import org.elasticsearch.compute.operator.SinkOperator;
1418
import org.elasticsearch.compute.operator.SourceOperator;
1519
import org.elasticsearch.core.Releasable;
20+
import org.elasticsearch.core.Releasables;
21+
import org.elasticsearch.indices.CrankyCircuitBreakerService;
1622

1723
import java.util.List;
1824

@@ -38,6 +44,20 @@ public static Driver create(
3844
SinkOperator sink,
3945
Releasable releasable
4046
) {
47+
// Do not wrap the local breaker for small local breakers, as the output mights not match expectations.
48+
if (driverContext.breaker() instanceof CrankyCircuitBreakerService.CrankyCircuitBreaker == false
49+
&& driverContext.breaker() instanceof LocalCircuitBreaker == false
50+
&& driverContext.breaker().getLimit() >= ByteSizeValue.ofMb(100).getBytes()
51+
&& Randomness.get().nextBoolean()) {
52+
final int overReservedBytes = Randomness.get().nextInt(1024 * 1024);
53+
final int maxOverReservedBytes = overReservedBytes + Randomness.get().nextInt(1024 * 1024);
54+
var localBreaker = new LocalCircuitBreaker(driverContext.breaker(), overReservedBytes, maxOverReservedBytes);
55+
BlockFactory localBlockFactory = driverContext.blockFactory().newChildFactory(localBreaker);
56+
driverContext = new DriverContext(localBlockFactory.bigArrays(), localBlockFactory);
57+
}
58+
if (driverContext.breaker() instanceof LocalCircuitBreaker localBreaker) {
59+
releasable = Releasables.wrap(releasable, localBreaker);
60+
}
4161
return new Driver(
4262
"unset",
4363
"test-task",

0 commit comments

Comments
 (0)