Skip to content

Bug fix for S3AsyncClient.putObject hangs if there is a connection reset while uploading of objects #3522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "Netty NIO Http Client",
"contributor": "",
"description": "Fix for Netty based client request getting stuck if connection is reset after recieveing Http Continue response."
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package software.amazon.awssdk.http.nio.netty.internal;

import io.netty.channel.Channel;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2FrameStream;
Expand Down Expand Up @@ -73,6 +74,12 @@ public final class ChannelAttributeKey {
public static final AttributeKey<ChannelDiagnostics> CHANNEL_DIAGNOSTICS = NettyUtils.getOrCreateAttributeKey(
"aws.http.nio.netty.async.channelDiagnostics");

/**
* {@link AttributeKey} to keep track of whether we have received Continue in {@link HttpResponseStatus}.
*/
public static final AttributeKey<Boolean> RESPONSE_100_CONTINUE_MESSAGE = NettyUtils.getOrCreateAttributeKey(
"aws.http.nio.netty.async.100ContinueMessage");

/**
* {@link AttributeKey} to keep track of whether we should close the connection after this request
* has completed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.KEEP_ALIVE;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.LAST_HTTP_CONTENT_RECEIVED_KEY;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.REQUEST_CONTEXT_KEY;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_100_CONTINUE_MESSAGE;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_COMPLETE_KEY;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_CONTENT_LENGTH;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_DATA_READ;
Expand Down Expand Up @@ -196,6 +197,7 @@ private void configureChannel() {
channel.attr(REQUEST_CONTEXT_KEY).set(context);
channel.attr(RESPONSE_COMPLETE_KEY).set(false);
channel.attr(LAST_HTTP_CONTENT_RECEIVED_KEY).set(false);
channel.attr(RESPONSE_100_CONTINUE_MESSAGE).set(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to mark the completion of request in HttpStreamHandler and get rid of LastHttpContentHandler? That way, we don't need to handle 100 continue message this classe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raised #3535 to change the logic where LAST_HTTP_CONTENT_RECEIVED_KEY is set

channel.attr(RESPONSE_CONTENT_LENGTH).set(null);
channel.attr(RESPONSE_DATA_READ).set(null);
channel.attr(CHANNEL_DIAGNOSTICS).get().incrementRequestCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.KEEP_ALIVE;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.LAST_HTTP_CONTENT_RECEIVED_KEY;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.REQUEST_CONTEXT_KEY;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_100_CONTINUE_MESSAGE;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_COMPLETE_KEY;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_CONTENT_LENGTH;
import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_DATA_READ;
Expand Down Expand Up @@ -463,18 +464,28 @@ public void cancel() {
private void notifyIfResponseNotCompleted(ChannelHandlerContext handlerCtx) {
RequestContext requestCtx = handlerCtx.channel().attr(REQUEST_CONTEXT_KEY).get();
Boolean responseCompleted = handlerCtx.channel().attr(RESPONSE_COMPLETE_KEY).get();
Boolean lastHttpContentReceived = handlerCtx.channel().attr(LAST_HTTP_CONTENT_RECEIVED_KEY).get();
boolean isLastByteWithout100Continue = isLastByteWithout100Response(handlerCtx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a more explanatory name for this boolean? I think the channel key name is correct but here it'd be great if you'd know why a 100 continue disables the lastbyte check. isLastByteForResponseBody? (not sure if this is 100% correct)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with new PR to remove LastHttpContentHandler.

handlerCtx.channel().attr(KEEP_ALIVE).set(false);

if (!Boolean.TRUE.equals(responseCompleted) && !Boolean.TRUE.equals(lastHttpContentReceived)) {
if (!Boolean.TRUE.equals(responseCompleted) && !isLastByteWithout100Continue) {
IOException err = new IOException(NettyUtils.closedChannelMessage(handlerCtx.channel()));
runAndLogError(handlerCtx.channel(), () -> "Fail to execute SdkAsyncHttpResponseHandler#onError",
() -> requestCtx.handler().onError(err));
executeFuture(handlerCtx).completeExceptionally(err);
runAndLogError(handlerCtx.channel(), () -> "Could not release channel", () -> closeAndRelease(handlerCtx));
} else if (!Boolean.TRUE.equals(responseCompleted)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is a bit confusing. Does it make sense to rearrange it with !Boolean.TRUE.equals(responseCompleted) in an outer if clause and then just a simple if/else checking isLastByteWithout100Response nested within?

if (!Boolean.TRUE.equals(responseCompleted)) {
    if (!isLastByteWithout100Continue) {
            IOException err = new IOException(NettyUtils.closedChannelMessage(handlerCtx.channel()));
            runAndLogError(handlerCtx.channel(), () -> "Fail to execute SdkAsyncHttpResponseHandler#onError",
                           () -> requestCtx.handler().onError(err));
            executeFuture(handlerCtx).completeExceptionally(err);
            runAndLogError(handlerCtx.channel(), () -> "Could not release channel", () -> closeAndRelease(handlerCtx));
        } else {
            log.trace(handlerCtx.channel(),
                      () -> "Run error skipped because lastHttpContentReceived is "
                            + handlerCtx.channel().attr(LAST_HTTP_CONTENT_RECEIVED_KEY).get() + " and 100ContinueMessage is "
                            + handlerCtx.channel().attr(RESPONSE_100_CONTINUE_MESSAGE).get());
        }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not required in new changes

log.trace(handlerCtx.channel(),
() -> "Run error skipped because lastHttpContentReceived is "
+ handlerCtx.channel().attr(LAST_HTTP_CONTENT_RECEIVED_KEY).get() + " and 100ContinueMessage is "
+ handlerCtx.channel().attr(RESPONSE_100_CONTINUE_MESSAGE).get());
}
}

private boolean isLastByteWithout100Response(ChannelHandlerContext handlerCtx) {
return Boolean.TRUE.equals(handlerCtx.channel().attr(LAST_HTTP_CONTENT_RECEIVED_KEY).get())
&& Boolean.FALSE.equals(handlerCtx.channel().attr(RESPONSE_100_CONTINUE_MESSAGE).get());
}

private static final class DataCountingPublisher implements Publisher<ByteBuffer> {
private final ChannelHandlerContext ctx;
private final Publisher<ByteBuffer> delegate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package software.amazon.awssdk.http.nio.netty.internal.nrs;

import static software.amazon.awssdk.http.nio.netty.internal.ChannelAttributeKey.RESPONSE_100_CONTINUE_MESSAGE;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.FullHttpRequest;
Expand Down Expand Up @@ -168,6 +170,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
ReferenceCountUtil.release(msg);
if (msg instanceof LastHttpContent) {
ignoreResponseBody = false;
ctx.channel().attr(RESPONSE_100_CONTINUE_MESSAGE).set(true);
}
} else {
super.channelRead(ctx, msg);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package software.amazon.awssdk.http.nio.netty.fault;

import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.http.async.SdkHttpContentPublisher;

public class SdkTestHttpContentPublisher implements SdkHttpContentPublisher {
private final byte[] body;
private final AtomicReference<Subscriber<? super ByteBuffer>> subscriber = new AtomicReference<>(null);
private final AtomicBoolean complete = new AtomicBoolean(false);

public SdkTestHttpContentPublisher(byte[] body) {
this.body = body;
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> s) {
boolean wasFirstSubscriber = subscriber.compareAndSet(null, s);

SdkTestHttpContentPublisher publisher = this;

if (wasFirstSubscriber) {
s.onSubscribe(new Subscription() {
@Override
public void request(long n) {
publisher.request(n);
}

@Override
public void cancel() {
// Do nothing
}
});
} else {
s.onError(new RuntimeException("Only allow one subscriber"));
}
}

protected void request(long n) {
// Send the whole body if they request >0 ByteBuffers
if (n > 0 && !complete.get()) {
complete.set(true);
subscriber.get().onNext(ByteBuffer.wrap(body));
subscriber.get().onComplete();
}
}

@Override
public Optional<Long> contentLength() {
return Optional.of((long)body.length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
Expand All @@ -50,6 +52,7 @@
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.reactivex.Flowable;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -93,6 +96,17 @@ public static Collection<TestCase> testCases() {
new TestCase(CloseTime.DURING_RESPONSE_PAYLOAD, "Response had content-length"));
}

public static Collection<TestCase> testCasesForHttpContinueResponse() {
return Arrays.asList(new TestCase(CloseTime.DURING_INIT, "The connection was closed during the request."),
new TestCase(CloseTime.BEFORE_SSL_HANDSHAKE, "The connection was closed during the request."),
new TestCase(CloseTime.DURING_SSL_HANDSHAKE, "The connection was closed during the request."),
new TestCase(CloseTime.BEFORE_REQUEST_PAYLOAD, "The connection was closed during the request."),
new TestCase(CloseTime.DURING_REQUEST_PAYLOAD, "The connection was closed during the request."),
new TestCase(CloseTime.BEFORE_RESPONSE_HEADERS, "The connection was closed during the request."),
new TestCase(CloseTime.BEFORE_RESPONSE_PAYLOAD, "The connection was closed during the request."),
new TestCase(CloseTime.DURING_RESPONSE_PAYLOAD, "The connection was closed during the request."));
}

@AfterEach
public void teardown() throws InterruptedException {
if (server != null) {
Expand All @@ -108,14 +122,23 @@ public void teardown() throws InterruptedException {

@ParameterizedTest
@MethodSource("testCases")
public void closeTimeHasCorrectMessage(TestCase testCase) throws Exception {
void closeTimeHasCorrectMessage(TestCase testCase) throws Exception {
server = new Server(ServerConfig.builder().httpResponseStatus(HttpResponseStatus.OK).build());
setupTestCase(testCase);
server.closeTime = testCase.closeTime;
assertThat(captureException()).hasMessageContaining(testCase.errorMessageSubstring);
assertThat(captureExceptionWithHttpOkResponse()).hasMessageContaining(testCase.errorMessageSubstring);
}

@ParameterizedTest
@MethodSource("testCasesForHttpContinueResponse")
void closeTimeHasCorrectMessageWith100ContinueResponse(TestCase testCase) throws Exception {
server = new Server(ServerConfig.builder().httpResponseStatus(HttpResponseStatus.CONTINUE).build());
setupTestCase(testCase);
server.closeTime = testCase.closeTime;
assertThat(captureExceptionWithHttpContinueResponse()).hasMessageContaining(testCase.errorMessageSubstring);
}

public void setupTestCase(TestCase testCase) throws Exception {
server = new Server();
server.init(testCase.closeTime);

netty = NettyNioAsyncHttpClient.builder()
Expand All @@ -125,7 +148,7 @@ public void setupTestCase(TestCase testCase) throws Exception {
.buildWithDefaults(AttributeMap.builder().put(TRUST_ALL_CERTIFICATES, true).build());
}

private Throwable captureException() {
private Throwable captureExceptionWithHttpOkResponse() {
try {
sendGetRequest().get(10, TimeUnit.SECONDS);
} catch (InterruptedException | TimeoutException e) {
Expand All @@ -137,6 +160,20 @@ private Throwable captureException() {
throw new AssertionError("Call did not fail as expected.");
}

private Throwable captureExceptionWithHttpContinueResponse() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Consider adding the future as a parameter to captureException and keep it to one method

try {
sendPutRequestWithExpectContinue().get(10, TimeUnit.SECONDS);
} catch (InterruptedException | TimeoutException e) {
throw new Error(e);
} catch (ExecutionException e) {
return e.getCause();
}

throw new AssertionError("Call did not fail as expected.");
}



private CompletableFuture<Void> sendGetRequest() {
AsyncExecuteRequest req = AsyncExecuteRequest.builder()
.responseHandler(new SdkAsyncHttpResponseHandler() {
Expand Down Expand Up @@ -169,6 +206,40 @@ public void onError(Throwable error) {
return netty.execute(req);
}

private CompletableFuture<Void> sendPutRequestWithExpectContinue() {
AsyncExecuteRequest req = AsyncExecuteRequest.builder()
.responseHandler(new SdkAsyncHttpResponseHandler() {
private SdkHttpResponse headers;

@Override
public void onHeaders(SdkHttpResponse headers) {
this.headers = headers;
}

@Override
public void onStream(Publisher<ByteBuffer> stream) {
Flowable.fromPublisher(stream).forEach(b -> {
});
}

@Override
public void onError(Throwable error) {
}
})
Comment on lines +211 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Could be simplified by extracting the response handler

.request(SdkHttpFullRequest.builder()
.method(SdkHttpMethod.PUT)
.protocol("https")
.putHeader(HttpHeaderNames.EXPECT.toString(),
HttpHeaderValues.CONTINUE.toString())
.host("localhost")
.port(server.port())
.build())
.requestContentPublisher(new SdkTestHttpContentPublisher("reqBody".getBytes(StandardCharsets.UTF_8)))
.build();

return netty.execute(req);
}

private static class TestCase {
private CloseTime closeTime;
private String errorMessageSubstring;
Expand Down Expand Up @@ -198,12 +269,37 @@ private enum CloseTime {
DURING_RESPONSE_PAYLOAD
}

private static class ServerConfig {
private final HttpResponseStatus httpResponseStatus;
public static Builder builder(){
return new Builder();
}
private ServerConfig(Builder builder) {
this.httpResponseStatus = builder.httpResponseStatus;
}
public static class Builder {
private HttpResponseStatus httpResponseStatus;
public Builder httpResponseStatus(HttpResponseStatus httpResponseStatus){
this.httpResponseStatus = httpResponseStatus;
return this;
}
public ServerConfig build() {
return new ServerConfig(this);
}
}
}

private static class Server extends ChannelInitializer<Channel> {
private final NioEventLoopGroup group = new NioEventLoopGroup();
private CloseTime closeTime;
private ServerBootstrap bootstrap;
private ServerSocketChannel serverSock;
private SslContext sslCtx;
private ServerConfig serverConfig;

public Server(ServerConfig serverConfig) {
this.serverConfig = serverConfig;
}

private void init(CloseTime closeTime) throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate();
Expand Down Expand Up @@ -240,7 +336,9 @@ protected void initChannel(Channel ch) {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new SslHandler(new FaultInjectionSslEngine(sslCtx.newEngine(ch.alloc()), ch), false));
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new FaultInjectionHttpHandler());
FaultInjectionHttpHandler faultInjectionHttpHandler = new FaultInjectionHttpHandler();
faultInjectionHttpHandler.setHttpResponseStatus(serverConfig.httpResponseStatus);
pipeline.addLast(faultInjectionHttpHandler);

LOGGER.info(() -> "Channel initialized " + ch);
}
Expand Down Expand Up @@ -287,6 +385,13 @@ private void closeChannel(String message) {
}

private class FaultInjectionHttpHandler extends SimpleChannelInboundHandler<Object> {

private HttpResponseStatus httpResponseStatus = HttpResponseStatus.OK;

public void setHttpResponseStatus(HttpResponseStatus httpResponseStatus) {
this.httpResponseStatus = httpResponseStatus;
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
LOGGER.info(() -> "Reading " + msg);
Expand Down Expand Up @@ -321,7 +426,7 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
private void writeResponse(ChannelHandlerContext ctx) {
int responseLength = 10 * 1024 * 1024; // 10 MB
HttpHeaders headers = new DefaultHttpHeaders().add("Content-Length", responseLength);
ctx.writeAndFlush(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, headers)).addListener(x -> {
ctx.writeAndFlush(new DefaultHttpResponse(HttpVersion.HTTP_1_1, this.httpResponseStatus, headers)).addListener(x -> {
if (closeTime == CloseTime.BEFORE_RESPONSE_PAYLOAD) {
LOGGER.info(() -> "Closing channel before response payload " + ctx.channel());
ctx.close();
Expand Down