Skip to content

Commit ff37493

Browse files
author
Steve Riesenberg
committed
Verify ReactorContext when using Virtual Threads
Closes gh-12791
1 parent d6fac11 commit ff37493

File tree

4 files changed

+190
-4
lines changed

4 files changed

+190
-4
lines changed

config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java

+62-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,14 +17,20 @@
1717
package org.springframework.security.config.annotation.web.configuration;
1818

1919
import java.net.URI;
20+
import java.util.Arrays;
2021
import java.util.HashMap;
2122
import java.util.Map;
23+
import java.util.concurrent.Executors;
24+
import java.util.concurrent.Future;
25+
import java.util.concurrent.ThreadFactory;
2226

2327
import jakarta.servlet.http.HttpServletRequest;
2428
import jakarta.servlet.http.HttpServletResponse;
2529
import org.junit.jupiter.api.AfterEach;
2630
import org.junit.jupiter.api.BeforeEach;
2731
import org.junit.jupiter.api.Test;
32+
import org.junit.jupiter.api.condition.DisabledOnJre;
33+
import org.junit.jupiter.api.condition.JRE;
2834
import org.junit.jupiter.api.extension.ExtendWith;
2935
import reactor.core.CoreSubscriber;
3036
import reactor.core.publisher.BaseSubscriber;
@@ -35,6 +41,8 @@
3541

3642
import org.springframework.context.annotation.Bean;
3743
import org.springframework.context.annotation.Configuration;
44+
import org.springframework.core.task.SimpleAsyncTaskExecutor;
45+
import org.springframework.core.task.VirtualThreadTaskExecutor;
3846
import org.springframework.http.HttpMethod;
3947
import org.springframework.http.HttpStatus;
4048
import org.springframework.mock.web.MockHttpServletRequest;
@@ -46,6 +54,7 @@
4654
import org.springframework.security.config.test.SpringTestContext;
4755
import org.springframework.security.config.test.SpringTestContextExtension;
4856
import org.springframework.security.core.Authentication;
57+
import org.springframework.security.core.context.SecurityContext;
4958
import org.springframework.security.core.context.SecurityContextHolder;
5059
import org.springframework.security.core.context.SecurityContextHolderStrategy;
5160
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
@@ -271,6 +280,58 @@ public void createPublisherWhenCustomSecurityContextHolderStrategyThenUses() {
271280
verify(strategy, times(2)).getContext();
272281
}
273282

