Skip to content

Commit 8b2eb8f

Browse files
committed
Fix query execution mode detection for aggregate types that implement Streamable.
We now short-circuit the QueryMethod.isCollectionQuery() algorithm in case we find the concrete domain type or any subclass of it. Fixes #2869.
1 parent 8708e25 commit 8b2eb8f

File tree

5 files changed

+120
-8
lines changed

5 files changed

+120
-8
lines changed

src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,13 @@ public TypeInformation<?> getReturnType(Method method) {
9696
return returnType;
9797
}
9898

99+
@Override
99100
public Class<?> getReturnedDomainClass(Method method) {
100101

101102
TypeInformation<?> returnType = getReturnType(method);
103+
returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType);
102104

103-
return QueryExecutionConverters.unwrapWrapperTypes(ReactiveWrapperConverters.unwrapWrapperTypes(returnType))
104-
.getType();
105+
return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType();
105106
}
106107

107108
public Class<?> getRepositoryInterface() {

src/main/java/org/springframework/data/repository/query/QueryMethod.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@
2424

2525
import org.springframework.data.domain.Page;
2626
import org.springframework.data.domain.Pageable;
27-
import org.springframework.data.domain.Window;
2827
import org.springframework.data.domain.ScrollPosition;
2928
import org.springframework.data.domain.Slice;
3029
import org.springframework.data.domain.Sort;
30+
import org.springframework.data.domain.Window;
3131
import org.springframework.data.projection.ProjectionFactory;
3232
import org.springframework.data.repository.core.EntityMetadata;
3333
import org.springframework.data.repository.core.RepositoryMetadata;
3434
import org.springframework.data.repository.util.QueryExecutionConverters;
3535
import org.springframework.data.repository.util.ReactiveWrapperConverters;
3636
import org.springframework.data.util.Lazy;
37+
import org.springframework.data.util.NullableWrapperConverters;
3738
import org.springframework.data.util.ReactiveWrappers;
3839
import org.springframework.data.util.TypeInformation;
3940
import org.springframework.util.Assert;
@@ -296,7 +297,15 @@ private boolean calculateIsCollectionQuery() {
296297
return false;
297298
}
298299

299-
Class<?> returnType = metadata.getReturnType(method).getType();
300+
TypeInformation<?> returnTypeInformation = metadata.getReturnType(method);
301+
302+
// Check against simple wrapper types first
303+
if (metadata.getDomainTypeInformation()
304+
.isAssignableFrom(NullableWrapperConverters.unwrapActualType(returnTypeInformation))) {
305+
return false;
306+
}
307+
308+
Class<?> returnType = returnTypeInformation.getType();
300309

301310
if (QueryExecutionConverters.supports(returnType) && !QueryExecutionConverters.isSingleValue(returnType)) {
302311
return true;

src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
import org.springframework.core.convert.support.ConfigurableConversionService;
3737
import org.springframework.core.convert.support.DefaultConversionService;
3838
import org.springframework.data.domain.Page;
39-
import org.springframework.data.domain.Window;
4039
import org.springframework.data.domain.Slice;
40+
import org.springframework.data.domain.Window;
4141
import org.springframework.data.geo.GeoResults;
4242
import org.springframework.data.util.CustomCollections;
4343
import org.springframework.data.util.NullableWrapper;
@@ -85,6 +85,7 @@ public abstract class QueryExecutionConverters {
8585
private static final Set<Class<?>> ALLOWED_PAGEABLE_TYPES = new HashSet<>();
8686
private static final Map<Class<?>, ExecutionAdapter> EXECUTION_ADAPTER = new HashMap<>();
8787
private static final Map<Class<?>, Boolean> supportsCache = new ConcurrentReferenceHashMap<>();
88+
private static final TypeInformation<Void> VOID_INFORMATION = TypeInformation.of(Void.class);
8889

8990
static {
9091

@@ -235,15 +236,21 @@ public static Object unwrap(@Nullable Object source) {
235236
}
236237

237238
/**
238-
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
239+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation} but aborts at the given
240+
* reference type.
239241
*
240242
* @param type must not be {@literal null}.
243+
* @param reference must not be {@literal null}.
241244
* @return will never be {@literal null}.
242245
*/
243-
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
246+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type, TypeInformation<?> reference) {
244247

245248
Assert.notNull(type, "type must not be null");
246249

250+
if (reference.isAssignableFrom(type)) {
251+
return type;
252+
}
253+
247254
Class<?> rawType = type.getType();
248255

249256
boolean needToUnwrap = type.isCollectionLike() //
@@ -253,7 +260,17 @@ public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
253260
|| supports(rawType) //
254261
|| Stream.class.isAssignableFrom(rawType);
255262

256-
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType()) : type;
263+
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType(), reference) : type;
264+
}
265+
266+
/**
267+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
268+
*
269+
* @param type must not be {@literal null}.
270+
* @return will never be {@literal null}.
271+
*/
272+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
273+
return unwrapWrapperTypes(type, VOID_INFORMATION);
257274
}
258275

