Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS CRT HTTP Client",
"contributor": "",
"description": "Fixed a thread safety issue that could cause application to crash in the edge case where the SDK attempted to invoke `incrementWindow` after the stream is closed in AWS CRT HTTP Client."
}
6 changes: 6 additions & 0 deletions bom-internal/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>nl.jqno.equalsverifier</groupId>
<artifactId>equalsverifier</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion http-clients/aws-crt-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<artifactId>mockito-inline</artifactId>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows us to mock final classes/methods

<scope>test</scope>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.crt.CRT;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
Expand All @@ -46,19 +47,29 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler {
private final HttpClientConnection connection;
private final CompletableFuture<Void> completionFuture;
private final SdkAsyncHttpResponseHandler responseHandler;
private final SimplePublisher<ByteBuffer> responsePublisher = new SimplePublisher<>();
private final SimplePublisher<ByteBuffer> responsePublisher;

private final SdkHttpResponse.Builder responseBuilder;
private final ResponseHandlerHelper responseHandlerHelper;

private CrtResponseAdapter(HttpClientConnection connection,
CompletableFuture<Void> completionFuture,
SdkAsyncHttpResponseHandler responseHandler) {
this(connection, completionFuture, responseHandler, new SimplePublisher<>());
}


@SdkTestInternalApi
public CrtResponseAdapter(HttpClientConnection connection,
CompletableFuture<Void> completionFuture,
SdkAsyncHttpResponseHandler responseHandler,
SimplePublisher<ByteBuffer> simplePublisher) {
this.connection = Validate.paramNotNull(connection, "connection");
this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture");
this.responseHandler = Validate.paramNotNull(responseHandler, "responseHandler");
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, connection);
this.responsePublisher = simplePublisher;
}

public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection crtConn,
Expand Down Expand Up @@ -95,9 +106,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {
return;
}

if (!responseHandlerHelper.connectionClosed().get()) {
stream.incrementWindow(bodyBytesIn.length);
}
responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length);
});

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.crt.CRT;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
Expand All @@ -42,7 +43,7 @@
public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpStreamResponseHandler {
private static final Logger log = Logger.loggerFor(InputStreamAdaptingHttpStreamResponseHandler.class);
private volatile AbortableInputStreamSubscriber inputStreamSubscriber;
private final SimplePublisher<ByteBuffer> simplePublisher = new SimplePublisher<>();
private final SimplePublisher<ByteBuffer> simplePublisher;

private final CompletableFuture<SdkHttpFullResponse> requestCompletionFuture;
private final HttpClientConnection crtConn;
Expand All @@ -52,10 +53,18 @@ public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpS

public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn,
CompletableFuture<SdkHttpFullResponse> requestCompletionFuture) {
this(crtConn, requestCompletionFuture, new SimplePublisher<>());
}

@SdkTestInternalApi
public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn,
CompletableFuture<SdkHttpFullResponse> requestCompletionFuture,
SimplePublisher<ByteBuffer> simplePublisher) {
this.crtConn = crtConn;
this.requestCompletionFuture = requestCompletionFuture;
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, crtConn);
this.simplePublisher = simplePublisher;
}

@Override
Expand Down Expand Up @@ -101,11 +110,8 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {
failFutureAndCloseConnection(stream, failure);
return;
}

if (!responseHandlerHelper.connectionClosed().get()) {
// increment the window upon buffer consumption.
stream.incrementWindow(bodyBytesIn.length);
}
// increment the window upon buffer consumption.
responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length);
});

// Window will be incremented after the subscriber consumes the data, returning 0 here to disable it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package software.amazon.awssdk.http.crt.internal.response;

