Skip to content

Commit ca9f9bf

Browse files
committed
Fix query execution mode detection for aggregate types that implement Streamable.
We now short-circuit the QueryMethod.isCollectionQuery() algorithm in case we find the concrete domain type or any subclass of it. Fixes #2869.
1 parent 05dd7ae commit ca9f9bf

File tree

5 files changed

+119
-8
lines changed

5 files changed

+119
-8
lines changed

src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,13 @@ public TypeInformation<?> getReturnType(Method method) {
9696
return returnType;
9797
}
9898

99+
@Override
99100
public Class<?> getReturnedDomainClass(Method method) {
100101

101102
TypeInformation<?> returnType = getReturnType(method);
103+
returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType);
102104

103-
return QueryExecutionConverters.unwrapWrapperTypes(ReactiveWrapperConverters.unwrapWrapperTypes(returnType))
104-
.getType();
105+
return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType();
105106
}
106107

107108
public Class<?> getRepositoryInterface() {

src/main/java/org/springframework/data/repository/query/QueryMethod.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@
2424

2525
import org.springframework.data.domain.Page;
2626
import org.springframework.data.domain.Pageable;
27-
import org.springframework.data.domain.Window;
2827
import org.springframework.data.domain.ScrollPosition;
2928
import org.springframework.data.domain.Slice;
3029
import org.springframework.data.domain.Sort;
30+
import org.springframework.data.domain.Window;
3131
import org.springframework.data.projection.ProjectionFactory;
3232
import org.springframework.data.repository.core.EntityMetadata;
3333
import org.springframework.data.repository.core.RepositoryMetadata;
3434
import org.springframework.data.repository.util.QueryExecutionConverters;
3535
import org.springframework.data.repository.util.ReactiveWrapperConverters;
3636
import org.springframework.data.util.Lazy;
37+
import org.springframework.data.util.NullableWrapperConverters;
3738
import org.springframework.data.util.ReactiveWrappers;
3839
import org.springframework.data.util.TypeInformation;
3940
import org.springframework.util.Assert;
@@ -296,7 +297,15 @@ private boolean calculateIsCollectionQuery() {
296297
return false;
297298
}
298299

299-
Class<?> returnType = metadata.getReturnType(method).getType();
300+
TypeInformation<?> returnTypeInformation = metadata.getReturnType(method);
301+
302+
// Check against simple wrapper types first
303+
if (metadata.getDomainTypeInformation()
304+
.isAssignableFrom(NullableWrapperConverters.unwrapActualType(returnTypeInformation))) {
305+
return false;
306+
}
307+
308+
Class<?> returnType = returnTypeInformation.getType();
300309

301310
if (QueryExecutionConverters.supports(returnType) && !QueryExecutionConverters.isSingleValue(returnType)) {
302311
return true;

src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
import org.springframework.core.convert.support.ConfigurableConversionService;
3737
import org.springframework.core.convert.support.DefaultConversionService;
3838
import org.springframework.data.domain.Page;
39-
import org.springframework.data.domain.Window;
4039
import org.springframework.data.domain.Slice;
40+
import org.springframework.data.domain.Window;
4141
import org.springframework.data.geo.GeoResults;
4242
import org.springframework.data.util.CustomCollections;
4343
import org.springframework.data.util.NullableWrapper;
@@ -85,6 +85,7 @@ public abstract class QueryExecutionConverters {
8585
private static final Set<Class<?>> ALLOWED_PAGEABLE_TYPES = new HashSet<>();
8686
private static final Map<Class<?>, ExecutionAdapter> EXECUTION_ADAPTER = new HashMap<>();
8787
private static final Map<Class<?>, Boolean> supportsCache = new ConcurrentReferenceHashMap<>();
88+
private static final TypeInformation<Void> VOID_INFORMATION = TypeInformation.of(Void.class);
8889

8990
static {
9091

@@ -235,15 +236,21 @@ public static Object unwrap(@Nullable Object source) {
235236
}
236237

237238
/**
238-
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
239+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation} but aborts at the given
240+
* reference type.
239241
*
240242
* @param type must not be {@literal null}.
243+
* @param reference must not be {@literal null}.
241244
* @return will never be {@literal null}.
242245
*/
243-
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
246+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type, TypeInformation<?> reference) {
244247

245248
Assert.notNull(type, "type must not be null");
246249

250+
if (reference.isAssignableFrom(type)) {
251+
return type;
252+
}
253+
247254
Class<?> rawType = type.getType();
248255

249256
boolean needToUnwrap = type.isCollectionLike() //
@@ -253,7 +260,17 @@ public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
253260
|| supports(rawType) //
254261
|| Stream.class.isAssignableFrom(rawType);
255262

256-
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType()) : type;
263+
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType(), reference) : type;
264+
}
265+
266+
/**
267+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
268+
*
269+
* @param type must not be {@literal null}.
270+
* @return will never be {@literal null}.
271+
*/
272+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
273+
return unwrapWrapperTypes(type, VOID_INFORMATION);
257274
}
258275

