From 3465ef5eb075578f7d9a54f4688b3c158cca32d5 Mon Sep 17 00:00:00 2001 From: Artem Bilan Date: Thu, 14 Sep 2023 16:24:41 -0400 Subject: [PATCH 1/4] Fix ThreadSPropagationChInterceptor for stacking Related SO thread: https://stackoverflow.com/questions/77058188/multiple-threadstatepropagationchannelinterceptors-not-possible The current `ThreadStatePropagationChannelInterceptor` logic is to wrap one message to another (`MessageWithThreadState`), essentially stacking contexts. The `postReceive()` logic is to unwrap a `MessageWithThreadState`, therefore we deal with the latest pushed context which leads to the `ClassCastException` * Rework `ThreadStatePropagationChannelInterceptor` logic to reuse existing `MessageWithThreadState` and add the current context to its `stateQueue`. Therefore, the `postReceive()` will `poll()` the oldest context which is, essentially, the one populated by this interceptor before, according to the interceptors order * Fix `AbstractMessageChannel.setInterceptors()` to not modify provided list of interceptors * The new `ThreadStatePropagationChannelInterceptorTests` demonstrates the problem described in that mentioned SO question and verifies that context are propagated in the order they have been populated **Cherry-pick to `6.1.x` & `6.0.x`** --- .../channel/AbstractMessageChannel.java | 6 +- ...eadStatePropagationChannelInterceptor.java | 45 ++++++--- ...atePropagationChannelInterceptorTests.java | 95 +++++++++++++++++++ 3 files changed, 130 insertions(+), 16 deletions(-) create mode 100644 spring-integration-core/src/test/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptorTests.java diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java index 1517392ed36..2ddc0c0d241 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java @@ -17,6 +17,7 @@ package org.springframework.integration.channel; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; @@ -164,8 +165,9 @@ public void setDatatypes(Class... datatypes) { */ @Override public void setInterceptors(List interceptors) { - interceptors.sort(this.orderComparator); - this.interceptors.set(interceptors); + List interceptorsToUse = new ArrayList<>(interceptors); + interceptorsToUse.sort(this.orderComparator); + this.interceptors.set(interceptorsToUse); } /** diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java index e8ca1124559..0d8511ac4d4 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java @@ -16,6 +16,9 @@ package org.springframework.integration.channel.interceptor; +import java.util.ArrayDeque; +import java.util.Queue; + import io.micrometer.common.lang.Nullable; import org.springframework.integration.support.MessageDecorator; @@ -58,20 +61,27 @@ public abstract class ThreadStatePropagationChannelInterceptor implements Exe public final Message preSend(Message message, MessageChannel channel) { S threadContext = obtainPropagatingContext(message, channel); if (threadContext != null) { - return new MessageWithThreadState<>(message, threadContext); - } - else { - return message; + if (message instanceof MessageWithThreadState messageWithThreadState) { + messageWithThreadState.stateQueue.add(threadContext); + } + else { + return new MessageWithThreadState(message, threadContext); + } } + + return message; } @Override @SuppressWarnings("unchecked") public final Message postReceive(Message message, MessageChannel channel) { - if (message instanceof MessageWithThreadState) { - MessageWithThreadState messageWithThreadState = (MessageWithThreadState) message; - Message messageToHandle = messageWithThreadState.message; - populatePropagatedContext(messageWithThreadState.state, messageToHandle, channel); + if (message instanceof MessageWithThreadState messageWithThreadState) { + Object threadContext = messageWithThreadState.stateQueue.poll(); + Message messageToHandle = messageWithThreadState; + if (messageWithThreadState.stateQueue.isEmpty()) { + messageToHandle = messageWithThreadState.message; + } + populatePropagatedContext((S) threadContext, messageToHandle, channel); return messageToHandle; } return message; @@ -88,16 +98,23 @@ public final Message beforeHandle(Message message, MessageChannel channel, protected abstract void populatePropagatedContext(@Nullable S state, Message message, MessageChannel channel); - private static final class MessageWithThreadState implements Message, MessageDecorator { + private static final class MessageWithThreadState implements Message, MessageDecorator { private final Message message; - private final S state; + private final Queue stateQueue; + + @SuppressWarnings("unchecked") + MessageWithThreadState(Message message, Object state) { + this.message = (Message) message; + this.stateQueue = new ArrayDeque<>(); + this.stateQueue.add(state); + } @SuppressWarnings("unchecked") - MessageWithThreadState(Message message, S state) { + private MessageWithThreadState(Message message, Queue stateQueue) { this.message = (Message) message; - this.state = state; + this.stateQueue = stateQueue; } @Override @@ -112,14 +129,14 @@ public MessageHeaders getHeaders() { @Override public Message decorateMessage(Message message) { - return new MessageWithThreadState<>(message, this.state); + return new MessageWithThreadState(message, this.stateQueue); } @Override public String toString() { return "MessageWithThreadState{" + "message=" + this.message + - ", state=" + this.state + + ", state=" + this.stateQueue + '}'; } diff --git a/spring-integration-core/src/test/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptorTests.java b/spring-integration-core/src/test/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptorTests.java new file mode 100644 index 00000000000..5d48a0271a1 --- /dev/null +++ b/spring-integration-core/src/test/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptorTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.channel.interceptor; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.integration.channel.ExecutorChannel; +import org.springframework.integration.util.ErrorHandlingTaskExecutor; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.support.GenericMessage; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * @author Artem Bilan + * + * @since 6.2 + */ +public class ThreadStatePropagationChannelInterceptorTests { + + @Test + void ThreadStatePropagationChannelInterceptorsCanBeStacked() { + TestContext1 ctx1 = new TestContext1(); + TestContext2 ctx2 = new TestContext2(); + + List propagatedContexts = new ArrayList<>(); + + var interceptor1 = new ThreadStatePropagationChannelInterceptor() { + @Override + protected TestContext1 obtainPropagatingContext(Message message, MessageChannel channel) { + return ctx1; + } + + @Override + protected void populatePropagatedContext(TestContext1 state, Message message, MessageChannel channel) { + propagatedContexts.add(state); + } + + }; + + var interceptor2 = new ThreadStatePropagationChannelInterceptor() { + @Override + protected TestContext2 obtainPropagatingContext(Message message, MessageChannel channel) { + return ctx2; + } + + @Override + protected void populatePropagatedContext(TestContext2 state, Message message, MessageChannel channel) { + propagatedContexts.add(state); + } + + }; + + ExecutorChannel testChannel = new ExecutorChannel( + new ErrorHandlingTaskExecutor(new SyncTaskExecutor(), ReflectionUtils::rethrowRuntimeException)); + testChannel.setInterceptors(List.of(interceptor1, interceptor2)); + testChannel.setBeanFactory(mock()); + testChannel.afterPropertiesSet(); + testChannel.subscribe(m -> { + }); + + testChannel.send(new GenericMessage<>("test data")); + + assertThat(propagatedContexts.get(0)).isEqualTo(ctx1); + assertThat(propagatedContexts.get(1)).isEqualTo(ctx2); + } + + private record TestContext1() { + } + + private record TestContext2() { + } + +} From 662f5fe838242cdf20fa0a6f2a344f5f427f9663 Mon Sep 17 00:00:00 2001 From: Artem Bilan Date: Thu, 14 Sep 2023 17:51:51 -0400 Subject: [PATCH 2/4] * Fix `ThreadStatePropagationChannelInterceptor` for publish-subscribe scenario. Essentially, copy the state queue to a new decorated message * Fix `BroadcastingDispatcher` to always decorate message, even if not `applySequence` --- ...eadStatePropagationChannelInterceptor.java | 7 ++-- .../dispatcher/BroadcastingDispatcher.java | 38 +++++++++---------- 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java index 0d8511ac4d4..020017ec0a8 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java @@ -17,6 +17,7 @@ package org.springframework.integration.channel.interceptor; import java.util.ArrayDeque; +import java.util.LinkedList; import java.util.Queue; import io.micrometer.common.lang.Nullable; @@ -104,17 +105,15 @@ private static final class MessageWithThreadState implements Message, Me private final Queue stateQueue; - @SuppressWarnings("unchecked") MessageWithThreadState(Message message, Object state) { - this.message = (Message) message; - this.stateQueue = new ArrayDeque<>(); + this(message, new LinkedList<>()); this.stateQueue.add(state); } @SuppressWarnings("unchecked") private MessageWithThreadState(Message message, Queue stateQueue) { this.message = (Message) message; - this.stateQueue = stateQueue; + this.stateQueue = new LinkedList<>(stateQueue); } @Override diff --git a/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java b/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java index ec843b2444c..93658e91c07 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,13 +57,13 @@ public class BroadcastingDispatcher extends AbstractDispatcher implements BeanFa private final boolean requireSubscribers; - private volatile boolean ignoreFailures; + private final Executor executor; - private volatile boolean applySequence; + private boolean ignoreFailures; - private final Executor executor; + private boolean applySequence; - private volatile int minSubscribers; + private int minSubscribers; private MessageHandlingTaskDecorator messageHandlingTaskDecorator = task -> task; @@ -149,24 +149,20 @@ public boolean dispatch(Message message) { int dispatched = 0; int sequenceNumber = 1; Collection handlers = this.getHandlers(); - if (this.requireSubscribers && handlers.size() == 0) { + if (this.requireSubscribers && handlers.isEmpty()) { throw new MessageDispatchingException(message, "Dispatcher has no subscribers"); } int sequenceSize = handlers.size(); Message messageToSend = message; - UUID sequenceId = null; - if (this.applySequence) { - sequenceId = message.getHeaders().getId(); - } for (MessageHandler handler : handlers) { if (this.applySequence) { messageToSend = getMessageBuilderFactory() .fromMessage(message) - .pushSequenceDetails(sequenceId, sequenceNumber++, sequenceSize) + .pushSequenceDetails(message.getHeaders().getId(), sequenceNumber++, sequenceSize) .build(); - if (message instanceof MessageDecorator) { - messageToSend = ((MessageDecorator) message).decorateMessage(messageToSend); - } + } + if (message instanceof MessageDecorator messageDecorator) { + messageToSend = messageDecorator.decorateMessage(messageToSend); } if (this.executor != null) { @@ -175,7 +171,7 @@ public boolean dispatch(Message message) { dispatched++; } else { - if (this.invokeHandler(handler, messageToSend)) { + if (invokeHandler(handler, messageToSend)) { dispatched++; } } @@ -222,15 +218,15 @@ private boolean invokeHandler(MessageHandler handler, Message message) { handler.handleMessage(message); return true; } - catch (RuntimeException e) { + catch (RuntimeException ex) { if (!this.ignoreFailures) { - if (e instanceof MessagingException && ((MessagingException) e).getFailedMessage() == null) { // NOSONAR - throw new MessagingException(message, "Failed to handle Message", e); + if (ex instanceof MessagingException exception && exception.getFailedMessage() == null) { // NOSONAR + throw new MessagingException(message, "Failed to handle Message", ex); } - throw e; + throw ex; } - else if (this.logger.isWarnEnabled()) { - logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", e); + else { + logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", ex); } return false; } From 9621886d1fc0ef2800b2cc1ebd56a8ca073dda47 Mon Sep 17 00:00:00 2001 From: Artem Bilan Date: Thu, 14 Sep 2023 17:57:47 -0400 Subject: [PATCH 3/4] * Fix unused import in the `BroadcastingDispatcher` --- .../integration/dispatcher/BroadcastingDispatcher.java | 1 - 1 file changed, 1 deletion(-) diff --git a/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java b/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java index 93658e91c07..43dec106abb 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java @@ -17,7 +17,6 @@ package org.springframework.integration.dispatcher; import java.util.Collection; -import java.util.UUID; import java.util.concurrent.Executor; import org.springframework.beans.BeansException; From d5af3b6d48c28e1844890e2507cc1b70b73d3bde Mon Sep 17 00:00:00 2001 From: Artem Bilan Date: Thu, 14 Sep 2023 18:11:53 -0400 Subject: [PATCH 4/4] * Fix unused import in the `ThreadStatePropagationChannelInterceptor` --- .../interceptor/ThreadStatePropagationChannelInterceptor.java | 1 - 1 file changed, 1 deletion(-) diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java index 020017ec0a8..19408c21249 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java @@ -16,7 +16,6 @@ package org.springframework.integration.channel.interceptor; -import java.util.ArrayDeque; import java.util.LinkedList; import java.util.Queue;