Skip to content

Commit efac590

Browse files
authored
Fix race condition in RestCancellableNodeClient (#126703)
Today we rely on registering the channel after registering the task to be cancelled to ensure that the task is cancelled even if the channel is closed concurrently. However the client may already have processed a cancellable request on the channel and therefore this mechanism doesn't work. With this change we make sure not to register another task after draining the registrations in order to cancel them. Closes #88201 Backport of #126686 to `7.17`
1 parent 9866f21 commit efac590

File tree

4 files changed

+90
-37
lines changed

4 files changed

+90
-37
lines changed

docs/changelog/126686.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126686
2+
summary: Fix race condition in `RestCancellableNodeClient`
3+
area: Task Management
4+
type: bug
5+
issues:
6+
- 88201

qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,12 @@
1111
import org.apache.http.client.methods.HttpGet;
1212
import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction;
1313
import org.elasticsearch.client.Request;
14-
import org.elasticsearch.test.junit.annotations.TestIssueLogging;
1514

1615
public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase {
17-
@TestIssueLogging(
18-
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
19-
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
20-
+ ",org.elasticsearch.transport.TransportService:TRACE"
21-
)
2216
public void testIndicesSegmentsRestCancellation() throws Exception {
2317
runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME);
2418
}
2519

26-
@TestIssueLogging(
27-
issueUrl = "https://github.com/elastic/elasticsearch/issues/88201",
28-
value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG"
29-
+ ",org.elasticsearch.transport.TransportService:TRACE"
30-
)
3120
public void testCatSegmentsRestCancellation() throws Exception {
3221
runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME);
3322
}

server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
import org.elasticsearch.client.FilterClient;
1818
import org.elasticsearch.client.OriginSettingClient;
1919
import org.elasticsearch.client.node.NodeClient;
20+
import org.elasticsearch.core.Nullable;
2021
import org.elasticsearch.http.HttpChannel;
2122
import org.elasticsearch.tasks.CancellableTask;
2223
import org.elasticsearch.tasks.Task;
2324
import org.elasticsearch.tasks.TaskId;
2425

