1
1
/*
2
- * Copyright 2002-2020 the original author or authors.
2
+ * Copyright 2002-2021 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
16
16
17
17
package org .springframework .security .config .annotation .web .configuration ;
18
18
19
+ import java .util .Collection ;
19
20
import java .util .Collections ;
20
21
import java .util .HashMap ;
21
22
import java .util .Map ;
23
+ import java .util .Set ;
24
+ import java .util .concurrent .ConcurrentHashMap ;
22
25
import java .util .function .Function ;
26
+ import java .util .function .Supplier ;
23
27
24
28
import jakarta .servlet .http .HttpServletRequest ;
25
29
import jakarta .servlet .http .HttpServletResponse ;
36
40
import org .springframework .context .annotation .Bean ;
37
41
import org .springframework .context .annotation .Configuration ;
38
42
import org .springframework .security .core .Authentication ;
39
- import org .springframework .security .core .context .SecurityContext ;
40
43
import org .springframework .security .core .context .SecurityContextHolder ;
41
44
import org .springframework .web .context .request .RequestAttributes ;
42
45
import org .springframework .web .context .request .RequestContextHolder ;
@@ -68,17 +71,22 @@ static class SecurityReactorContextSubscriberRegistrar implements InitializingBe
68
71
69
72
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR" ;
70
73
74
+ private static final Map <Object , Supplier <Object >> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap <>();
75
+
76
+ static {
77
+ CONTEXT_ATTRIBUTE_VALUE_LOADERS .put (HttpServletRequest .class ,
78
+ SecurityReactorContextSubscriberRegistrar ::getRequest );
79
+ CONTEXT_ATTRIBUTE_VALUE_LOADERS .put (HttpServletResponse .class ,
80
+ SecurityReactorContextSubscriberRegistrar ::getResponse );
81
+ CONTEXT_ATTRIBUTE_VALUE_LOADERS .put (Authentication .class ,
82
+ SecurityReactorContextSubscriberRegistrar ::getAuthentication );
83
+ }
84
+
71
85
@ Override
72
86
public void afterPropertiesSet () throws Exception {
73
87
Function <? super Publisher <Object >, ? extends Publisher <Object >> lifter = Operators
74
88
.liftPublisher ((pub , sub ) -> createSubscriberIfNecessary (sub ));
75
- Hooks .onLastOperator (SECURITY_REACTOR_CONTEXT_OPERATOR_KEY , (pub ) -> {
76
- if (!contextAttributesAvailable ()) {
77
- // No need to decorate so return original Publisher
78
- return pub ;
79
- }
80
- return lifter .apply (pub );
81
- });
89
+ Hooks .onLastOperator (SECURITY_REACTOR_CONTEXT_OPERATOR_KEY , lifter ::apply );
82
90
}
83
91
84
92
@ Override
@@ -94,45 +102,30 @@ <T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
94
102
return new SecurityReactorContextSubscriber <>(delegate , getContextAttributes ());
95
103
}
96
104
97
- private static boolean contextAttributesAvailable () {
98
- SecurityContext context = SecurityContextHolder .peekContext ();
99
- Authentication authentication = null ;
100
- if (context != null ) {
101
- authentication = context .getAuthentication ();
102
- }
103
- return authentication != null
104
- || RequestContextHolder .getRequestAttributes () instanceof ServletRequestAttributes ;
105
+ private static Map <Object , Object > getContextAttributes () {
106
+ return new LoadingMap <>(CONTEXT_ATTRIBUTE_VALUE_LOADERS );
105
107
}
106
108
107
- private static Map <Object , Object > getContextAttributes () {
108
- HttpServletRequest servletRequest = null ;
109
- HttpServletResponse servletResponse = null ;
109
+ private static HttpServletRequest getRequest () {
110
110
RequestAttributes requestAttributes = RequestContextHolder .getRequestAttributes ();
111
111
if (requestAttributes instanceof ServletRequestAttributes ) {
112
112
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes ) requestAttributes ;
113
- servletRequest = servletRequestAttributes .getRequest ();
114
- servletResponse = servletRequestAttributes .getResponse (); // possible null
115
- }
116
- SecurityContext context = SecurityContextHolder .peekContext ();
117
- Authentication authentication = null ;
118
- if (context != null ) {
119
- authentication = context .getAuthentication ();
120
- }
121
- if (authentication == null && servletRequest == null ) {
122
- return Collections .emptyMap ();
123
- }
124
- Map <Object , Object > contextAttributes = new HashMap <>();
125
- if (servletRequest != null ) {
126
- contextAttributes .put (HttpServletRequest .class , servletRequest );
127
- }
128
- if (servletResponse != null ) {
129
- contextAttributes .put (HttpServletResponse .class , servletResponse );
113
+ return servletRequestAttributes .getRequest ();
130
114
}
131
- if (authentication != null ) {
132
- contextAttributes .put (Authentication .class , authentication );
115
+ return null ;
116
+ }
117
+
118
+ private static HttpServletResponse getResponse () {
119
+ RequestAttributes requestAttributes = RequestContextHolder .getRequestAttributes ();
120
+ if (requestAttributes instanceof ServletRequestAttributes ) {
121
+ ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes ) requestAttributes ;
122
+ return servletRequestAttributes .getResponse (); // possible null
133
123
}
124
+ return null ;
125
+ }
134
126
135
- return contextAttributes ;
127
+ private static Authentication getAuthentication () {
128
+ return SecurityContextHolder .getContext ().getAuthentication ();
136
129
}
137
130
138
131
}
@@ -185,4 +178,112 @@ public void onComplete() {
185
178
186
179
}
187
180
181
+ /**
182
+ * A map that computes each value when {@link #get} is invoked
183
+ */
184
+ static class LoadingMap <K , V > implements Map <K , V > {
185
+
186
+ private final Map <K , V > loaded = new ConcurrentHashMap <>();
187
+
188
+ private final Map <K , Supplier <V >> loaders ;
189
+
190
+ LoadingMap (Map <K , Supplier <V >> loaders ) {
191
+ this .loaders = Collections .unmodifiableMap (new HashMap <>(loaders ));
192
+ }
193
+
194
+ @ Override
195
+ public int size () {
196
+ return this .loaders .size ();
197
+ }
198
+
199
+ @ Override
200
+ public boolean isEmpty () {
201
+ return this .loaders .isEmpty ();
202
+ }
203
+
204
+ @ Override
205
+ public boolean containsKey (Object key ) {
206
+ return this .loaders .containsKey (key );
207
+ }
208
+
209
+ @ Override
210
+ public Set <K > keySet () {
211
+ return this .loaders .keySet ();
212
+ }
213
+
214
+ @ Override
215
+ public V get (Object key ) {
216
+ if (!this .loaders .containsKey (key )) {
217
+ throw new IllegalArgumentException (
218
+ "This map only supports the following keys: " + this .loaders .keySet ());
219
+ }
220
+ return this .loaded .computeIfAbsent ((K ) key , (k ) -> this .loaders .get (k ).get ());
221
+ }
222
+
223
+ @ Override
224
+ public V put (K key , V value ) {
225
+ if (!this .loaders .containsKey (key )) {
226
+ throw new IllegalArgumentException (
227
+ "This map only supports the following keys: " + this .loaders .keySet ());
228
+ }
229
+ return this .loaded .put (key , value );
230
+ }
231
+
232
+ @ Override
233
+ public V remove (Object key ) {
234
+ if (!this .loaders .containsKey (key )) {
235
+ throw new IllegalArgumentException (
236
+ "This map only supports the following keys: " + this .loaders .keySet ());
237
+ }
238
+ return this .loaded .remove (key );
239
+ }
240
+
241
+ @ Override
242
+ public void putAll (Map <? extends K , ? extends V > m ) {
243
+ for (Map .Entry <? extends K , ? extends V > entry : m .entrySet ()) {
244
+ put (entry .getKey (), entry .getValue ());
245
+ }
246
+ }
247
+
248
+ @ Override
249
+ public void clear () {
250
+ this .loaded .clear ();
251
+ }
252
+
253
+ @ Override
254
+ public boolean containsValue (Object value ) {
255
+ return this .loaded .containsValue (value );
256
+ }
257
+
258
+ @ Override
259
+ public Collection <V > values () {
260
+ return this .loaded .values ();
261
+ }
262
+
263
+ @ Override
264
+ public Set <Entry <K , V >> entrySet () {
265
+ return this .loaded .entrySet ();
266
+ }
267
+
268
+ @ Override
269
+ public boolean equals (Object o ) {
270
+ if (this == o ) {
271
+ return true ;
272
+ }
273
+ if (o == null || getClass () != o .getClass ()) {
274
+ return false ;
275
+ }
276
+
277
+ LoadingMap <?, ?> that = (LoadingMap <?, ?>) o ;
278
+
279
+ return this .loaded .equals (that .loaded );
280
+ }
281
+
282
+ @ Override
283
+ public int hashCode () {
284
+ return this .loaded .hashCode ();
285
+ }
286
+
287
+ }
288
+
188
289
}
0 commit comments