283+
@Test
284+
public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception {
285+
this.spring.register(SecurityConfig.class).autowire();
286+
287+
ThreadFactory threadFactory = Executors.defaultThreadFactory();
288+
assertContextAttributesAvailable(threadFactory);
289+
}
290+
291+
@Test
292+
@DisabledOnJre(JRE.JAVA_17)
293+
public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception {
294+
this.spring.register(SecurityConfig.class).autowire();
295+
296+
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
297+
assertContextAttributesAvailable(threadFactory);
298+
}
299+
300+
private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception {
301+
Map<Object, Object> expectedContextAttributes = new HashMap<>();
302+
expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
303+
expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
304+
expectedContextAttributes.put(Authentication.class, this.authentication);
305+
306+
try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) {
307+
Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes);
308+
assertThat(future.get()).isEqualTo(expectedContextAttributes);
309+
}
310+
}
311+
312+
private Map<Object, Object> propagateRequestAttributes() {
313+
RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse);
314+
RequestContextHolder.setRequestAttributes(requestAttributes);
315+
316+
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
317+
securityContext.setAuthentication(this.authentication);
318+
SecurityContextHolder.setContext(securityContext);
319+
320+
// @formatter:off
321+
return Mono.deferContextual(Mono::just)
322+
.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
323+
.map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
324+
.map((attributes) -> {
325+
Map<Object, Object> map = new HashMap<>();
326+
// Copy over items from lazily loaded map
327+
Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class)
328+
.forEach((key) -> map.put(key, attributes.get(key)));
329+
return map;
330+
})
331+
.block();
332+
// @formatter:on
333+
}
334+
274335
@Configuration
275336
@EnableWebSecurity
276337
static class SecurityConfig {

core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java

+57-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,10 +16,17 @@
1616

1717
package org.springframework.security.core.context;
1818

19+
import java.util.concurrent.Executors;
20+
import java.util.concurrent.ThreadFactory;
21+
1922
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.api.condition.DisabledOnJre;
24+
import org.junit.jupiter.api.condition.JRE;
2025
import reactor.core.publisher.Mono;
26+
import reactor.core.scheduler.Schedulers;
2127
import reactor.test.StepVerifier;
2228

29+
import org.springframework.core.task.VirtualThreadTaskExecutor;
2330
import org.springframework.security.authentication.TestingAuthenticationToken;
2431
import org.springframework.security.core.Authentication;
2532

@@ -99,4 +106,53 @@ public void setAuthenticationAndGetContextThenEmitsContext() {
99106
// @formatter:on
100107
}
101108

109+
@Test
110+
public void getContextWhenThreadFactoryIsPlatformThenPropagated() {
111+
verifySecurityContextIsPropagated(Executors.defaultThreadFactory());
112+
}
113+
114+
@Test
115+
@DisabledOnJre(JRE.JAVA_17)
116+
public void getContextWhenThreadFactoryIsVirtualThenPropagated() {
117+
verifySecurityContextIsPropagated(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
118+
}
119+
120+
private static void verifySecurityContextIsPropagated(ThreadFactory threadFactory) {
121+
Authentication authentication = new TestingAuthenticationToken("user", null);
122+
123+
// @formatter:off
124+
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
125+
.map(SecurityContext::getAuthentication)
126+
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
127+
.subscribeOn(Schedulers.newSingle(threadFactory));
128+
// @formatter:on
129+
130+
StepVerifier.create(publisher).expectNext(authentication).verifyComplete();
131+
}
132+
133+
@Test
134+
public void clearContextWhenThreadFactoryIsPlatformThenCleared() {
135+
verifySecurityContextIsCleared(Executors.defaultThreadFactory());
136+
}
137+
138+
@Test
139+
@DisabledOnJre(JRE.JAVA_17)
140+
public void clearContextWhenThreadFactoryIsVirtualThenCleared() {
141+
verifySecurityContextIsCleared(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
142+
}
143+
144+
private static void verifySecurityContextIsCleared(ThreadFactory threadFactory) {
145+
Authentication authentication = new TestingAuthenticationToken("user", null);
146+
147+
// @formatter:off
148+
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
149+
.map(SecurityContext::getAuthentication)
150+
.contextWrite(ReactiveSecurityContextHolder.clearContext())
151+
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
152+
.subscribeOn(Schedulers.newSingle(threadFactory));
153+
// @formatter:on
154+
155+
StepVerifier.create(publisher).verifyComplete();
156+
}
157+
102158
}

web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
1717
package org.springframework.security.web.server.context;
1818

1919
import java.util.List;
20+
import java.util.concurrent.Executors;
21+
import java.util.concurrent.ThreadFactory;
2022

2123
import org.junit.jupiter.api.BeforeEach;
2224
import org.junit.jupiter.api.Test;
25+
import org.junit.jupiter.api.condition.DisabledOnJre;
26+
import org.junit.jupiter.api.condition.JRE;
2327
import org.junit.jupiter.api.extension.ExtendWith;
2428
import org.mockito.Mock;
2529
import org.mockito.junit.jupiter.MockitoExtension;
2630
import reactor.core.publisher.Mono;
31+
import reactor.core.scheduler.Schedulers;
2732
import reactor.test.StepVerifier;
2833
import reactor.test.publisher.TestPublisher;
2934
import reactor.util.context.Context;
3035

36+
import org.springframework.core.task.VirtualThreadTaskExecutor;
3137
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
3238
import org.springframework.mock.web.server.MockServerWebExchange;
3339
import org.springframework.security.core.Authentication;
@@ -117,4 +123,32 @@ public void filterWhenMainContextThenDoesNotOverride() {
117123
StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete();
118124
}
119125

126+
@Test
127+
public void filterWhenThreadFactoryIsPlatformThenSecurityContextLoaded() {
128+
ThreadFactory threadFactory = Executors.defaultThreadFactory();
129+
assertSecurityContextLoaded(threadFactory);
130+
}
131+
132+
@Test
133+
@DisabledOnJre(JRE.JAVA_17)
134+
public void filterWhenThreadFactoryIsVirtualThenSecurityContextLoaded() {
135+
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
136+
assertSecurityContextLoaded(threadFactory);
137+
}
138+
139+
private void assertSecurityContextLoaded(ThreadFactory threadFactory) {
140+
SecurityContextImpl context = new SecurityContextImpl(this.principal);
141+
given(this.repository.load(any())).willReturn(Mono.just(context));
142+
// @formatter:off
143+
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
144+
.subscribeOn(Schedulers.newSingle(threadFactory));
145+
WebFilter assertSecurityContext = (exchange, chain) -> ReactiveSecurityContextHolder.getContext()
146+
.map(SecurityContext::getAuthentication)
147+
.doOnSuccess((authentication) -> assertThat(authentication).isSameAs(this.principal))
148+
.then(chain.filter(exchange));
149+
// @formatter:on
150+
this.handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, assertSecurityContext);
151+
this.handler.exchange(this.exchange);
152+
}
153+
120154
}

