Skip to content

Commit 5ae2730

Browse files
committed
Fix query execution mode detection for aggregate types that implement Streamable.
We now short-circuit the QueryMethod.isCollectionQuery() algorithm in case we find the concrete domain type or any subclass of it. Fixes #2869.
1 parent 38ea46c commit 5ae2730

File tree

5 files changed

+121
-7
lines changed

5 files changed

+121
-7
lines changed

src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,13 @@ public TypeInformation<?> getReturnType(Method method) {
105105
* (non-Javadoc)
106106
* @see org.springframework.data.repository.core.RepositoryMetadata#getReturnedDomainClass(java.lang.reflect.Method)
107107
*/
108+
@Override
108109
public Class<?> getReturnedDomainClass(Method method) {
109110

110111
TypeInformation<?> returnType = getReturnType(method);
112+
returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType);
111113

112-
return QueryExecutionConverters.unwrapWrapperTypes(ReactiveWrapperConverters.unwrapWrapperTypes(returnType))
113-
.getType();
114+
return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType();
114115
}
115116

116117
/*

src/main/java/org/springframework/data/repository/query/QueryMethod.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.data.repository.util.ReactiveWrapperConverters;
3333
import org.springframework.data.util.ClassTypeInformation;
3434
import org.springframework.data.util.Lazy;
35+
import org.springframework.data.util.NullableWrapperConverters;
3536
import org.springframework.data.util.TypeInformation;
3637
import org.springframework.util.Assert;
3738

@@ -265,7 +266,15 @@ private boolean calculateIsCollectionQuery() {
265266
return false;
266267
}
267268

268-
Class<?> returnType = metadata.getReturnType(method).getType();
269+
TypeInformation<?> returnTypeInformation = metadata.getReturnType(method);
270+
271+
// Check against simple wrapper types first
272+
if (metadata.getDomainTypeInformation()
273+
.isAssignableFrom(NullableWrapperConverters.unwrapActualType(returnTypeInformation))) {
274+
return false;
275+
}
276+
277+
Class<?> returnType = returnTypeInformation.getType();
269278

270279
if (QueryExecutionConverters.supports(returnType) && !QueryExecutionConverters.isSingleValue(returnType)) {
271280
return true;

src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java

+21-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.data.domain.Page;
4040
import org.springframework.data.domain.Slice;
4141
import org.springframework.data.geo.GeoResults;
42+
import org.springframework.data.util.ClassTypeInformation;
4243
import org.springframework.data.util.CustomCollections;
4344
import org.springframework.data.util.NullableWrapper;
4445
import org.springframework.data.util.NullableWrapperConverters;
@@ -86,6 +87,7 @@ public abstract class QueryExecutionConverters {
8687
private static final Set<Class<?>> ALLOWED_PAGEABLE_TYPES = new HashSet<>();
8788
private static final Map<Class<?>, ExecutionAdapter> EXECUTION_ADAPTER = new HashMap<>();
8889
private static final Map<Class<?>, Boolean> supportsCache = new ConcurrentReferenceHashMap<>();
90+
private static final TypeInformation<Void> VOID_INFORMATION = ClassTypeInformation.from(Void.class);
8991

9092
static {
9193

@@ -235,15 +237,21 @@ public static Object unwrap(@Nullable Object source) {
235237
}
236238

237239
/**
238-
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
240+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation} but aborts at the given
241+
* reference type.
239242
*
240243
* @param type must not be {@literal null}.
244+
* @param reference must not be {@literal null}.
241245
* @return will never be {@literal null}.
242246
*/
243-
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
247+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type, TypeInformation<?> reference) {
244248

245249
Assert.notNull(type, "type must not be null");
246250

251+
if (reference.isAssignableFrom(type)) {
252+
return type;
253+
}
254+
247255
Class<?> rawType = type.getType();
248256

249257
boolean needToUnwrap = type.isCollectionLike() //
@@ -253,7 +261,17 @@ public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
253261
|| supports(rawType) //
254262
|| Stream.class.isAssignableFrom(rawType);
255263

256-
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType()) : type;
264+
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType(), reference) : type;
265+
}
266+
267+
/**
268+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
269+
*
270+
* @param type must not be {@literal null}.
271+
* @return will never be {@literal null}.
272+
*/
273+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
274+
return unwrapWrapperTypes(type, VOID_INFORMATION);
257275
}
258276

