diff --git a/.changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json b/.changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json new file mode 100644 index 000000000000..31b52d8a3707 --- /dev/null +++ b/.changes/next-release/bugfix-AWSCRTHTTPClient-7b95a65.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "AWS CRT HTTP Client", + "contributor": "", + "description": "Fixed a thread safety issue that could cause application to crash in the edge case where the SDK attempted to invoke `incrementWindow` after the stream is closed in AWS CRT HTTP Client." +} diff --git a/bom-internal/pom.xml b/bom-internal/pom.xml index bb8a9386ad21..5138b01d3bf5 100644 --- a/bom-internal/pom.xml +++ b/bom-internal/pom.xml @@ -255,6 +255,12 @@ ${mockito.version} test + + org.mockito + mockito-inline + ${mockito.version} + test + nl.jqno.equalsverifier equalsverifier diff --git a/http-clients/aws-crt-client/pom.xml b/http-clients/aws-crt-client/pom.xml index a6fe67926eaa..77debaee6bea 100644 --- a/http-clients/aws-crt-client/pom.xml +++ b/http-clients/aws-crt-client/pom.xml @@ -105,7 +105,7 @@ org.mockito - mockito-core + mockito-inline test diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java index bfb75050e55a..d2f246336ca8 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.SdkTestInternalApi; import software.amazon.awssdk.crt.CRT; import software.amazon.awssdk.crt.http.HttpClientConnection; import software.amazon.awssdk.crt.http.HttpException; @@ -46,7 +47,7 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler { private final HttpClientConnection connection; private final CompletableFuture completionFuture; private final SdkAsyncHttpResponseHandler responseHandler; - private final SimplePublisher responsePublisher = new SimplePublisher<>(); + private final SimplePublisher responsePublisher; private final SdkHttpResponse.Builder responseBuilder; private final ResponseHandlerHelper responseHandlerHelper; @@ -54,11 +55,21 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler { private CrtResponseAdapter(HttpClientConnection connection, CompletableFuture completionFuture, SdkAsyncHttpResponseHandler responseHandler) { + this(connection, completionFuture, responseHandler, new SimplePublisher<>()); + } + + + @SdkTestInternalApi + public CrtResponseAdapter(HttpClientConnection connection, + CompletableFuture completionFuture, + SdkAsyncHttpResponseHandler responseHandler, + SimplePublisher simplePublisher) { this.connection = Validate.paramNotNull(connection, "connection"); this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture"); this.responseHandler = Validate.paramNotNull(responseHandler, "responseHandler"); this.responseBuilder = SdkHttpResponse.builder(); this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, connection); + this.responsePublisher = simplePublisher; } public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection crtConn, @@ -95,9 +106,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) { return; } - if (!responseHandlerHelper.connectionClosed().get()) { - stream.incrementWindow(bodyBytesIn.length); - } + responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length); }); return 0; diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java index 939405db38b8..b6b95307722e 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.CompletableFuture; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.SdkTestInternalApi; import software.amazon.awssdk.crt.CRT; import software.amazon.awssdk.crt.http.HttpClientConnection; import software.amazon.awssdk.crt.http.HttpException; @@ -42,7 +43,7 @@ public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpStreamResponseHandler { private static final Logger log = Logger.loggerFor(InputStreamAdaptingHttpStreamResponseHandler.class); private volatile AbortableInputStreamSubscriber inputStreamSubscriber; - private final SimplePublisher simplePublisher = new SimplePublisher<>(); + private final SimplePublisher simplePublisher; private final CompletableFuture requestCompletionFuture; private final HttpClientConnection crtConn; @@ -52,10 +53,18 @@ public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpS public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn, CompletableFuture requestCompletionFuture) { + this(crtConn, requestCompletionFuture, new SimplePublisher<>()); + } + + @SdkTestInternalApi + public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn, + CompletableFuture requestCompletionFuture, + SimplePublisher simplePublisher) { this.crtConn = crtConn; this.requestCompletionFuture = requestCompletionFuture; this.responseBuilder = SdkHttpResponse.builder(); this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, crtConn); + this.simplePublisher = simplePublisher; } @Override @@ -101,11 +110,8 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) { failFutureAndCloseConnection(stream, failure); return; } - - if (!responseHandlerHelper.connectionClosed().get()) { - // increment the window upon buffer consumption. - stream.incrementWindow(bodyBytesIn.length); - } + // increment the window upon buffer consumption. + responseHandlerHelper.incrementWindow(stream, bodyBytesIn.length); }); // Window will be incremented after the subscriber consumes the data, returning 0 here to disable it. diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java index 69665d1aeff9..8aa4037df88b 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java @@ -15,7 +15,6 @@ package software.amazon.awssdk.http.crt.internal.response; -import java.util.concurrent.atomic.AtomicBoolean; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.crt.http.HttpClientConnection; import software.amazon.awssdk.crt.http.HttpHeader; @@ -30,14 +29,16 @@ * * CRT connection will only be closed, i.e., not reused, in one of the following conditions: * 1. 5xx server error OR - * 2. It fails to read the response. + * 2. It fails to read the response OR + * 3. the response stream is closed/aborted by the caller. */ @SdkInternalApi public class ResponseHandlerHelper { private final SdkHttpResponse.Builder responseBuilder; private final HttpClientConnection connection; - private AtomicBoolean connectionClosed = new AtomicBoolean(false); + private boolean connectionClosed; + private final Object lock = new Object(); public ResponseHandlerHelper(SdkHttpResponse.Builder responseBuilder, HttpClientConnection connection) { this.responseBuilder = responseBuilder; @@ -57,9 +58,20 @@ public void onResponseHeaders(HttpStream stream, int responseStatusCode, int hea * Release the connection back to the pool so that it can be reused. */ public void releaseConnection(HttpStream stream) { - if (connectionClosed.compareAndSet(false, true)) { - connection.close(); - stream.close(); + synchronized (lock) { + if (!connectionClosed) { + connectionClosed = true; + connection.close(); + stream.close(); + } + } + } + + public void incrementWindow(HttpStream stream, int windowSize) { + synchronized (lock) { + if (!connectionClosed) { + stream.incrementWindow(windowSize); + } } } @@ -67,10 +79,13 @@ public void releaseConnection(HttpStream stream) { * Close the connection completely */ public void closeConnection(HttpStream stream) { - if (connectionClosed.compareAndSet(false, true)) { - connection.shutdown(); - connection.close(); - stream.close(); + synchronized (lock) { + if (!connectionClosed) { + connectionClosed = true; + connection.shutdown(); + connection.close(); + stream.close(); + } } } @@ -82,8 +97,4 @@ public void cleanUpConnectionBasedOnStatusCode(HttpStream stream) { releaseConnection(stream); } } - - public AtomicBoolean connectionClosed() { - return connectionClosed; - } } diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java index 10c51b1a6028..f9b20742613f 100644 --- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java +++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java @@ -16,18 +16,25 @@ package software.amazon.awssdk.http.crt.internal; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import software.amazon.awssdk.crt.http.HttpClientConnection; import software.amazon.awssdk.crt.http.HttpException; @@ -35,6 +42,8 @@ import software.amazon.awssdk.crt.http.HttpHeaderBlock; import software.amazon.awssdk.crt.http.HttpStream; import software.amazon.awssdk.crt.http.HttpStreamResponseHandler; +import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler; +import software.amazon.awssdk.utils.async.SimplePublisher; @ExtendWith(MockitoExtension.class) public abstract class BaseHttpStreamResponseHandlerTest { @@ -44,10 +53,15 @@ public abstract class BaseHttpStreamResponseHandlerTest { @Mock HttpStream httpStream; + @Mock + SimplePublisher simplePublisher; + HttpStreamResponseHandler responseHandler; abstract HttpStreamResponseHandler responseHandler(); + abstract HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher simplePublisher); + @BeforeEach public void setUp() { requestFuture = new CompletableFuture<>(); @@ -113,6 +127,101 @@ void streamClosed_shouldNotIncreaseStreamWindow() throws InterruptedException { verify(httpStream, never()).incrementWindow(anyInt()); } + @Test + void publisherWritesFutureFails_shouldShutdownConnection() { + SimplePublisher simplePublisher = Mockito.mock(SimplePublisher.class); + CompletableFuture future = new CompletableFuture<>(); + when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future); + + HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher); + HttpHeader[] httpHeaders = getHttpHeaders(); + + handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(), + httpHeaders); + handler.onResponseHeadersDone(httpStream, 0); + handler.onResponseBody(httpStream, + RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8)); + RuntimeException runtimeException = new RuntimeException(); + future.completeExceptionally(runtimeException); + + try { + requestFuture.join(); + } catch (Exception e) { + // we don't verify here because it behaves differently in async and sync + } + + verify(crtConn).shutdown(); + verify(crtConn).close(); + verify(httpStream).close(); + verify(httpStream, never()).incrementWindow(anyInt()); + } + + @Test + void publisherWritesFutureCompletesAfterConnectionClosed_shouldNotInvokeIncrementWindow() { + CompletableFuture future = new CompletableFuture<>(); + when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future); + when(simplePublisher.complete()).thenReturn(future); + + HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher); + + + HttpHeader[] httpHeaders = getHttpHeaders(); + + handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(), + httpHeaders); + handler.onResponseHeadersDone(httpStream, 0); + handler.onResponseBody(httpStream, + RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8)); + handler.onResponseComplete(httpStream, 0); + future.complete(null); + + requestFuture.join(); + verify(crtConn, never()).shutdown(); + verify(crtConn).close(); + verify(httpStream).close(); + verify(httpStream, never()).incrementWindow(anyInt()); + } + + @Test + void publisherWritesFutureCompletesWhenConnectionClosed_shouldNotInvokeIncrementWindow() { + CompletableFuture future = new CompletableFuture<>(); + when(simplePublisher.send(any(ByteBuffer.class))).thenReturn(future); + when(simplePublisher.complete()).thenReturn(future); + + HttpStreamResponseHandler handler = responseHandlerWithMockedPublisher(simplePublisher); + + + HttpHeader[] httpHeaders = getHttpHeaders(); + + handler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(), + httpHeaders); + handler.onResponseHeadersDone(httpStream, 0); + handler.onResponseBody(httpStream, + RandomStringUtils.random(1 * 1024 * 1024).getBytes(StandardCharsets.UTF_8)); + + // This tracker tracks which of the two operation completes first + AtomicInteger whenCompleteTracker = new AtomicInteger(0); + CompletableFuture onResponseComplete = CompletableFuture.runAsync(() -> handler.onResponseComplete(httpStream, 0)) + .whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 1)); + + CompletableFuture writeComplete = CompletableFuture.runAsync(() -> future.complete(null)) + .whenComplete((r, t) -> whenCompleteTracker.compareAndSet(0, 2)); + requestFuture.join(); + + CompletableFuture.allOf(onResponseComplete, writeComplete).join(); + + if (whenCompleteTracker.get() == 1) { + // onResponseComplete finishes first + verify(httpStream, never()).incrementWindow(anyInt()); + } else { + verify(httpStream).incrementWindow(anyInt()); + } + + verify(crtConn, never()).shutdown(); + verify(crtConn).close(); + verify(httpStream).close(); + } + static HttpHeader[] getHttpHeaders() { HttpHeader[] httpHeaders = new HttpHeader[1]; httpHeaders[0] = new HttpHeader("Content-Length", "1"); diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java index 8668e0fc0054..0628efaf15b2 100644 --- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java +++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java @@ -41,6 +41,7 @@ import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; import software.amazon.awssdk.http.crt.internal.response.CrtResponseAdapter; import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler; +import software.amazon.awssdk.utils.async.SimplePublisher; public class CrtResponseHandlerTest extends BaseHttpStreamResponseHandlerTest { @@ -53,6 +54,15 @@ HttpStreamResponseHandler responseHandler() { return CrtResponseAdapter.toCrtResponseHandler(crtConn, requestFuture, responseHandler); } + @Override + HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher simplePublisher) { + AsyncResponseHandler responseHandler = new AsyncResponseHandler<>((response, + executionAttributes) -> null, Function.identity(), new ExecutionAttributes()); + + responseHandler.prepare(); + return new CrtResponseAdapter(crtConn, requestFuture, responseHandler, simplePublisher); + } + @Test void publisherFailedToDeliverEvents_shouldShutDownConnection() { SdkAsyncHttpResponseHandler responseHandler = new TestAsyncHttpResponseHandler(); diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java index b5ef71dbc5cd..45c7fdccbfe9 100644 --- a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java +++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java @@ -17,12 +17,17 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import io.reactivex.Completable; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; @@ -34,6 +39,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import software.amazon.awssdk.crt.http.HttpClientConnection; import software.amazon.awssdk.crt.http.HttpHeader; @@ -45,6 +51,8 @@ import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; import software.amazon.awssdk.http.crt.internal.response.CrtResponseAdapter; import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.async.SimplePublisher; public class InputStreamAdaptingHttpStreamResponseHandlerTest extends BaseHttpStreamResponseHandlerTest { @@ -53,6 +61,11 @@ HttpStreamResponseHandler responseHandler() { return new InputStreamAdaptingHttpStreamResponseHandler(crtConn, requestFuture); } + @Override + HttpStreamResponseHandler responseHandlerWithMockedPublisher(SimplePublisher simplePublisher) { + return new InputStreamAdaptingHttpStreamResponseHandler(crtConn, requestFuture, simplePublisher); + } + @Test void abortStream_shouldShutDownConnection() throws IOException { HttpHeader[] httpHeaders = getHttpHeaders();