diff --git a/core/aws-core/pom.xml b/core/aws-core/pom.xml
index 9f8c62d1795b..1f7e1306d479 100644
--- a/core/aws-core/pom.xml
+++ b/core/aws-core/pom.xml
@@ -113,6 +113,11 @@
software.amazon.eventstream
eventstream
+
+ software.amazon.awssdk
+ utils-lite
+ ${awsjavasdk.version}
+
software.amazon.awssdk
diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptor.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptor.java
index 95224228cfb4..d2b422c940eb 100644
--- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptor.java
+++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptor.java
@@ -19,10 +19,12 @@
import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.awscore.internal.interceptor.TracingSystemSetting;
import software.amazon.awssdk.core.interceptor.Context;
+import software.amazon.awssdk.core.interceptor.ExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.utils.SystemSetting;
+import software.amazon.awssdk.utilslite.SdkInternalThreadLocal;
/**
* The {@code TraceIdExecutionInterceptor} copies the trace details to the {@link #TRACE_ID_HEADER} header, assuming we seem to
@@ -32,27 +34,57 @@
public class TraceIdExecutionInterceptor implements ExecutionInterceptor {
private static final String TRACE_ID_HEADER = "X-Amzn-Trace-Id";
private static final String LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE = "AWS_LAMBDA_FUNCTION_NAME";
+ private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TRACE_ID";
+ private static final ExecutionAttribute TRACE_ID = new ExecutionAttribute<>("TraceId");
+
+ @Override
+ public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) {
+ String traceId = SdkInternalThreadLocal.get(CONCURRENT_TRACE_ID_KEY);
+ if (traceId != null) {
+ executionAttributes.putAttribute(TRACE_ID, traceId);
+ }
+ }
@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
Optional traceIdHeader = traceIdHeader(context);
if (!traceIdHeader.isPresent()) {
Optional lambdafunctionName = lambdaFunctionNameEnvironmentVariable();
- Optional traceId = traceId();
+ Optional traceId = traceId(executionAttributes);
if (lambdafunctionName.isPresent() && traceId.isPresent()) {
return context.httpRequest().copy(r -> r.putHeader(TRACE_ID_HEADER, traceId.get()));
}
}
-
return context.httpRequest();
}
+ @Override
+ public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
+ saveTraceId(executionAttributes);
+ }
+
+ @Override
+ public void onExecutionFailure(Context.FailedExecution context, ExecutionAttributes executionAttributes) {
+ saveTraceId(executionAttributes);
+ }
+
+ private static void saveTraceId(ExecutionAttributes executionAttributes) {
+ String traceId = executionAttributes.getAttribute(TRACE_ID);
+ if (traceId != null) {
+ SdkInternalThreadLocal.put(CONCURRENT_TRACE_ID_KEY, executionAttributes.getAttribute(TRACE_ID));
+ }
+ }
+
private Optional traceIdHeader(Context.ModifyHttpRequest context) {
return context.httpRequest().firstMatchingHeader(TRACE_ID_HEADER);
}
- private Optional traceId() {
+ private Optional traceId(ExecutionAttributes executionAttributes) {
+ Optional traceId = Optional.ofNullable(executionAttributes.getAttribute(TRACE_ID));
+ if (traceId.isPresent()) {
+ return traceId;
+ }
return TracingSystemSetting._X_AMZN_TRACE_ID.getStringValue();
}
@@ -61,4 +93,4 @@ private Optional lambdaFunctionNameEnvironmentVariable() {
return SystemSetting.getStringValueFromEnvironmentVariable(LAMBDA_FUNCTION_NAME_ENVIRONMENT_VARIABLE);
// CHECKSTYLE:ON
}
-}
+}
\ No newline at end of file
diff --git a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptorTest.java b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptorTest.java
index b3f965a490fc..3c18d064cd0d 100644
--- a/core/aws-core/src/test/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptorTest.java
+++ b/core/aws-core/src/test/java/software/amazon/awssdk/awscore/interceptor/TraceIdExecutionInterceptorTest.java
@@ -28,6 +28,7 @@
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
+import software.amazon.awssdk.utilslite.SdkInternalThreadLocal;
public class TraceIdExecutionInterceptorTest {
@Test
@@ -111,6 +112,78 @@ public void headerNotAddedIfNoTraceIdEnvVar() {
});
}
+ @Test
+ public void modifyHttpRequest_whenMultiConcurrencyModeWithInternalThreadLocal_shouldAddTraceIdHeader() {
+ EnvironmentVariableHelper.run(env -> {
+ resetRelevantEnvVars(env);
+ env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
+ SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
+
+ try {
+ TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
+ ExecutionAttributes executionAttributes = new ExecutionAttributes();
+
+ interceptor.beforeExecution(null, executionAttributes);
+ Context.ModifyHttpRequest context = context();
+
+ SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
+ assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+ } finally {
+ SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID");
+ }
+ });
+ }
+
+ @Test
+ public void modifyHttpRequest_whenMultiConcurrencyModeWithBothInternalThreadLocalAndSystemProperty_shouldUseInternalThreadLocalValue() {
+ EnvironmentVariableHelper.run(env -> {
+ resetRelevantEnvVars(env);
+ env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
+
+ SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
+ Properties props = System.getProperties();
+ props.setProperty("com.amazonaws.xray.traceHeader", "sys-prop-345");
+
+ try {
+ TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
+ ExecutionAttributes executionAttributes = new ExecutionAttributes();
+
+ interceptor.beforeExecution(null, executionAttributes);
+
+ Context.ModifyHttpRequest context = context();
+ SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
+
+ assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+ } finally {
+ SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID");
+ props.remove("com.amazonaws.xray.traceHeader");
+ }
+ });
+ }
+
+ @Test
+ public void modifyHttpRequest_whenNotInLambdaEnvironmentWithInternalThreadLocal_shouldNotAddHeader() {
+ EnvironmentVariableHelper.run(env -> {
+ resetRelevantEnvVars(env);
+
+ SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "should-be-ignored");
+
+ try {
+ TraceIdExecutionInterceptor interceptor = new TraceIdExecutionInterceptor();
+ ExecutionAttributes executionAttributes = new ExecutionAttributes();
+
+ interceptor.beforeExecution(null, executionAttributes);
+
+ Context.ModifyHttpRequest context = context();
+ SdkHttpRequest request = interceptor.modifyHttpRequest(context, executionAttributes);
+
+ assertThat(request.firstMatchingHeader("X-Amzn-Trace-Id")).isEmpty();
+ } finally {
+ SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID");
+ }
+ });
+ }
+
private Context.ModifyHttpRequest context() {
return context(SdkHttpRequest.builder()
.uri(URI.create("https://localhost"))
diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/TraceIdTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/TraceIdTest.java
index 3299e26ef876..a0747444292d 100644
--- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/TraceIdTest.java
+++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/TraceIdTest.java
@@ -17,17 +17,25 @@
import static org.assertj.core.api.Assertions.assertThat;
+import java.util.List;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
import software.amazon.awssdk.awscore.interceptor.TraceIdExecutionInterceptor;
+import software.amazon.awssdk.core.interceptor.Context;
+import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
+import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteResponse;
+import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient;
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
import software.amazon.awssdk.testutils.EnvironmentVariableHelper;
+import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
import software.amazon.awssdk.utils.StringInputStream;
+import software.amazon.awssdk.utilslite.SdkInternalThreadLocal;
/**
* Verifies that the {@link TraceIdExecutionInterceptor} is actually wired up for AWS services.
@@ -56,4 +64,181 @@ public void traceIdInterceptorIsEnabled() {
}
});
}
-}
+
+ @Test
+ public void traceIdInterceptorPreservesTraceIdAcrossRetries() {
+ EnvironmentVariableHelper.run(env -> {
+ env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
+ SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
+
+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
+ .region(Region.US_WEST_2)
+ .credentialsProvider(AnonymousCredentialsProvider.create())
+ .httpClient(mockHttpClient)
+ .build()) {
+
+ mockHttpClient.stubResponses(
+ HttpExecuteResponse.builder()
+ .response(SdkHttpResponse.builder().statusCode(500).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build(),
+ HttpExecuteResponse.builder()
+ .response(SdkHttpResponse.builder().statusCode(500).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build(),
+ HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build());
+
+ client.allTypes().join();
+
+ List requests = mockHttpClient.getRequests();
+ assertThat(requests).hasSize(3);
+
+ assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+ assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+ assertThat(requests.get(2).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+
+ } finally {
+ SdkInternalThreadLocal.clear();
+ }
+ });
+ }
+
+ @Test
+ public void traceIdInterceptorPreservesTraceIdAcrossChainedFutures() {
+ EnvironmentVariableHelper.run(env -> {
+ env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
+ SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
+
+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
+ .region(Region.US_WEST_2)
+ .credentialsProvider(AnonymousCredentialsProvider.create())
+ .httpClient(mockHttpClient)
+ .build()) {
+
+ mockHttpClient.stubResponses(
+ HttpExecuteResponse.builder()
+ .response(SdkHttpResponse.builder().statusCode(200).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build(),
+ HttpExecuteResponse.builder()
+ .response(SdkHttpResponse.builder().statusCode(200).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build()
+ );
+
+ client.allTypes()
+ .thenRun(() -> {
+ client.allTypes().join();
+ })
+ .join();
+
+ List requests = mockHttpClient.getRequests();
+
+ assertThat(requests).hasSize(2);
+
+ assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+ assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+
+ } finally {
+ SdkInternalThreadLocal.clear();
+ }
+ });
+ }
+
+ @Test
+ public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFutures() {
+ EnvironmentVariableHelper.run(env -> {
+ env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
+ SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
+
+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
+ .region(Region.US_WEST_2)
+ .credentialsProvider(AnonymousCredentialsProvider.create())
+ .httpClient(mockHttpClient)
+ .build()) {
+
+ mockHttpClient.stubResponses(
+ HttpExecuteResponse.builder()
+ .response(SdkHttpResponse.builder().statusCode(400).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build(),
+ HttpExecuteResponse.builder()
+ .response(SdkHttpResponse.builder().statusCode(200).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build()
+ );
+
+ client.allTypes()
+ .exceptionally(throwable -> {
+ client.allTypes().join();
+ return null;
+ }).join();
+
+ List requests = mockHttpClient.getRequests();
+
+ assertThat(requests).hasSize(2);
+
+ assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+ assertThat(requests.get(1).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+
+ } finally {
+ SdkInternalThreadLocal.clear();
+ }
+ });
+ }
+
+ @Test
+ public void traceIdInterceptorPreservesTraceIdAcrossExceptionallyCompletedFuturesThrownInPreExecution() {
+ EnvironmentVariableHelper.run(env -> {
+ env.set("AWS_LAMBDA_FUNCTION_NAME", "foo");
+ SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", "SdkInternalThreadLocal-trace-123");
+
+ ExecutionInterceptor throwingInterceptor = new ExecutionInterceptor() {
+ private boolean hasThrown = false;
+
+ @Override
+ public void beforeMarshalling(Context.BeforeMarshalling context, ExecutionAttributes executionAttributes) {
+ if (!hasThrown) {
+ hasThrown = true;
+ throw new RuntimeException("failing in pre execution");
+ }
+ }
+ };
+
+ try (MockAsyncHttpClient mockHttpClient = new MockAsyncHttpClient();
+ ProtocolRestJsonAsyncClient client = ProtocolRestJsonAsyncClient.builder()
+ .region(Region.US_WEST_2)
+ .credentialsProvider(AnonymousCredentialsProvider.create())
+ .overrideConfiguration(o -> o.addExecutionInterceptor(throwingInterceptor))
+ .httpClient(mockHttpClient)
+ .build()) {
+
+ mockHttpClient.stubResponses(
+ HttpExecuteResponse.builder()
+ .response(SdkHttpResponse.builder().statusCode(200).build())
+ .responseBody(AbortableInputStream.create(new StringInputStream("{}")))
+ .build()
+ );
+
+ client.allTypes()
+ .exceptionally(throwable -> {
+ client.allTypes().join();
+ return null;
+ }).join();
+
+ List requests = mockHttpClient.getRequests();
+
+ assertThat(requests).hasSize(1);
+ assertThat(requests.get(0).firstMatchingHeader("X-Amzn-Trace-Id")).hasValue("SdkInternalThreadLocal-trace-123");
+
+ } finally {
+ SdkInternalThreadLocal.clear();
+ }
+ });
+ }
+}
\ No newline at end of file