259277
/**

src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java

+42-1
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@
2121
import java.lang.reflect.Method;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.Map.Entry;
25+
import java.util.Optional;
26+
import java.util.stream.Stream;
2427

28+
import org.junit.jupiter.api.DynamicTest;
2529
import org.junit.jupiter.api.Test;
26-
import org.springframework.core.ResolvableType;
30+
import org.junit.jupiter.api.TestFactory;
2731
import org.springframework.data.domain.Page;
2832
import org.springframework.data.domain.Pageable;
2933
import org.springframework.data.querydsl.User;
3034
import org.springframework.data.repository.PagingAndSortingRepository;
3135
import org.springframework.data.repository.Repository;
3236
import org.springframework.data.repository.core.RepositoryMetadata;
37+
import org.springframework.data.util.Streamable;
3338

3439
/**
3540
* Unit tests for {@link AbstractRepositoryMetadata}.
@@ -112,6 +117,25 @@ void doesNotUnwrapCustomTypeImplementingIterable() throws Exception {
112117
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(Container.class);
113118
}
114119

120+
@TestFactory // GH-2869
121+
Stream<DynamicTest> detectsReturnTypesForStreamableAggregates() throws Exception {
122+
123+
RepositoryMetadata metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
124+
Stream<Entry<String, Class<?>>> methods = Stream.of(
125+
Map.entry("findBy", StreamableAggregate.class),
126+
Map.entry("findSubTypeBy", StreamableAggregateSubType.class),
127+
Map.entry("findAllBy", StreamableAggregate.class),
128+
Map.entry("findOptional", StreamableAggregate.class));
129+
130+
return DynamicTest.stream(methods, //
131+
it -> it.getKey() + "'s returned domain class is " + it.getValue(), //
132+
it -> {
133+
134+
Method method = StreamableAggregateRepository.class.getMethod(it.getKey());
135+
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(it.getValue());
136+
});
137+
}
138+
115139
interface UserRepository extends Repository<User, Long> {
116140

117141
User findSingle();
@@ -153,4 +177,21 @@ abstract class Container implements Iterable<Element> {}
153177
interface ContainerRepository extends Repository<Container, Long> {
154178
Container someMethod();
155179
}
180+
181+
// GH-2869
182+
183+
static abstract class StreamableAggregate implements Streamable<Object> {}
184+
185+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
186+
187+
StreamableAggregate findBy();
188+
189+
StreamableAggregateSubType findSubTypeBy();
190+
191+
Streamable<StreamableAggregate> findAllBy();
192+
193+
Optional<StreamableAggregate> findOptional();
194+
}
195+
196+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
156197
}

src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java

+45
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,17 @@
2424
import java.io.Serializable;
2525
import java.lang.reflect.Method;
2626
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Map.Entry;
29+
import java.util.Optional;
2730
import java.util.concurrent.CompletableFuture;
2831
import java.util.concurrent.Future;
2932
import java.util.stream.Stream;
3033

3134
import org.eclipse.collections.api.list.ImmutableList;
35+
import org.junit.jupiter.api.DynamicTest;
3236
import org.junit.jupiter.api.Test;
37+
import org.junit.jupiter.api.TestFactory;
3338
import org.springframework.data.domain.Page;
3439
import org.springframework.data.domain.Pageable;
3540
import org.springframework.data.domain.Slice;
@@ -39,6 +44,7 @@
3944
import org.springframework.data.repository.core.RepositoryMetadata;
4045
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
4146
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
47+
import org.springframework.data.util.Streamable;
4248

4349
/**
4450
* Unit tests for {@link QueryMethod}.
@@ -258,6 +264,28 @@ void considersEclipseCollectionCollectionQuery() throws Exception {
258264
assertThat(queryMethod.isCollectionQuery()).isTrue();
259265
}
260266

267+
@TestFactory // GH-2869
268+
Stream<DynamicTest> doesNotConsiderQueryMethodReturningAggregateImplementingStreamableACollectionQuery()
269+
throws Exception {
270+
271+
RepositoryMetadata metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
272+
Stream<Entry<String, Boolean>> stream = Stream.of(
273+
Map.entry("findBy", false),
274+
Map.entry("findSubTypeBy", false),
275+
Map.entry("findAllBy", true),
276+
Map.entry("findOptionalBy", false));
277+
278+
return DynamicTest.stream(stream, //
279+
it -> it.getKey() + " considered collection query -> " + it.getValue(), //
280+
it -> {
281+
282+
Method method = StreamableAggregateRepository.class.getMethod(it.getKey());
283+
QueryMethod queryMethod = new QueryMethod(method, metadata, factory);
284+
285+
assertThat(queryMethod.isCollectionQuery()).isEqualTo(it.getValue());
286+
});
287+
}
288+
261289
interface SampleRepository extends Repository<User, Serializable> {
262290

263291
String pagingMethodWithInvalidReturnType(Pageable pageable);
@@ -325,4 +353,21 @@ abstract class Container implements Iterable<Element> {}
325353
interface ContainerRepository extends Repository<Container, Long> {
326354
Container someMethod();
327355
}
356+
357+
// GH-2869
358+
359+
static abstract class StreamableAggregate implements Streamable<Object> {}
360+
361+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
362+
363+
StreamableAggregate findBy();
364+
365+
StreamableAggregateSubType findSubTypeBy();
366+
367+
Optional<StreamableAggregate> findOptionalBy();
368+
369+
Streamable<StreamableAggregate> findAllBy();
370+
}
371+
372+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
328373
}

0 commit comments

Comments
 (0)