Skip to content

Commit 044b96d

Browse files
authored
improv(batch): Propagate trace entity to worker threads during parallel batch processing. (#2300)
1 parent f449a57 commit 044b96d

File tree

5 files changed

+288
-61
lines changed

5 files changed

+288
-61
lines changed

docs/utilities/batch.md

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,9 @@ used with SQS FIFO. In that case, an `UnsupportedOperationException` is thrown.
484484
in most cases the defaults work well, and changing them is more likely to decrease performance
485485
(see [here](https://www.baeldung.com/java-when-to-use-parallel-stream#fork-join-framework)
486486
and [here](https://dzone.com/articles/be-aware-of-forkjoinpoolcommonpool)).
487-
In situations where this may be useful - such as performing IO-bound work in parallel - make sure to measure before and after!
487+
In situations where this may be useful, such as performing IO-bound work in parallel, make sure to measure before and after!
488+
489+
When using parallel processing with X-Ray tracing enabled, the Tracing utility automatically handles trace context propagation to worker threads. This ensures that subsegments created during parallel message processing appear under the correct parent segment in your X-Ray trace, maintaining proper trace hierarchy and visibility into your batch processing performance.
488490

489491

490492
=== "Example with SQS"
@@ -536,6 +538,84 @@ used with SQS FIFO. In that case, an `UnsupportedOperationException` is thrown.
536538
}
537539
```
538540

541+
=== "Example with X-Ray Tracing"
542+
543+
```java hl_lines="12 17"
544+
public class SqsBatchHandler implements RequestHandler<SQSEvent, SQSBatchResponse> {
545+
546+
private final BatchMessageHandler<SQSEvent, SQSBatchResponse> handler;
547+
548+
public SqsBatchHandler() {
549+
handler = new BatchMessageHandlerBuilder()
550+
.withSqsBatchHandler()
551+
.buildWithMessageHandler(this::processMessage, Product.class);
552+
}
553+
554+
@Override
555+
@Tracing
556+
public SQSBatchResponse handleRequest(SQSEvent sqsEvent, Context context) {
557+
return handler.processBatchInParallel(sqsEvent, context);
558+
}
559+
560+
@Tracing // This will appear correctly under the handleRequest subsegment
561+
private void processMessage(Product p, Context c) {
562+
// Process the product - subsegments will appear under handleRequest
563+
}
564+
}
565+
```
566+
567+
### Choosing the right concurrency model
568+
569+
The `processBatchInParallel` method has two overloads with different concurrency characteristics:
570+
571+
#### Without custom executor (parallelStream)
572+
573+
When you call `processBatchInParallel(event, context)` without providing an executor, the implementation uses Java's `parallelStream()` which leverages the common `ForkJoinPool`.
574+
575+
**Best for: CPU-bound workloads**
576+
577+
- Thread pool size matches available CPU cores
578+
- Optimized for computational tasks (data transformation, calculations, parsing)
579+
- Main thread participates in work-stealing
580+
- Simple to use with no configuration needed
581+
582+
```java
583+
// Good for CPU-intensive processing
584+
return handler.processBatchInParallel(sqsEvent, context);
585+
```
586+
587+
#### With custom executor (CompletableFuture)
588+
589+
When you call `processBatchInParallel(event, context, executor)` with a custom executor, the implementation uses `CompletableFuture` which gives you full control over the thread pool.
590+
591+
**Best for: I/O-bound workloads**
592+
593+
- You control thread pool size and characteristics
594+
- Ideal for I/O operations (HTTP calls, database queries, S3 operations)
595+
- Can use larger thread pools since threads spend time waiting, not computing
596+
- Main thread only waits; worker threads do all processing
597+
598+
```java
599+
// Good for I/O-intensive processing (API calls, DB queries, etc.)
600+
ExecutorService executor = Executors.newFixedThreadPool(50);
601+
return handler.processBatchInParallel(sqsEvent, context, executor);
602+
```
603+
604+
**For Java 21+: Virtual Threads**
605+
606+
If you're using Java 21 or later, virtual threads are ideal for I/O-bound workloads:
607+
608+
```java
609+
ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
610+
return handler.processBatchInParallel(sqsEvent, context, executor);
611+
```
612+
613+
Virtual threads are lightweight and can handle thousands of concurrent I/O operations efficiently without the overhead of platform threads.
614+
615+
**Recommendation for typical Lambda SQS processing:**
616+
617+
Most Lambda functions processing SQS messages perform I/O operations (calling APIs, querying databases, writing to S3). For these workloads, use the custom executor approach with a thread pool sized appropriately for your I/O operations or virtual threads for Java 21+.
618+
539619

540620
## Handling Messages
541621

powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/DynamoDbBatchMessageHandler.java

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,25 @@
1414

1515
package software.amazon.lambda.powertools.batch.handler;
1616

17-
import com.amazonaws.services.lambda.runtime.Context;
18-
import com.amazonaws.services.lambda.runtime.events.DynamodbEvent;
19-
import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse;
20-
2117
import java.util.ArrayList;
2218
import java.util.List;
2319
import java.util.Optional;
2420
import java.util.concurrent.CompletableFuture;
2521
import java.util.concurrent.Executor;
22+
import java.util.concurrent.atomic.AtomicReference;
2623
import java.util.function.BiConsumer;
2724
import java.util.function.Consumer;
2825
import java.util.stream.Collectors;
26+
2927
import org.slf4j.Logger;
3028
import org.slf4j.LoggerFactory;
29+
30+
import com.amazonaws.services.lambda.runtime.Context;
31+
import com.amazonaws.services.lambda.runtime.events.DynamodbEvent;
32+
import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse;
33+
3134
import software.amazon.lambda.powertools.batch.internal.MultiThreadMDC;
35+
import software.amazon.lambda.powertools.batch.internal.XRayTraceEntityPropagator;
3236

3337
/**
3438
* A batch message processor for DynamoDB Streams batches.
@@ -43,8 +47,8 @@ public class DynamoDbBatchMessageHandler implements BatchMessageHandler<Dynamodb
4347
private final BiConsumer<DynamodbEvent.DynamodbStreamRecord, Context> rawMessageHandler;
4448

4549
public DynamoDbBatchMessageHandler(Consumer<DynamodbEvent.DynamodbStreamRecord> successHandler,
46-
BiConsumer<DynamodbEvent.DynamodbStreamRecord, Throwable> failureHandler,
47-
BiConsumer<DynamodbEvent.DynamodbStreamRecord, Context> rawMessageHandler) {
50+
BiConsumer<DynamodbEvent.DynamodbStreamRecord, Throwable> failureHandler,
51+
BiConsumer<DynamodbEvent.DynamodbStreamRecord, Context> rawMessageHandler) {
4852
this.successHandler = successHandler;
4953
this.failureHandler = failureHandler;
5054
this.rawMessageHandler = rawMessageHandler;
@@ -65,14 +69,23 @@ public StreamsEventResponse processBatch(DynamodbEvent event, Context context) {
6569
@Override
6670
public StreamsEventResponse processBatchInParallel(DynamodbEvent event, Context context) {
6771
MultiThreadMDC multiThreadMDC = new MultiThreadMDC();
72+
Object capturedSubsegment = XRayTraceEntityPropagator.captureTraceEntity();
6873

6974
List<StreamsEventResponse.BatchItemFailure> batchItemFailures = event.getRecords()
7075
.parallelStream() // Parallel processing
7176
.map(eventRecord -> {
72-
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
73-
Optional<StreamsEventResponse.BatchItemFailure> failureOpt = processBatchItem(eventRecord, context);
74-
multiThreadMDC.removeThread(Thread.currentThread().getName());
75-
return failureOpt;
77+
AtomicReference<Optional<StreamsEventResponse.BatchItemFailure>> result = new AtomicReference<>();
78+
79+
XRayTraceEntityPropagator.runWithEntity(capturedSubsegment, () -> {
80+
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
81+
try {
82+
result.set(processBatchItem(eventRecord, context));
83+
} finally {
84+
multiThreadMDC.removeThread(Thread.currentThread().getName());
85+
}
86+
});
87+
88+
return result.get();
7689
})
7790
.filter(Optional::isPresent)
7891
.map(Optional::get)
@@ -84,21 +97,29 @@ public StreamsEventResponse processBatchInParallel(DynamodbEvent event, Context
8497
@Override
8598
public StreamsEventResponse processBatchInParallel(DynamodbEvent event, Context context, Executor executor) {
8699
MultiThreadMDC multiThreadMDC = new MultiThreadMDC();
100+
Object capturedSubsegment = XRayTraceEntityPropagator.captureTraceEntity();
87101

88102
List<StreamsEventResponse.BatchItemFailure> batchItemFailures = new ArrayList<>();
89103
List<CompletableFuture<Void>> futures = event.getRecords().stream()
90104
.map(eventRecord -> CompletableFuture.runAsync(() -> {
91-
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
92-
Optional<StreamsEventResponse.BatchItemFailure> failureOpt = processBatchItem(eventRecord, context);
93-
failureOpt.ifPresent(batchItemFailures::add);
94-
multiThreadMDC.removeThread(Thread.currentThread().getName());
105+
XRayTraceEntityPropagator.runWithEntity(capturedSubsegment, () -> {
106+
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
107+
try {
108+
Optional<StreamsEventResponse.BatchItemFailure> failureOpt = processBatchItem(eventRecord,
109+
context);
110+
failureOpt.ifPresent(batchItemFailures::add);
111+
} finally {
112+
multiThreadMDC.removeThread(Thread.currentThread().getName());
113+
}
114+
});
95115
}, executor))
96116
.collect(Collectors.toList());
97117
futures.forEach(CompletableFuture::join);
98118
return StreamsEventResponse.builder().withBatchItemFailures(batchItemFailures).build();
99119
}
100120

101-
private Optional<StreamsEventResponse.BatchItemFailure> processBatchItem(DynamodbEvent.DynamodbStreamRecord streamRecord, Context context) {
121+
private Optional<StreamsEventResponse.BatchItemFailure> processBatchItem(
122+
DynamodbEvent.DynamodbStreamRecord streamRecord, Context context) {
102123
try {
103124
LOGGER.debug("Processing item {}", streamRecord.getEventID());
104125

@@ -124,7 +145,8 @@ private Optional<StreamsEventResponse.BatchItemFailure> processBatchItem(Dynamod
124145
LOGGER.warn("failureHandler threw handling failure", e2);
125146
}
126147
}
127-
return Optional.of(StreamsEventResponse.BatchItemFailure.builder().withItemIdentifier(sequenceNumber).build());
148+
return Optional
149+
.of(StreamsEventResponse.BatchItemFailure.builder().withItemIdentifier(sequenceNumber).build());
128150
}
129151
}
130152
}

powertools-batch/src/main/java/software/amazon/lambda/powertools/batch/handler/KinesisStreamsBatchMessageHandler.java

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,25 @@
1414

1515
package software.amazon.lambda.powertools.batch.handler;
1616

17-
18-
import com.amazonaws.services.lambda.runtime.Context;
19-
import com.amazonaws.services.lambda.runtime.events.KinesisEvent;
20-
import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse;
21-
2217
import java.util.ArrayList;
2318
import java.util.List;
2419
import java.util.Optional;
2520
import java.util.concurrent.CompletableFuture;
2621
import java.util.concurrent.Executor;
22+
import java.util.concurrent.atomic.AtomicReference;
2723
import java.util.function.BiConsumer;
2824
import java.util.function.Consumer;
2925
import java.util.stream.Collectors;
26+
3027
import org.slf4j.Logger;
3128
import org.slf4j.LoggerFactory;
29+
30+
import com.amazonaws.services.lambda.runtime.Context;
31+
import com.amazonaws.services.lambda.runtime.events.KinesisEvent;
32+
import com.amazonaws.services.lambda.runtime.events.StreamsEventResponse;
33+
3234
import software.amazon.lambda.powertools.batch.internal.MultiThreadMDC;
35+
import software.amazon.lambda.powertools.batch.internal.XRayTraceEntityPropagator;
3336
import software.amazon.lambda.powertools.utilities.EventDeserializer;
3437

3538
/**
@@ -49,10 +52,10 @@ public class KinesisStreamsBatchMessageHandler<M> implements BatchMessageHandler
4952
private final BiConsumer<KinesisEvent.KinesisEventRecord, Throwable> failureHandler;
5053

5154
public KinesisStreamsBatchMessageHandler(BiConsumer<KinesisEvent.KinesisEventRecord, Context> rawMessageHandler,
52-
BiConsumer<M, Context> messageHandler,
53-
Class<M> messageClass,
54-
Consumer<KinesisEvent.KinesisEventRecord> successHandler,
55-
BiConsumer<KinesisEvent.KinesisEventRecord, Throwable> failureHandler) {
55+
BiConsumer<M, Context> messageHandler,
56+
Class<M> messageClass,
57+
Consumer<KinesisEvent.KinesisEventRecord> successHandler,
58+
BiConsumer<KinesisEvent.KinesisEventRecord, Throwable> failureHandler) {
5659

5760
this.rawMessageHandler = rawMessageHandler;
5861
this.messageHandler = messageHandler;
@@ -76,14 +79,23 @@ public StreamsEventResponse processBatch(KinesisEvent event, Context context) {
7679
@Override
7780
public StreamsEventResponse processBatchInParallel(KinesisEvent event, Context context) {
7881
MultiThreadMDC multiThreadMDC = new MultiThreadMDC();
82+
Object capturedSubsegment = XRayTraceEntityPropagator.captureTraceEntity();
7983

8084
List<StreamsEventResponse.BatchItemFailure> batchItemFailures = event.getRecords()
8185
.parallelStream() // Parallel processing
8286
.map(eventRecord -> {
83-
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
84-
Optional<StreamsEventResponse.BatchItemFailure> failureOpt = processBatchItem(eventRecord, context);
85-
multiThreadMDC.removeThread(Thread.currentThread().getName());
86-
return failureOpt;
87+
AtomicReference<Optional<StreamsEventResponse.BatchItemFailure>> result = new AtomicReference<>();
88+
89+
XRayTraceEntityPropagator.runWithEntity(capturedSubsegment, () -> {
90+
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
91+
try {
92+
result.set(processBatchItem(eventRecord, context));
93+
} finally {
94+
multiThreadMDC.removeThread(Thread.currentThread().getName());
95+
}
96+
});
97+
98+
return result.get();
8799
})
88100
.filter(Optional::isPresent)
89101
.map(Optional::get)
@@ -95,21 +107,29 @@ public StreamsEventResponse processBatchInParallel(KinesisEvent event, Context c
95107
@Override
96108
public StreamsEventResponse processBatchInParallel(KinesisEvent event, Context context, Executor executor) {
97109
MultiThreadMDC multiThreadMDC = new MultiThreadMDC();
110+
Object capturedSubsegment = XRayTraceEntityPropagator.captureTraceEntity();
98111

99112
List<StreamsEventResponse.BatchItemFailure> batchItemFailures = new ArrayList<>();
100113
List<CompletableFuture<Void>> futures = event.getRecords().stream()
101114
.map(eventRecord -> CompletableFuture.runAsync(() -> {
102-
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
103-
Optional<StreamsEventResponse.BatchItemFailure> failureOpt = processBatchItem(eventRecord, context);
104-
failureOpt.ifPresent(batchItemFailures::add);
105-
multiThreadMDC.removeThread(Thread.currentThread().getName());
115+
XRayTraceEntityPropagator.runWithEntity(capturedSubsegment, () -> {
116+
multiThreadMDC.copyMDCToThread(Thread.currentThread().getName());
117+
try {
118+
Optional<StreamsEventResponse.BatchItemFailure> failureOpt = processBatchItem(eventRecord,
119+
context);
120+
failureOpt.ifPresent(batchItemFailures::add);
121+
} finally {
122+
multiThreadMDC.removeThread(Thread.currentThread().getName());
123+
}
124+
});
106125
}, executor))
107126
.collect(Collectors.toList());
108127
futures.forEach(CompletableFuture::join);
109128
return StreamsEventResponse.builder().withBatchItemFailures(batchItemFailures).build();
110129
}
111130

112-
private Optional<StreamsEventResponse.BatchItemFailure> processBatchItem(KinesisEvent.KinesisEventRecord eventRecord, Context context) {
131+
private Optional<StreamsEventResponse.BatchItemFailure> processBatchItem(
132+
KinesisEvent.KinesisEventRecord eventRecord, Context context) {
113133
try {
114134
LOGGER.debug("Processing item {}", eventRecord.getEventID());
115135

@@ -141,8 +161,8 @@ private Optional<StreamsEventResponse.BatchItemFailure> processBatchItem(Kinesis
141161
}
142162
}
143163

144-
return Optional.of(StreamsEventResponse.BatchItemFailure.builder().withItemIdentifier(eventRecord.getKinesis().getSequenceNumber()).build());
164+
return Optional.of(StreamsEventResponse.BatchItemFailure.builder()
165+
.withItemIdentifier(eventRecord.getKinesis().getSequenceNumber()).build());
145166
}
146167
}
147168
}
148-

0 commit comments

Comments
 (0)