259276
/**

src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,21 @@
2222
import java.util.Collections;
2323
import java.util.List;
2424
import java.util.Map;
25+
import java.util.Optional;
2526
import java.util.Set;
27+
import java.util.stream.Stream;
2628

29+
30+
import org.junit.jupiter.api.DynamicTest;
2731
import org.junit.jupiter.api.Test;
32+
import org.junit.jupiter.api.TestFactory;
2833
import org.springframework.data.domain.Page;
2934
import org.springframework.data.domain.Pageable;
3035
import org.springframework.data.querydsl.User;
3136
import org.springframework.data.repository.PagingAndSortingRepository;
3237
import org.springframework.data.repository.Repository;
3338
import org.springframework.data.repository.core.RepositoryMetadata;
39+
import org.springframework.data.util.Streamable;
3440

3541
/**
3642
* Unit tests for {@link AbstractRepositoryMetadata}.
@@ -113,6 +119,25 @@ void doesNotUnwrapCustomTypeImplementingIterable() throws Exception {
113119
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(Container.class);
114120
}
115121

122+
@TestFactory // GH-2869
123+
Stream<DynamicTest> detectsReturnTypesForStreamableAggregates() throws Exception {
124+
125+
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
126+
var methods = Stream.of(
127+
Map.entry("findBy", StreamableAggregate.class),
128+
Map.entry("findSubTypeBy", StreamableAggregateSubType.class),
129+
Map.entry("findAllBy", StreamableAggregate.class),
130+
Map.entry("findOptional", StreamableAggregate.class));
131+
132+
return DynamicTest.stream(methods, //
133+
it -> it.getKey() + "'s returned domain class is " + it.getValue(), //
134+
it -> {
135+
136+
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
137+
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(it.getValue());
138+
});
139+
}
140+
116141
interface UserRepository extends Repository<User, Long> {
117142

118143
User findSingle();
@@ -157,4 +182,20 @@ interface ContainerRepository extends Repository<Container, Long> {
157182

158183
interface CompletePageableAndSortingRepository extends PagingAndSortingRepository<Container, Long> {}
159184

185+
// GH-2869
186+
187+
static abstract class StreamableAggregate implements Streamable<Object> {}
188+
189+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
190+
191+
StreamableAggregate findBy();
192+
193+
StreamableAggregateSubType findSubTypeBy();
194+
195+
Streamable<StreamableAggregate> findAllBy();
196+
197+
Optional<StreamableAggregate> findOptional();
198+
}
199+
200+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
160201
}

src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@
2525

2626
import java.io.Serializable;
2727
import java.util.List;
28+
import java.util.Map;
29+
import java.util.Optional;
2830
import java.util.concurrent.CompletableFuture;
2931
import java.util.concurrent.Future;
3032
import java.util.stream.Stream;
3133

3234
import org.eclipse.collections.api.list.ImmutableList;
35+
import org.junit.jupiter.api.DynamicTest;
3336
import org.junit.jupiter.api.Test;
37+
import org.junit.jupiter.api.TestFactory;
3438
import org.springframework.data.domain.Page;
3539
import org.springframework.data.domain.Pageable;
3640
import org.springframework.data.domain.ScrollPosition;
@@ -41,6 +45,7 @@
4145
import org.springframework.data.repository.core.RepositoryMetadata;
4246
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
4347
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
48+
import org.springframework.data.util.Streamable;
4449

4550
/**
4651
* Unit tests for {@link QueryMethod}.
@@ -302,6 +307,28 @@ void considersEclipseCollectionCollectionQuery() throws Exception {
302307
assertThat(queryMethod.isCollectionQuery()).isTrue();
303308
}
304309

310+
@TestFactory // GH-2869
311+
Stream<DynamicTest> doesNotConsiderQueryMethodReturningAggregateImplementingStreamableACollectionQuery()
312+
throws Exception {
313+
314+
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
315+
var stream = Stream.of(
316+
Map.entry("findBy", false),
317+
Map.entry("findSubTypeBy", false),
318+
Map.entry("findAllBy", true),
319+
Map.entry("findOptionalBy", false));
320+
321+
return DynamicTest.stream(stream, //
322+
it -> it.getKey() + " considered collection query -> " + it.getValue(), //
323+
it -> {
324+
325+
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
326+
var queryMethod = new QueryMethod(method, metadata, factory);
327+
328+
assertThat(queryMethod.isCollectionQuery()).isEqualTo(it.getValue());
329+
});
330+
}
331+
305332
interface SampleRepository extends Repository<User, Serializable> {
306333

307334
String pagingMethodWithInvalidReturnType(Pageable pageable);
@@ -379,4 +406,21 @@ abstract class Container implements Iterable<Element> {}
379406
interface ContainerRepository extends Repository<Container, Long> {
380407
Container someMethod();
381408
}
409+
410+
// GH-2869
411+
412+
static abstract class StreamableAggregate implements Streamable<Object> {}
413+
414+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
415+
416+
StreamableAggregate findBy();
417+
418+
StreamableAggregateSubType findSubTypeBy();
419+
420+
Optional<StreamableAggregate> findOptionalBy();
421+
422+
Streamable<StreamableAggregate> findAllBy();
423+
}
424+
425+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
382426
}

0 commit comments

Comments
 (0)