web/src/test/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilterTests.java

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,17 +17,25 @@
1717
package org.springframework.security.web.server.context;
1818

1919
import java.util.Collections;
20+
import java.util.concurrent.Executors;
21+
import java.util.concurrent.ThreadFactory;
2022

2123
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.condition.DisabledOnJre;
25+
import org.junit.jupiter.api.condition.JRE;
2226
import reactor.core.publisher.Mono;
27+
import reactor.core.scheduler.Schedulers;
2328
import reactor.test.StepVerifier;
2429

30+
import org.springframework.core.task.VirtualThreadTaskExecutor;
2531
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
2632
import org.springframework.mock.web.server.MockServerWebExchange;
2733
import org.springframework.security.authentication.TestingAuthenticationToken;
2834
import org.springframework.security.core.Authentication;
2935
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
36+
import org.springframework.security.test.web.reactive.server.WebTestHandler;
3037
import org.springframework.web.server.ServerWebExchange;
38+
import org.springframework.web.server.WebFilter;
3139
import org.springframework.web.server.handler.DefaultWebFilterChain;
3240

3341
import static org.assertj.core.api.Assertions.assertThat;
@@ -80,4 +88,31 @@ public void filterWhenPrincipalNullThenContextEmpty() {
8088
StepVerifier.create(result).verifyComplete();
8189
}
8290

91+
@Test
92+
public void filterWhenThreadFactoryIsPlatformThenContextPopulated() {
93+
ThreadFactory threadFactory = Executors.defaultThreadFactory();
94+
assertPrincipalPopulated(threadFactory);
95+
}
96+
97+
@Test
98+
@DisabledOnJre(JRE.JAVA_17)
99+
public void filterWhenThreadFactoryIsVirtualThenContextPopulated() {
100+
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
101+
assertPrincipalPopulated(threadFactory);
102+
}
103+
104+
private void assertPrincipalPopulated(ThreadFactory threadFactory) {
105+
// @formatter:off
106+
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
107+
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.principal))
108+
.subscribeOn(Schedulers.newSingle(threadFactory));
109+
WebFilter assertPrincipal = (exchange, chain) -> exchange.getPrincipal()
110+
.doOnSuccess((principal) -> assertThat(principal).isSameAs(this.principal))
111+
.then(chain.filter(exchange));
112+
// @formatter:on
113+
WebTestHandler handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter,
114+
assertPrincipal);
115+
handler.exchange(this.exchange);
116+
}
117+
83118
}

0 commit comments

Comments
 (0)