25-
import java.util.ArrayList;
26+
import java.util.Collection;
2627
import java.util.HashSet;
27-
import java.util.List;
2828
import java.util.Map;
2929
import java.util.Set;
3030
import java.util.concurrent.ConcurrentHashMap;
@@ -111,12 +111,14 @@ private void cancelTask(TaskId taskId) {
111111

112112
private class CloseListener implements ActionListener<Void> {
113113
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
114-
private final Set<TaskId> tasks = new HashSet<>();
114+
115+
@Nullable // if already drained
116+
private Set<TaskId> tasks = new HashSet<>();
115117

116118
CloseListener() {}
117119

118120
synchronized int getNumTasks() {
119-
return tasks.size();
121+
return tasks == null ? 0 : tasks.size();
120122
}
121123

122124
void maybeRegisterChannel(HttpChannel httpChannel) {
@@ -129,16 +131,23 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
129131
}
130132
}
131133

132-
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
133-
taskHolder.taskId = taskId;
134-
if (taskHolder.completed == false) {
135-
this.tasks.add(taskId);
134+
void registerTask(TaskHolder taskHolder, TaskId taskId) {
135+
synchronized (this) {
136+
taskHolder.taskId = taskId;
137+
if (tasks != null) {
138+
if (taskHolder.completed == false) {
139+
tasks.add(taskId);
140+
}
141+
return;
142+
}
136143
}
144+
// else tasks == null so the channel is already closed
145+
cancelTask(taskId);
137146
}
138147

139148
synchronized void unregisterTask(TaskHolder taskHolder) {
140-
if (taskHolder.taskId != null) {
141-
this.tasks.remove(taskHolder.taskId);
149+
if (taskHolder.taskId != null && tasks != null) {
150+
tasks.remove(taskHolder.taskId);
142151
}
143152
taskHolder.completed = true;
144153
}
@@ -148,18 +157,20 @@ public void onResponse(Void aVoid) {
148157
final HttpChannel httpChannel = channel.get();
149158
assert httpChannel != null : "channel not registered";
150159
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
151-
CloseListener closeListener = httpChannels.remove(httpChannel);
152-
assert closeListener != null : "channel not found in the map of tracked channels";
153-
final List<TaskId> toCancel;
154-
synchronized (this) {
155-
toCancel = new ArrayList<>(tasks);
156-
tasks.clear();
157-
}
158-
for (TaskId taskId : toCancel) {
160+
final CloseListener closeListener = httpChannels.remove(httpChannel);
161+
assert closeListener != null : "channel not found in the map of tracked channels: " + httpChannel;
162+
assert closeListener == CloseListener.this : "channel had a different CloseListener registered: " + httpChannel;
163+
for (final TaskId taskId : drainTasks()) {
159164
cancelTask(taskId);
160165
}
161166
}
162167

168+
private synchronized Collection<TaskId> drainTasks() {
169+
final Collection<TaskId> drained = tasks;
170+
tasks = null;
171+
return drained;
172+
}
173+
163174
@Override
164175
public void onFailure(Exception e) {
165176
onResponse(null);

server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.net.InetSocketAddress;
3434
import java.util.ArrayList;
3535
import java.util.Collections;
36+
import java.util.HashSet;
3637
import java.util.List;
3738
import java.util.Set;
3839
import java.util.concurrent.CopyOnWriteArraySet;
@@ -43,6 +44,7 @@
4344
import java.util.concurrent.atomic.AtomicInteger;
4445
import java.util.concurrent.atomic.AtomicLong;
4546
import java.util.concurrent.atomic.AtomicReference;
47+
import java.util.function.LongSupplier;
4648

4749
public class RestCancellableNodeClientTests extends ESTestCase {
4850

@@ -150,8 +152,42 @@ public void testChannelAlreadyClosed() {
150152
}
151153
}
152154

155+
public void testConcurrentExecuteAndClose() {
156+
final TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true);
157+
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
158+
int numTasks = randomIntBetween(1, 30);
159+
TestHttpChannel channel = new TestHttpChannel();
160+
final CountDownLatch startLatch = new CountDownLatch(1);
161+
final CountDownLatch doneLatch = new CountDownLatch(numTasks + 1);
162+
final Set<TaskId> expectedTasks = new HashSet<>(numTasks);
163+
for (int j = 0; j < numTasks; j++) {
164+
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
165+
threadPool.generic().execute(() -> {
166+
client.execute(SearchAction.INSTANCE, new SearchRequest(), ActionListener.wrap(ESTestCase::fail));
167+
startLatch.countDown();
168+
doneLatch.countDown();
169+
});
170+
expectedTasks.add(new TaskId(testClient.getLocalNodeId(), j));
171+
}
172+
threadPool.generic().execute(() -> {
173+
try {
174+
safeAwait(startLatch);
175+
channel.awaitClose();
176+
} catch (InterruptedException e) {
177+
Thread.currentThread().interrupt();
178+
throw new AssertionError(e);
179+
} finally {
180+
doneLatch.countDown();
181+
}
182+
});
183+
safeAwait(doneLatch);
184+
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
185+
assertEquals(expectedTasks, testClient.cancelledTasks);
186+
}
187+
153188
private static class TestClient extends NodeClient {
154-
private final AtomicLong counter = new AtomicLong(0);
189+
private final LongSupplier searchTaskIdGenerator = new AtomicLong(0)::getAndIncrement;
190+
private final LongSupplier cancelTaskIdGenerator = new AtomicLong(1000)::getAndIncrement;
155191
private final Set<TaskId> cancelledTasks = new CopyOnWriteArraySet<>();
156192
private final AtomicInteger searchRequests = new AtomicInteger(0);
157193
private final boolean timeout;
@@ -171,7 +207,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
171207
case CancelTasksAction.NAME:
172208
CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request;
173209
assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTaskId()));
174-
Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap());
210+
Task task = request.createTask(
211+
cancelTaskIdGenerator.getAsLong(),
212+
"cancel_task",
213+
action.name(),
214+
null,
215+
Collections.emptyMap()
216+
);
175217
if (randomBoolean()) {
176218
listener.onResponse(null);
177219
} else {
@@ -182,7 +224,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
182224
return task;
183225
case SearchAction.NAME:
184226
searchRequests.incrementAndGet();
185-
Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap());
227+
Task searchTask = request.createTask(
228+
searchTaskIdGenerator.getAsLong(),
229+
"search",
230+
action.name(),
231+
null,
232+
Collections.emptyMap()
233+
);
186234
if (timeout == false) {
187235
if (rarely()) {
188236
// make sure that search is sometimes also called from the same thread before the task is returned
@@ -193,7 +241,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> Task exe
193241
}
194242
return searchTask;
195243
default:
196-
throw new UnsupportedOperationException();
244+
throw new AssertionError("unexpected action " + action.name());
197245
}
198246

199247
}
@@ -224,9 +272,7 @@ public InetSocketAddress getRemoteAddress() {
224272

225273
@Override
226274
public void close() {
227-
if (open.compareAndSet(true, false) == false) {
228-
throw new IllegalStateException("channel already closed!");
229-
}
275+
assertTrue("HttpChannel is already closed", open.compareAndSet(true, false));
230276
ActionListener<Void> listener = closeListener.get();
231277
if (listener != null) {
232278
boolean failure = randomBoolean();
@@ -242,6 +288,7 @@ public void close() {
242288
}
243289

244290
private void awaitClose() throws InterruptedException {
291+
assertNotNull("must set closeListener before calling awaitClose", closeListener.get());
245292
close();
246293
closeLatch.await();
247294
}
@@ -258,7 +305,7 @@ public void addCloseListener(ActionListener<Void> listener) {
258305
listener.onResponse(null);
259306
} else {
260307
if (closeListener.compareAndSet(null, listener) == false) {
261-
throw new IllegalStateException("close listener already set, only one is allowed!");
308+
throw new AssertionError("close listener already set, only one is allowed!");
262309
}
263310
}
264311
}

0 commit comments

Comments
 (0)