import java.util.concurrent.atomic.AtomicBoolean;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpHeader;
Expand All @@ -30,14 +29,16 @@
*
* CRT connection will only be closed, i.e., not reused, in one of the following conditions:
* 1. 5xx server error OR
* 2. It fails to read the response.
* 2. It fails to read the response OR
* 3. the response stream is closed/aborted by the caller.
*/
@SdkInternalApi
public class ResponseHandlerHelper {

private final SdkHttpResponse.Builder responseBuilder;
private final HttpClientConnection connection;
private AtomicBoolean connectionClosed = new AtomicBoolean(false);
private boolean connectionClosed;
private final Object lock = new Object();

public ResponseHandlerHelper(SdkHttpResponse.Builder responseBuilder, HttpClientConnection connection) {
this.responseBuilder = responseBuilder;
Expand All @@ -57,20 +58,34 @@ public void onResponseHeaders(HttpStream stream, int responseStatusCode, int hea
* Release the connection back to the pool so that it can be reused.
*/
public void releaseConnection(HttpStream stream) {
if (connectionClosed.compareAndSet(false, true)) {
connection.close();
stream.close();
synchronized (lock) {
if (!connectionClosed) {
connectionClosed = true;
connection.close();
stream.close();
}
}
}

public void incrementWindow(HttpStream stream, int windowSize) {
Copy link
Contributor

@joviegas joviegas Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please confirm if we have tested and confirmed that there is not latency because of this synchronized check in multi threaded environment with multiple requests going on at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't anticipate any latency change as part of this change since most of times, there should only be one thread accessing those methods. The only exception is that when the request fails, it may be possible one thread is trying to invoke incrementWindow and the other is closing it.

I'll double check our benchmarks after today's release.

synchronized (lock) {
if (!connectionClosed) {
stream.incrementWindow(windowSize);
}
}
}

/**
* Close the connection completely
*/
public void closeConnection(HttpStream stream) {
if (connectionClosed.compareAndSet(false, true)) {
connection.shutdown();
connection.close();
stream.close();
synchronized (lock) {
if (!connectionClosed) {
connectionClosed = true;
connection.shutdown();
connection.close();
stream.close();
}
}
}

Expand All @@ -82,8 +97,4 @@ public void cleanUpConnectionBasedOnStatusCode(HttpStream stream) {
releaseConnection(stream);
}
}

public AtomicBoolean connectionClosed() {
return connectionClosed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,34 @@
package software.amazon.awssdk.http.crt.internal;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
import software.amazon.awssdk.crt.http.HttpHeader;
import software.amazon.awssdk.crt.http.HttpHeaderBlock;
import software.amazon.awssdk.crt.http.HttpStream;
import software.amazon.awssdk.crt.http.HttpStreamResponseHandler;
import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler;
import software.amazon.awssdk.utils.async.SimplePublisher;

@ExtendWith(MockitoExtension.class)
public abstract class BaseHttpStreamResponseHandlerTest {
Expand All @@ -44,10 +53,15 @@ public abstract class BaseHttpStreamResponseHandlerTest {
@Mock
HttpStream httpStream;

@Mock
SimplePublisher<ByteBuffer> simplePublisher;

HttpStreamResponseHandler responseHandler;

abstract HttpStreamResponseHandler responseHandler();

abstract HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher<ByteBuffer> simplePublisher);

@BeforeEach
public void setUp() {
requestFuture = new CompletableFuture<>();
Expand Down Expand Up @@ -113,6 +127,101 @@ void streamClosed_shouldNotIncreaseStreamWindow() throws InterruptedException {
verify(httpStream, never()).incrementWindow(anyInt());
}

@Test
void publisherWritesFutureFails_shouldShutdownConnection() {
SimplePublisher<ByteBuffer> simplePublisher = Mockito.mock(SimplePublisher.class);
CompletableFuture<Void> future = new CompletableFuture<>();
when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);

HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);
HttpHeader[] httpHeaders = getHttpHeaders();

handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);
handler.onResponseHeadersDone(httpStream, 0);
handler.onResponseBody(httpStream,
RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));
RuntimeException runtimeException = new RuntimeException();
future.completeExceptionally(runtimeException);

try {
requestFuture.join();
} catch (Exception e) {
// we don't verify here because it behaves differently in async and sync
}

verify(crtConn).shutdown();
verify(crtConn).close();
verify(httpStream).close();
verify(httpStream, never()).incrementWindow(anyInt());
}

@Test
void publisherWritesFutureCompletesAfterConnectionClosed_shouldNotInvokeIncrementWindow() {
CompletableFuture<Void> future = new CompletableFuture<>();
when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);
when(simplePublisher.complete()).thenReturn(future);

HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);


HttpHeader[] httpHeaders = getHttpHeaders();

handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);
handler.onResponseHeadersDone(httpStream, 0);
handler.onResponseBody(httpStream,
RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));
handler.onResponseComplete(httpStream, 0);
future.complete(null);

requestFuture.join();
verify(crtConn, never()).shutdown();
verify(crtConn).close();
verify(httpStream).close();
verify(httpStream, never()).incrementWindow(anyInt());
}

@Test
void publisherWritesFutureCompletesWhenConnectionClosed_shouldNotInvokeIncrementWindow() {
CompletableFuture<Void> future = new CompletableFuture<>();
when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);
when(simplePublisher.complete()).thenReturn(future);

HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);


HttpHeader[] httpHeaders = getHttpHeaders();

handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);
handler.onResponseHeadersDone(httpStream, 0);
handler.onResponseBody(httpStream,
RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));

// This tracker tracks which of the two operation completes first
AtomicInteger whenCompleteTracker = new AtomicInteger(0);
CompletableFuture<Void> onResponseComplete = CompletableFuture.runAsync(() -> handler.onResponseComplete(httpStream, 0))
.whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 1));

CompletableFuture<Void> writeComplete = CompletableFuture.runAsync(() -> future.complete(null))
.whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 2));
requestFuture.join();

CompletableFuture.allOf(onResponseComplete, writeComplete).join();

if (whenCompleteTracker.get() == 1) {
// onResponseComplete finishes first
verify(httpStream, never()).incrementWindow(anyInt());
} else {
verify(httpStream).incrementWindow(anyInt());
}

verify(crtConn, never()).shutdown();
verify(crtConn).close();
verify(httpStream).close();
}

static HttpHeader[] getHttpHeaders() {
HttpHeader[] httpHeaders = new HttpHeader[1];
httpHeaders[0] = new HttpHeader("Content-Length", "1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.http.crt.internal.response.CrtResponseAdapter;
import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler;
import software.amazon.awssdk.utils.async.SimplePublisher;

public class CrtResponseHandlerTest extends BaseHttpStreamResponseHandlerTest {

Expand All @@ -53,6 +54,15 @@ HttpStreamResponseHandler responseHandler() {
return CrtResponseAdapter.toCrtResponseHandler(crtConn, requestFuture, responseHandler);
}

@Override
HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher<ByteBuffer> simplePublisher) {
AsyncResponseHandler<Void> responseHandler = new AsyncResponseHandler<>((response,
executionAttributes) -> null, Function.identity(), new ExecutionAttributes());

responseHandler.prepare();
return new CrtResponseAdapter(crtConn, requestFuture, responseHandler, simplePublisher);
}

@Test
void publisherFailedToDeliverEvents_shouldShutDownConnection() {
SdkAsyncHttpResponseHandler responseHandler = new TestAsyncHttpResponseHandler();
Expand Down
Loading