diff --git a/.changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json b/.changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json
new file mode 100644
index 000000000000..31b52d8a3707
--- /dev/null
+++ b/.changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json
@@ -0,0 +1,6 @@
+{
+ "type": "bugfix",
+ "category": "AWS CRT HTTP Client",
+ "contributor": "",
+ "description": "Fixed a thread safety issue that could cause application to crash in the edge case where the SDK attempted to invoke `incrementWindow` after the stream is closed in AWS CRT HTTP Client."
+}
diff --git a/bom-internal/pom.xml b/bom-internal/pom.xml
index bb8a9386ad21..5138b01d3bf5 100644
--- a/bom-internal/pom.xml
+++ b/bom-internal/pom.xml
@@ -255,6 +255,12 @@
${mockito.version}
test
+
+ org.mockito
+ mockito-inline
+ ${mockito.version}
+ test
+
nl.jqno.equalsverifier
equalsverifier
diff --git a/http-clients/aws-crt-client/pom.xml b/http-clients/aws-crt-client/pom.xml
index a6fe67926eaa..77debaee6bea 100644
--- a/http-clients/aws-crt-client/pom.xml
+++ b/http-clients/aws-crt-client/pom.xml
@@ -105,7 +105,7 @@
org.mockito
- mockito-core
+ mockito-inline
test
diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java
index bfb75050e55a..d2f246336ca8 100644
--- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java
+++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java
@@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.crt.CRT;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
@@ -46,7 +47,7 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler {
private final HttpClientConnection connection;
private final CompletableFuture completionFuture;
private final SdkAsyncHttpResponseHandler responseHandler;
- private final SimplePublisher responsePublisher = new SimplePublisher<>();
+ private final SimplePublisher responsePublisher;
private final SdkHttpResponse.Builder responseBuilder;
private final ResponseHandlerHelper responseHandlerHelper;
@@ -54,11 +55,21 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler {
private CrtResponseAdapter(HttpClientConnection connection,
CompletableFuture completionFuture,
SdkAsyncHttpResponseHandler responseHandler) {
+ this(connection, completionFuture, responseHandler, new SimplePublisher<>());
+ }
+
+
+ @SdkTestInternalApi
+ public CrtResponseAdapter(HttpClientConnection connection,
+ CompletableFuture completionFuture,
+ SdkAsyncHttpResponseHandler responseHandler,
+ SimplePublisher simplePublisher) {
this.connection = Validate.paramNotNull(connection, "connection");
this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture");
this.responseHandler = Validate.paramNotNull(responseHandler, "responseHandler");
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, connection);
+ this.responsePublisher = simplePublisher;
}
public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection crtConn,
@@ -95,9 +106,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {
return;
}
- if (!responseHandlerHelper.connectionClosed().get()) {
- stream.incrementWindow(bodyBytesIn.length);
- }
+ responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length);
});
return 0;
diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java
index 939405db38b8..b6b95307722e 100644
--- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java
+++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java
@@ -20,6 +20,7 @@
import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.crt.CRT;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
@@ -42,7 +43,7 @@
public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpStreamResponseHandler {
private static final Logger log = Logger.loggerFor(InputStreamAdaptingHttpStreamResponseHandler.class);
private volatile AbortableInputStreamSubscriber inputStreamSubscriber;
- private final SimplePublisher simplePublisher = new SimplePublisher<>();
+ private final SimplePublisher simplePublisher;
private final CompletableFuture requestCompletionFuture;
private final HttpClientConnection crtConn;
@@ -52,10 +53,18 @@ public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpS
public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn,
CompletableFuture requestCompletionFuture) {
+ this(crtConn, requestCompletionFuture, new SimplePublisher<>());
+ }
+
+ @SdkTestInternalApi
+ public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn,
+ CompletableFuture requestCompletionFuture,
+ SimplePublisher simplePublisher) {
this.crtConn = crtConn;
this.requestCompletionFuture = requestCompletionFuture;
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, crtConn);
+ this.simplePublisher = simplePublisher;
}
@Override
@@ -101,11 +110,8 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {
failFutureAndCloseConnection(stream, failure);
return;
}
-
- if (!responseHandlerHelper.connectionClosed().get()) {
- // increment the window upon buffer consumption.
- stream.incrementWindow(bodyBytesIn.length);
- }
+ // increment the window upon buffer consumption.
+ responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length);
});
// Window will be incremented after the subscriber consumes the data, returning 0 here to disable it.
diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java
index 69665d1aeff9..8aa4037df88b 100644
--- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java
+++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java
@@ -15,7 +15,6 @@
package software.amazon.awssdk.http.crt.internal.response;
-import java.util.concurrent.atomic.AtomicBoolean;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpHeader;
@@ -30,14 +29,16 @@
*
* CRT connection will only be closed, i.e., not reused, in one of the following conditions:
* 1. 5xx server error OR
- * 2. It fails to read the response.
+ * 2. It fails to read the response OR
+ * 3. the response stream is closed/aborted by the caller.
*/
@SdkInternalApi
public class ResponseHandlerHelper {
private final SdkHttpResponse.Builder responseBuilder;
private final HttpClientConnection connection;
- private AtomicBoolean connectionClosed = new AtomicBoolean(false);
+ private boolean connectionClosed;
+ private final Object lock = new Object();
public ResponseHandlerHelper(SdkHttpResponse.Builder responseBuilder, HttpClientConnection connection) {
this.responseBuilder = responseBuilder;
@@ -57,9 +58,20 @@ public void onResponseHeaders(HttpStream stream, int responseStatusCode, int hea
* Release the connection back to the pool so that it can be reused.
*/
public void releaseConnection(HttpStream stream) {
- if (connectionClosed.compareAndSet(false, true)) {
- connection.close();
- stream.close();
+ synchronized (lock) {
+ if (!connectionClosed) {
+ connectionClosed = true;
+ connection.close();
+ stream.close();
+ }
+ }
+ }
+
+ public void incrementWindow(HttpStream stream, int windowSize) {
+ synchronized (lock) {
+ if (!connectionClosed) {
+ stream.incrementWindow(windowSize);
+ }
}
}
@@ -67,10 +79,13 @@ public void releaseConnection(HttpStream stream) {
* Close the connection completely
*/
public void closeConnection(HttpStream stream) {
- if (connectionClosed.compareAndSet(false, true)) {
- connection.shutdown();
- connection.close();
- stream.close();
+ synchronized (lock) {
+ if (!connectionClosed) {
+ connectionClosed = true;
+ connection.shutdown();
+ connection.close();
+ stream.close();
+ }
}
}
@@ -82,8 +97,4 @@ public void cleanUpConnectionBasedOnStatusCode(HttpStream stream) {
releaseConnection(stream);
}
}
-
- public AtomicBoolean connectionClosed() {
- return connectionClosed;
- }
}
diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java
index 10c51b1a6028..f9b20742613f 100644
--- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java
+++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java
@@ -16,18 +16,25 @@
package software.amazon.awssdk.http.crt.internal;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
@@ -35,6 +42,8 @@
import software.amazon.awssdk.crt.http.HttpHeaderBlock;
import software.amazon.awssdk.crt.http.HttpStream;
import software.amazon.awssdk.crt.http.HttpStreamResponseHandler;
+import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler;
+import software.amazon.awssdk.utils.async.SimplePublisher;
@ExtendWith(MockitoExtension.class)
public abstract class BaseHttpStreamResponseHandlerTest {
@@ -44,10 +53,15 @@ public abstract class BaseHttpStreamResponseHandlerTest {
@Mock
HttpStream httpStream;
+ @Mock
+ SimplePublisher simplePublisher;
+
HttpStreamResponseHandler responseHandler;
abstract HttpStreamResponseHandler responseHandler();
+ abstract HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher simplePublisher);
+
@BeforeEach
public void setUp() {
requestFuture = new CompletableFuture<>();
@@ -113,6 +127,101 @@ void streamClosed_shouldNotIncreaseStreamWindow() throws InterruptedException {
verify(httpStream, never()).incrementWindow(anyInt());
}
+ @Test
+ void publisherWritesFutureFails_shouldShutdownConnection() {
+ SimplePublisher simplePublisher = Mockito.mock(SimplePublisher.class);
+ CompletableFuture future = new CompletableFuture<>();
+ when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);
+
+ HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);
+ HttpHeader[] httpHeaders = getHttpHeaders();
+
+ handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
+ httpHeaders);
+ handler.onResponseHeadersDone(httpStream, 0);
+ handler.onResponseBody(httpStream,
+ RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));
+ RuntimeException runtimeException = new RuntimeException();
+ future.completeExceptionally(runtimeException);
+
+ try {
+ requestFuture.join();
+ } catch (Exception e) {
+ // we don't verify here because it behaves differently in async and sync
+ }
+
+ verify(crtConn).shutdown();
+ verify(crtConn).close();
+ verify(httpStream).close();
+ verify(httpStream, never()).incrementWindow(anyInt());
+ }
+
+ @Test
+ void publisherWritesFutureCompletesAfterConnectionClosed_shouldNotInvokeIncrementWindow() {
+ CompletableFuture future = new CompletableFuture<>();
+ when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);
+ when(simplePublisher.complete()).thenReturn(future);
+
+ HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);
+
+
+ HttpHeader[] httpHeaders = getHttpHeaders();
+
+ handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
+ httpHeaders);
+ handler.onResponseHeadersDone(httpStream, 0);
+ handler.onResponseBody(httpStream,
+ RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));
+ handler.onResponseComplete(httpStream, 0);
+ future.complete(null);
+
+ requestFuture.join();
+ verify(crtConn, never()).shutdown();
+ verify(crtConn).close();
+ verify(httpStream).close();
+ verify(httpStream, never()).incrementWindow(anyInt());
+ }
+
+ @Test
+ void publisherWritesFutureCompletesWhenConnectionClosed_shouldNotInvokeIncrementWindow() {
+ CompletableFuture future = new CompletableFuture<>();
+ when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future);
+ when(simplePublisher.complete()).thenReturn(future);
+
+ HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher);
+
+
+ HttpHeader[] httpHeaders = getHttpHeaders();
+
+ handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
+ httpHeaders);
+ handler.onResponseHeadersDone(httpStream, 0);
+ handler.onResponseBody(httpStream,
+ RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8));
+
+ // This tracker tracks which of the two operation completes first
+ AtomicInteger whenCompleteTracker = new AtomicInteger(0);
+ CompletableFuture onResponseComplete = CompletableFuture.runAsync(() -> handler.onResponseComplete(httpStream, 0))
+ .whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 1));
+
+ CompletableFuture writeComplete = CompletableFuture.runAsync(() -> future.complete(null))
+ .whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 2));
+ requestFuture.join();
+
+ CompletableFuture.allOf(onResponseComplete, writeComplete).join();
+
+ if (whenCompleteTracker.get() == 1) {
+ // onResponseComplete finishes first
+ verify(httpStream, never()).incrementWindow(anyInt());
+ } else {
+ verify(httpStream).incrementWindow(anyInt());
+ }
+
+ verify(crtConn, never()).shutdown();
+ verify(crtConn).close();
+ verify(httpStream).close();
+ }
+
static HttpHeader[] getHttpHeaders() {
HttpHeader[] httpHeaders = new HttpHeader[1];
httpHeaders[0] = new HttpHeader("Content-Length", "1");
diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java
index 8668e0fc0054..0628efaf15b2 100644
--- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java
+++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java
@@ -41,6 +41,7 @@
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.http.crt.internal.response.CrtResponseAdapter;
import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler;
+import software.amazon.awssdk.utils.async.SimplePublisher;
public class CrtResponseHandlerTest extends BaseHttpStreamResponseHandlerTest {
@@ -53,6 +54,15 @@ HttpStreamResponseHandler responseHandler() {
return CrtResponseAdapter.toCrtResponseHandler(crtConn, requestFuture, responseHandler);
}
+ @Override
+ HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher simplePublisher) {
+ AsyncResponseHandler responseHandler = new AsyncResponseHandler<>((response,
+ executionAttributes) -> null, Function.identity(), new ExecutionAttributes());
+
+ responseHandler.prepare();
+ return new CrtResponseAdapter(crtConn, requestFuture, responseHandler, simplePublisher);
+ }
+
@Test
void publisherFailedToDeliverEvents_shouldShutDownConnection() {
SdkAsyncHttpResponseHandler responseHandler = new TestAsyncHttpResponseHandler();
diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java
index b5ef71dbc5cd..45c7fdccbfe9 100644
--- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java
+++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java
@@ -17,12 +17,17 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import io.reactivex.Completable;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
@@ -34,6 +39,7 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpHeader;
@@ -45,6 +51,8 @@
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.http.crt.internal.response.CrtResponseAdapter;
import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler;
+import software.amazon.awssdk.utils.CompletableFutureUtils;
+import software.amazon.awssdk.utils.async.SimplePublisher;
public class InputStreamAdaptingHttpStreamResponseHandlerTest extends BaseHttpStreamResponseHandlerTest {
@@ -53,6 +61,11 @@ HttpStreamResponseHandler responseHandler() {
return new InputStreamAdaptingHttpStreamResponseHandler(crtConn, requestFuture);
}
+ @Override
+ HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher simplePublisher) {
+ return new InputStreamAdaptingHttpStreamResponseHandler(crtConn, requestFuture, simplePublisher);
+ }
+
@Test
void abortStream_shouldShutDownConnection() throws IOException {
HttpHeader[] httpHeaders = getHttpHeaders();