259276
/**

src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121
import java.lang.reflect.Method;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.Optional;
25+
import java.util.stream.Stream;
2426

27+
import org.junit.jupiter.api.DynamicTest;
2528
import org.junit.jupiter.api.Test;
29+
import org.junit.jupiter.api.TestFactory;
2630
import org.springframework.data.domain.Page;
2731
import org.springframework.data.domain.Pageable;
2832
import org.springframework.data.querydsl.User;
2933
import org.springframework.data.repository.PagingAndSortingRepository;
3034
import org.springframework.data.repository.Repository;
3135
import org.springframework.data.repository.core.RepositoryMetadata;
36+
import org.springframework.data.util.Streamable;
3237

3338
/**
3439
* Unit tests for {@link AbstractRepositoryMetadata}.
@@ -111,6 +116,25 @@ void doesNotUnwrapCustomTypeImplementingIterable() throws Exception {
111116
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(Container.class);
112117
}
113118

119+
@TestFactory // GH-2869
120+
Stream<DynamicTest> detectsReturnTypesForStreamableAggregates() throws Exception {
121+
122+
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
123+
var methods = Stream.of(
124+
Map.entry("findBy", StreamableAggregate.class),
125+
Map.entry("findSubTypeBy", StreamableAggregateSubType.class),
126+
Map.entry("findAllBy", StreamableAggregate.class),
127+
Map.entry("findOptional", StreamableAggregate.class));
128+
129+
return DynamicTest.stream(methods, //
130+
it -> it.getKey() + "'s returned domain class is " + it.getValue(), //
131+
it -> {
132+
133+
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
134+
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(it.getValue());
135+
});
136+
}
137+
114138
interface UserRepository extends Repository<User, Long> {
115139

116140
User findSingle();
@@ -155,4 +179,20 @@ interface ContainerRepository extends Repository<Container, Long> {
155179

156180
interface CompletePageableAndSortingRepository extends PagingAndSortingRepository<Container, Long> {}
157181

182+
// GH-2869
183+
184+
static abstract class StreamableAggregate implements Streamable<Object> {}
185+
186+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
187+
188+
StreamableAggregate findBy();
189+
190+
StreamableAggregateSubType findSubTypeBy();
191+
192+
Streamable<StreamableAggregate> findAllBy();
193+
194+
Optional<StreamableAggregate> findOptional();
195+
}
196+
197+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
158198
}

src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@
2424

2525
import java.io.Serializable;
2626
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Optional;
2729
import java.util.concurrent.CompletableFuture;
2830
import java.util.concurrent.Future;
2931
import java.util.stream.Stream;
3032

3133
import org.eclipse.collections.api.list.ImmutableList;
34+
import org.junit.jupiter.api.DynamicTest;
3235
import org.junit.jupiter.api.Test;
36+
import org.junit.jupiter.api.TestFactory;
3337
import org.springframework.data.domain.Page;
3438
import org.springframework.data.domain.Pageable;
3539
import org.springframework.data.domain.ScrollPosition;
@@ -41,6 +45,7 @@
4145
import org.springframework.data.repository.core.RepositoryMetadata;
4246
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
4347
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
48+
import org.springframework.data.util.Streamable;
4449

4550
/**
4651
* Unit tests for {@link QueryMethod}.
@@ -302,6 +307,28 @@ void considersEclipseCollectionCollectionQuery() throws Exception {
302307
assertThat(queryMethod.isCollectionQuery()).isTrue();
303308
}
304309

310+
@TestFactory // GH-2869
311+
Stream<DynamicTest> doesNotConsiderQueryMethodReturningAggregateImplementingStreamableACollectionQuery()
312+
throws Exception {
313+
314+
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
315+
var stream = Stream.of(
316+
Map.entry("findBy", false),
317+
Map.entry("findSubTypeBy", false),
318+
Map.entry("findAllBy", true),
319+
Map.entry("findOptionalBy", false));
320+
321+
return DynamicTest.stream(stream, //
322+
it -> it.getKey() + " considered collection query -> " + it.getValue(), //
323+
it -> {
324+
325+
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
326+
var queryMethod = new QueryMethod(method, metadata, factory);
327+
328+
assertThat(queryMethod.isCollectionQuery()).isEqualTo(it.getValue());
329+
});
330+
}
331+
305332
interface SampleRepository extends Repository<User, Serializable> {
306333

307334
String pagingMethodWithInvalidReturnType(Pageable pageable);
@@ -379,4 +406,21 @@ abstract class Container implements Iterable<Element> {}
379406
interface ContainerRepository extends Repository<Container, Long> {
380407
Container someMethod();
381408
}
409+
410+
// GH-2869
411+
412+
static abstract class StreamableAggregate implements Streamable<Object> {}
413+
414+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
415+
416+
StreamableAggregate findBy();
417+
418+
StreamableAggregateSubType findSubTypeBy();
419+
420+
Optional<StreamableAggregate> findOptionalBy();
421+
422+
Streamable<StreamableAggregate> findAllBy();
423+
}
424+
425+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
382426
}

0 commit comments

Comments
 (0)