|
1 | 1 | /*
|
2 |
| - * Copyright 2002-2022 the original author or authors. |
| 2 | + * Copyright 2002-2023 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
17 | 17 | package org.springframework.security.config.annotation.web.configuration;
|
18 | 18 |
|
19 | 19 | import java.net.URI;
|
| 20 | +import java.util.Arrays; |
20 | 21 | import java.util.HashMap;
|
21 | 22 | import java.util.Map;
|
| 23 | +import java.util.concurrent.Executors; |
| 24 | +import java.util.concurrent.Future; |
| 25 | +import java.util.concurrent.ThreadFactory; |
22 | 26 |
|
23 | 27 | import jakarta.servlet.http.HttpServletRequest;
|
24 | 28 | import jakarta.servlet.http.HttpServletResponse;
|
25 | 29 | import org.junit.jupiter.api.AfterEach;
|
26 | 30 | import org.junit.jupiter.api.BeforeEach;
|
27 | 31 | import org.junit.jupiter.api.Test;
|
| 32 | +import org.junit.jupiter.api.condition.DisabledOnJre; |
| 33 | +import org.junit.jupiter.api.condition.JRE; |
28 | 34 | import org.junit.jupiter.api.extension.ExtendWith;
|
29 | 35 | import reactor.core.CoreSubscriber;
|
30 | 36 | import reactor.core.publisher.BaseSubscriber;
|
|
35 | 41 |
|
36 | 42 | import org.springframework.context.annotation.Bean;
|
37 | 43 | import org.springframework.context.annotation.Configuration;
|
| 44 | +import org.springframework.core.task.SimpleAsyncTaskExecutor; |
| 45 | +import org.springframework.core.task.VirtualThreadTaskExecutor; |
38 | 46 | import org.springframework.http.HttpMethod;
|
39 | 47 | import org.springframework.http.HttpStatus;
|
40 | 48 | import org.springframework.mock.web.MockHttpServletRequest;
|
|
46 | 54 | import org.springframework.security.config.test.SpringTestContext;
|
47 | 55 | import org.springframework.security.config.test.SpringTestContextExtension;
|
48 | 56 | import org.springframework.security.core.Authentication;
|
| 57 | +import org.springframework.security.core.context.SecurityContext; |
49 | 58 | import org.springframework.security.core.context.SecurityContextHolder;
|
50 | 59 | import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
51 | 60 | import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
|
@@ -271,6 +280,58 @@ public void createPublisherWhenCustomSecurityContextHolderStrategyThenUses() {
|
271 | 280 | verify(strategy, times(2)).getContext();
|
272 | 281 | }
|
273 | 282 |
|
| 283 | + @Test |
| 284 | + public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception { |
| 285 | + this.spring.register(SecurityConfig.class).autowire(); |
| 286 | + |
| 287 | + ThreadFactory threadFactory = Executors.defaultThreadFactory(); |
| 288 | + assertContextAttributesAvailable(threadFactory); |
| 289 | + } |
| 290 | + |
| 291 | + @Test |
| 292 | + @DisabledOnJre(JRE.JAVA_17) |
| 293 | + public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception { |
| 294 | + this.spring.register(SecurityConfig.class).autowire(); |
| 295 | + |
| 296 | + ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory(); |
| 297 | + assertContextAttributesAvailable(threadFactory); |
| 298 | + } |
| 299 | + |
| 300 | + private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception { |
| 301 | + Map<Object, Object> expectedContextAttributes = new HashMap<>(); |
| 302 | + expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest); |
| 303 | + expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse); |
| 304 | + expectedContextAttributes.put(Authentication.class, this.authentication); |
| 305 | + |
| 306 | + try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) { |
| 307 | + Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes); |
| 308 | + assertThat(future.get()).isEqualTo(expectedContextAttributes); |
| 309 | + } |
| 310 | + } |
| 311 | + |
| 312 | + private Map<Object, Object> propagateRequestAttributes() { |
| 313 | + RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse); |
| 314 | + RequestContextHolder.setRequestAttributes(requestAttributes); |
| 315 | + |
| 316 | + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); |
| 317 | + securityContext.setAuthentication(this.authentication); |
| 318 | + SecurityContextHolder.setContext(securityContext); |
| 319 | + |
| 320 | + // @formatter:off |
| 321 | + return Mono.deferContextual(Mono::just) |
| 322 | + .filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) |
| 323 | + .map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) |
| 324 | + .map((attributes) -> { |
| 325 | + Map<Object, Object> map = new HashMap<>(); |
| 326 | + // Copy over items from lazily loaded map |
| 327 | + Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class) |
| 328 | + .forEach((key) -> map.put(key, attributes.get(key))); |
| 329 | + return map; |
| 330 | + }) |
| 331 | + .block(); |
| 332 | + // @formatter:on |
| 333 | + } |
| 334 | + |
274 | 335 | @Configuration
|
275 | 336 | @EnableWebSecurity
|
276 | 337 | static class SecurityConfig {
|
|
0 commit comments