Skip to content

Commit 2bc9371

Browse files
committed
Reactive support in MethodValidationInterceptor
Closes gh-20781
1 parent b110a39 commit 2bc9371

File tree

4 files changed

+227
-0
lines changed

4 files changed

+227
-0
lines changed

spring-context/spring-context.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies {
4646
testImplementation("org.awaitility:awaitility")
4747
testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-core")
4848
testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactor")
49+
testImplementation("io.projectreactor:reactor-test")
4950
testImplementation("io.reactivex.rxjava3:rxjava")
5051
testImplementation('io.micrometer:context-propagation')
5152
testImplementation("io.micrometer:micrometer-observation-test")

spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationAdapter.java

+7
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ private static Supplier<SpringValidatorAdapter> initValidatorAdapter(Supplier<Va
140140
}
141141

142142

143+
/**
144+
* Return the {@link SpringValidatorAdapter} configured for use.
145+
*/
146+
public Supplier<SpringValidatorAdapter> getSpringValidatorAdapter() {
147+
return this.validatorAdapter;
148+
}
149+
143150
/**
144151
* Set the strategy to use to determine message codes for violations.
145152
* <p>Default is a DefaultMessageCodesResolver.

spring-context/src/main/java/org/springframework/validation/beanvalidation/MethodValidationInterceptor.java

+98
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,39 @@
1717
package org.springframework.validation.beanvalidation;
1818

1919
import java.lang.reflect.Method;
20+
import java.lang.reflect.Parameter;
21+
import java.util.Collections;
22+
import java.util.List;
2023
import java.util.Set;
2124
import java.util.function.Supplier;
2225

2326
import jakarta.validation.ConstraintViolation;
2427
import jakarta.validation.ConstraintViolationException;
28+
import jakarta.validation.Valid;
2529
import jakarta.validation.Validator;
2630
import jakarta.validation.ValidatorFactory;
2731
import org.aopalliance.intercept.MethodInterceptor;
2832
import org.aopalliance.intercept.MethodInvocation;
33+
import reactor.core.publisher.Flux;
34+
import reactor.core.publisher.Mono;
2935

3036
import org.springframework.aop.ProxyMethodInvocation;
3137
import org.springframework.beans.factory.FactoryBean;
3238
import org.springframework.beans.factory.SmartFactoryBean;
39+
import org.springframework.core.MethodParameter;
40+
import org.springframework.core.ReactiveAdapter;
41+
import org.springframework.core.ReactiveAdapterRegistry;
42+
import org.springframework.core.annotation.AnnotationUtils;
3343
import org.springframework.lang.Nullable;
3444
import org.springframework.util.Assert;
3545
import org.springframework.util.ClassUtils;
46+
import org.springframework.validation.BeanPropertyBindingResult;
47+
import org.springframework.validation.Errors;
3648
import org.springframework.validation.annotation.Validated;
3749
import org.springframework.validation.method.MethodValidationException;
3850
import org.springframework.validation.method.MethodValidationResult;
51+
import org.springframework.validation.method.ParameterErrors;
52+
import org.springframework.validation.method.ParameterValidationResult;
3953

4054
/**
4155
* An AOP Alliance {@link MethodInterceptor} implementation that delegates to a
@@ -65,6 +79,10 @@
6579
*/
6680
public class MethodValidationInterceptor implements MethodInterceptor {
6781

82+
private static final boolean REACTOR_PRESENT =
83+
ClassUtils.isPresent("reactor.core.publisher.Mono", MethodValidationInterceptor.class.getClassLoader());
84+
85+
6886
private final MethodValidationAdapter validationAdapter;
6987

7088
private final boolean adaptViolations;
@@ -135,6 +153,12 @@ public Object invoke(MethodInvocation invocation) throws Throwable {
135153
Object[] arguments = invocation.getArguments();
136154
Class<?>[] groups = determineValidationGroups(invocation);
137155

156+
if (REACTOR_PRESENT) {
157+
arguments = ReactorValidationHelper.insertAsyncValidation(
158+
this.validationAdapter.getSpringValidatorAdapter(), this.adaptViolations,
159+
target, method, arguments);
160+
}
161+
138162
Set<ConstraintViolation<Object>> violations;
139163

140164
if (this.adaptViolations) {
@@ -206,4 +230,78 @@ protected Class<?>[] determineValidationGroups(MethodInvocation invocation) {
206230
return this.validationAdapter.determineValidationGroups(target, invocation.getMethod());
207231
}
208232

233+
234+
/**
235+
* Helper class to decorate reactive arguments with async validation.
236+
*/
237+
private final static class ReactorValidationHelper {
238+
239+
private static final ReactiveAdapterRegistry reactiveAdapterRegistry =
240+
ReactiveAdapterRegistry.getSharedInstance();
241+
242+
243+
public static Object[] insertAsyncValidation(
244+
Supplier<SpringValidatorAdapter> validatorAdapterSupplier, boolean adaptViolations,
245+
Object target, Method method, Object[] arguments) {
246+
247+
for (int i = 0; i < method.getParameterCount(); i++) {
248+
if (arguments[i] == null) {
249+
continue;
250+
}
251+
Class<?> parameterType = method.getParameterTypes()[i];
252+
ReactiveAdapter reactiveAdapter = reactiveAdapterRegistry.getAdapter(parameterType);
253+
if (reactiveAdapter == null || reactiveAdapter.isNoValue()) {
254+
continue;
255+
}
256+
Class<?>[] groups = determineValidationGroups(method.getParameters()[i]);
257+
if (groups == null) {
258+
continue;
259+
}
260+
SpringValidatorAdapter validatorAdapter = validatorAdapterSupplier.get();
261+
MethodParameter param = new MethodParameter(method, i);
262+
arguments[i] = (reactiveAdapter.isMultiValue() ?
263+
Flux.from(reactiveAdapter.toPublisher(arguments[i])).doOnNext(value ->
264+
validate(validatorAdapter, adaptViolations, target, method, param, value, groups)) :
265+
Mono.from(reactiveAdapter.toPublisher(arguments[i])).doOnNext(value ->
266+
validate(validatorAdapter, adaptViolations, target, method, param, value, groups)));
267+
}
268+
return arguments;
269+
}
270+
271+
@Nullable
272+
private static Class<?>[] determineValidationGroups(Parameter parameter) {
273+
Validated validated = AnnotationUtils.findAnnotation(parameter, Validated.class);
274+
if (validated != null) {
275+
return validated.value();
276+
}
277+
Valid valid = AnnotationUtils.findAnnotation(parameter, Valid.class);
278+
if (valid != null) {
279+
return new Class<?>[0];
280+
}
281+
return null;
282+
}
283+
284+
@SuppressWarnings("unchecked")
285+
private static <T> void validate(
286+
SpringValidatorAdapter validatorAdapter, boolean adaptViolations,
287+
Object target, Method method, MethodParameter parameter, Object argument, Class<?>[] groups) {
288+
289+
if (adaptViolations) {
290+
Errors errors = new BeanPropertyBindingResult(argument, argument.getClass().getSimpleName());
291+
validatorAdapter.validate(argument, errors);
292+
if (errors.hasErrors()) {
293+
ParameterErrors paramErrors = new ParameterErrors(parameter, argument, errors, null, null, null);
294+
List<ParameterValidationResult> results = Collections.singletonList(paramErrors);
295+
throw new MethodValidationException(MethodValidationResult.create(target, method, results));
296+
}
297+
}
298+
else {
299+
Set<ConstraintViolation<T>> violations = validatorAdapter.validate((T) argument, groups);
300+
if (!violations.isEmpty()) {
301+
throw new ConstraintViolationException(violations);
302+
}
303+
}
304+
}
305+
}
306+
209307
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.validation.beanvalidation;
18+
19+
import java.util.Set;
20+
21+
import jakarta.validation.ConstraintViolation;
22+
import jakarta.validation.ConstraintViolationException;
23+
import jakarta.validation.Valid;
24+
import jakarta.validation.Validation;
25+
import jakarta.validation.Validator;
26+
import jakarta.validation.constraints.Size;
27+
import org.junit.jupiter.api.Test;
28+
import reactor.core.publisher.Flux;
29+
import reactor.core.publisher.Mono;
30+
import reactor.test.StepVerifier;
31+
32+
import org.springframework.aop.framework.ProxyFactory;
33+
import org.springframework.validation.method.MethodValidationException;
34+
import org.springframework.validation.method.ParameterErrors;
35+
36+
import static org.assertj.core.api.Assertions.assertThat;
37+
38+
/**
39+
*
40+
*/
41+
public class MethodValidationProxyReactorTests {
42+
43+
@Test
44+
void validMonoArgument() {
45+
MyService myService = initProxy(new MyService(), false);
46+
Mono<Person> personMono = Mono.just(new Person("Faustino1234"));
47+
48+
StepVerifier.create(myService.addPerson(personMono))
49+
.expectErrorSatisfies(t -> {
50+
ConstraintViolationException ex = (ConstraintViolationException) t;
51+
Set<ConstraintViolation<?>> violations = ex.getConstraintViolations();
52+
assertThat(violations).hasSize(1);
53+
assertThat(violations.iterator().next().getMessage()).isEqualTo("size must be between 1 and 10");
54+
})
55+
.verify();
56+
}
57+
58+
@Test
59+
void validFluxArgument() {
60+
MyService myService = initProxy(new MyService(), false);
61+
Flux<Person> personFlux = Flux.just(new Person("Faust"), new Person("Faustino1234"));
62+
63+
StepVerifier.create(myService.addPersons(personFlux))
64+
.expectErrorSatisfies(t -> {
65+
ConstraintViolationException ex = (ConstraintViolationException) t;
66+
Set<ConstraintViolation<?>> violations = ex.getConstraintViolations();
67+
assertThat(violations).hasSize(1);
68+
assertThat(violations.iterator().next().getMessage()).isEqualTo("size must be between 1 and 10");
69+
})
70+
.verify();
71+
}
72+
73+
@Test
74+
void validMonoArgumentWithAdaptedViolations() {
75+
MyService myService = initProxy(new MyService(), true);
76+
Mono<Person> personMono = Mono.just(new Person("Faustino1234"));
77+
78+
StepVerifier.create(myService.addPerson(personMono))
79+
.expectErrorSatisfies(t -> {
80+
MethodValidationException ex = (MethodValidationException) t;
81+
assertThat(ex.getAllValidationResults()).hasSize(1);
82+
83+
ParameterErrors errors = ex.getBeanResults().get(0);
84+
assertThat(errors.getErrorCount()).isEqualTo(1);
85+
assertThat(errors.getFieldErrors().get(0).toString()).isEqualTo("""
86+
Field error in object 'Person' on field 'name': rejected value [Faustino1234]; \
87+
codes [Size.Person.name,Size.name,Size.java.lang.String,Size]; \
88+
arguments [org.springframework.context.support.DefaultMessageSourceResolvable: \
89+
codes [Person.name,name]; arguments []; default message [name],10,1]; \
90+
default message [size must be between 1 and 10]""");
91+
})
92+
.verify();
93+
}
94+
95+
private static MyService initProxy(Object target, boolean adaptViolations) {
96+
Validator validator = Validation.buildDefaultValidatorFactory().getValidator();
97+
MethodValidationInterceptor interceptor = new MethodValidationInterceptor(() -> validator, adaptViolations);
98+
ProxyFactory factory = new ProxyFactory(target);
99+
factory.addAdvice(interceptor);
100+
return (MyService) factory.getProxy();
101+
}
102+
103+
104+
@SuppressWarnings("unused")
105+
static class MyService {
106+
107+
public Mono<Void> addPerson(@Valid Mono<Person> personMono) {
108+
return personMono.then();
109+
}
110+
111+
public Mono<Void> addPersons(@Valid Flux<Person> personFlux) {
112+
return personFlux.then();
113+
}
114+
}
115+
116+
117+
@SuppressWarnings("unused")
118+
record Person(@Size(min = 1, max = 10) String name) {
119+
}
120+
121+
}

0 commit comments

Comments
 (0)