Skip to content

Commit 8555dfb

Browse files
GH-2244 - Fix detection of common element types in collections of persistent entities.
This fixes #2244.
1 parent fd36b91 commit 8555dfb

File tree

5 files changed

+276
-13
lines changed

5 files changed

+276
-13
lines changed

src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
import org.springframework.lang.NonNull;
7979
import org.springframework.lang.Nullable;
8080
import org.springframework.util.Assert;
81-
import org.springframework.util.CollectionUtils;
8281

8382
/**
8483
* @author Michael J. Simons
@@ -401,7 +400,7 @@ private <T> List<T> saveAllImpl(Iterable<T> instances, @Nullable List<PropertyDe
401400
return Collections.emptyList();
402401
}
403402

404-
Class<T> domainClass = (Class<T>) CollectionUtils.findCommonElementType(entities);
403+
Class<T> domainClass = (Class<T>) TemplateSupport.findCommonElementType(entities);
405404
Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getPersistentEntity(domainClass);
406405
if (entityMetaData.isUsingInternalIds() || entityMetaData.hasVersionProperty()
407406
|| entityMetaData.getDynamicLabelsProperty().isPresent()) {

src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
import org.springframework.lang.NonNull;
8080
import org.springframework.lang.Nullable;
8181
import org.springframework.util.Assert;
82-
import org.springframework.util.CollectionUtils;
8382

8483
/**
8584
* @author Michael J. Simons
@@ -417,7 +416,7 @@ private <T> Flux<T> saveAllImpl(Iterable<T> instances, @Nullable List<PropertyDe
417416
return Flux.empty();
418417
}
419418

420-
Class<T> domainClass = (Class<T>) CollectionUtils.findCommonElementType(entities);
419+
Class<T> domainClass = (Class<T>) TemplateSupport.findCommonElementType(entities);
421420
Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getPersistentEntity(domainClass);
422421

423422
if (entityMetaData.isUsingInternalIds() || entityMetaData.hasVersionProperty()

src/main/java/org/springframework/data/neo4j/core/TemplateSupport.java

+45-9
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,28 @@
1616
package org.springframework.data.neo4j.core;
1717

1818
import java.beans.PropertyDescriptor;
19+
import java.util.Arrays;
1920
import java.util.HashMap;
21+
import java.util.HashSet;
2022
import java.util.List;
2123
import java.util.Map;
2224
import java.util.Set;
2325
import java.util.function.Predicate;
2426
import java.util.stream.Collectors;
27+
import java.util.stream.StreamSupport;
2528

2629
import org.apiguardian.api.API;
2730
import org.neo4j.cypherdsl.core.Statement;
2831
import org.springframework.lang.Nullable;
2932

3033
/**
3134
* Utilities for templates.
35+
*
3236
* @author Michael J. Simons
3337
* @soundtrack Metallica - Ride The Lightning
34-
* @since 6.1
38+
* @since 6.0.9
3539
*/
36-
@API(status = API.Status.INTERNAL, since = "6.1")
40+
@API(status = API.Status.INTERNAL, since = "6.0.9")
3741
final class TemplateSupport {
3842

3943
enum FetchType {
@@ -44,16 +48,47 @@ enum FetchType {
4448

4549
@Nullable
4650
static Class<?> findCommonElementType(Iterable<?> collection) {
51+
52+
List<Class<?>> allClasses = StreamSupport.stream(collection.spliterator(), true)
53+
.filter(o -> o != null)
54+
.map(Object::getClass).collect(Collectors.toList());
55+
4756
Class<?> candidate = null;
48-
for (Object val : collection) {
49-
if (val != null) {
50-
if (candidate == null) {
51-
candidate = val.getClass();
52-
} else if (candidate != val.getClass()) {
53-
return null;
57+
for (Class<?> type : allClasses) {
58+
if (candidate == null) {
59+
candidate = type;
60+
} else if (candidate != type) {
61+
candidate = null;
62+
break;
63+
}
64+
}
65+
66+
if (candidate != null) {
67+
return candidate;
68+
} else {
69+
Predicate<Class<?>> moveUp = c -> c != null && c != Object.class;
70+
Set<Class<?>> mostAbstractClasses = new HashSet<>();
71+
for (Class<?> type : allClasses) {
72+
while (moveUp.test(type.getSuperclass())) {
73+
type = type.getSuperclass();
5474
}
75+
mostAbstractClasses.add(type);
5576
}
77+
candidate = mostAbstractClasses.size() == 1 ? mostAbstractClasses.iterator().next() : null;
5678
}
79+
80+
if (candidate != null) {
81+
return candidate;
82+
} else {
83+
List<Set<Class<?>>> interfacesPerClass = allClasses.stream()
84+
.map(c -> Arrays.stream(c.getInterfaces()).collect(Collectors.toSet()))
85+
.collect(Collectors.toList());
86+
Set<Class<?>> allInterfaces = interfacesPerClass.stream().flatMap(Set::stream).collect(Collectors.toSet());
87+
interfacesPerClass
88+
.forEach(setOfInterfaces -> allInterfaces.removeIf(iface -> !setOfInterfaces.contains(iface)));
89+
candidate = allInterfaces.size() == 1 ? allInterfaces.iterator().next() : null;
90+
}
91+
5792
return candidate;
5893
}
5994

@@ -70,7 +105,8 @@ static Predicate<String> computeIncludePropertyPredicate(List<PropertyDescriptor
70105

71106
/**
72107
* Merges statement and explicit parameters. Statement parameters have a higher precedence
73-
* @param statement A statement that maybe has some stored parameters
108+
*
109+
* @param statement A statement that maybe has some stored parameters
74110
* @param parameters The original parameters
75111
* @return Merged parameters
76112
*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright 2011-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.neo4j.core;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.util.Arrays;
21+
22+
import org.junit.jupiter.api.Test;
23+
24+
/**
25+
* @author Michael J. Simons
26+
*/
27+
class TemplateSupportTest {
28+
29+
static class A {
30+
}
31+
32+
static class B {
33+
}
34+
35+
static class A2 extends A {
36+
}
37+
38+
static class A3 extends A {
39+
}
40+
41+
static class A4 extends A {
42+
}
43+
44+
static class AA2 extends A2 {
45+
}
46+
47+
interface IA {
48+
}
49+
50+
interface IB {
51+
}
52+
53+
static class B1 implements IA {
54+
}
55+
56+
static class B2 implements IA {
57+
}
58+
59+
static class B3 implements IA, IB {
60+
}
61+
62+
static class B4 implements IA, IB {
63+
}
64+
65+
@Test
66+
void shouldFindCommonElementTypeOfHeterousCollection() {
67+
68+
Class<?> type = TemplateSupport.findCommonElementType(Arrays.asList(new A(), new A(), new A()));
69+
assertThat(type).isNotNull().isEqualTo(A.class);
70+
}
71+
72+
@Test
73+
void shouldNotFailWithNull() {
74+
75+
Class<?> type = TemplateSupport.findCommonElementType(Arrays.asList(new A(), null, new A()));
76+
assertThat(type).isNotNull().isEqualTo(A.class);
77+
}
78+
79+
@Test
80+
void shouldFindCommonElementTypeOfHumongousCollection() {
81+
82+
Class<?> type = TemplateSupport.findCommonElementType(Arrays.asList(new A2(), new A3(), new A4()));
83+
assertThat(type).isNotNull().isEqualTo(A.class);
84+
}
85+
86+
@Test
87+
void shouldFindCommonElementTypeOfHumongousDeepCollection() {
88+
89+
Class<?> type = TemplateSupport.findCommonElementType(Arrays.asList(new A2(), new AA2(), new A3(), new A4()));
90+
assertThat(type).isNotNull().isEqualTo(A.class);
91+
}
92+
93+
@Test
94+
void shouldFindCommonElementTypeOfHumongousInterfaceCollection() {
95+
96+
Class<?> type = TemplateSupport.findCommonElementType(Arrays.asList(new B1(), new B2()));
97+
assertThat(type).isNotNull().isEqualTo(IA.class);
98+
99+
type = TemplateSupport.findCommonElementType(Arrays.asList(new B1(), new B2(), new B3()));
100+
assertThat(type).isNotNull().isEqualTo(IA.class);
101+
}
102+
103+
@Test
104+
void shouldNotFindAmbiguousInterface() {
105+
106+
Class<?> type = TemplateSupport.findCommonElementType(Arrays.asList(new B3(), new B4()));
107+
assertThat(type).isNull();
108+
}
109+
110+
@Test
111+
void shouldNotFindCommonElementTypeWhenThereIsNone() {
112+
113+
Class<?> type = TemplateSupport.findCommonElementType(Arrays.asList(new A(), new A(), new B()));
114+
assertThat(type).isNull();
115+
116+
type = TemplateSupport.findCommonElementType(Arrays.asList(new A(), new B(), new A()));
117+
assertThat(type).isNull();
118+
119+
type = TemplateSupport.findCommonElementType(Arrays.asList(new B(), new A(), new A()));
120+
assertThat(type).isNull();
121+
}
122+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright 2011-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.neo4j.integration.issues.gh2244;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.util.Arrays;
21+
import java.util.HashSet;
22+
import java.util.List;
23+
24+
import org.junit.jupiter.api.Test;
25+
import org.neo4j.driver.Driver;
26+
import org.springframework.beans.factory.annotation.Autowired;
27+
import org.springframework.context.annotation.Bean;
28+
import org.springframework.context.annotation.Configuration;
29+
import org.springframework.data.neo4j.config.AbstractNeo4jConfig;
30+
import org.springframework.data.neo4j.core.Neo4jTemplate;
31+
import org.springframework.data.neo4j.core.convert.Neo4jConversions;
32+
import org.springframework.data.neo4j.core.mapping.Neo4jMappingContext;
33+
import org.springframework.data.neo4j.core.schema.GeneratedValue;
34+
import org.springframework.data.neo4j.core.schema.Id;
35+
import org.springframework.data.neo4j.core.schema.Node;
36+
import org.springframework.data.neo4j.test.Neo4jExtension;
37+
import org.springframework.data.neo4j.test.Neo4jIntegrationTest;
38+
import org.springframework.transaction.annotation.EnableTransactionManagement;
39+
40+
/**
41+
* @author Michael J. Simons
42+
*/
43+
@Neo4jIntegrationTest
44+
class GH2244IT {
45+
46+
protected static Neo4jExtension.Neo4jConnectionSupport neo4jConnectionSupport;
47+
48+
@Test
49+
void safeAllWithSubTypesShouldWork(@Autowired Neo4jTemplate neo4jTemplate) {
50+
51+
List<Step> steps = Arrays.asList(new Step.Origin(), new Step.Chain(), new Step.End());
52+
steps = neo4jTemplate.saveAll(steps);
53+
assertThat(steps).allSatisfy(s -> assertThat(s.id).isNotNull());
54+
}
55+
56+
/**
57+
* Abstract domain class.
58+
*/
59+
@Node
60+
public static abstract class Step {
61+
62+
@Id @GeneratedValue
63+
private Long id;
64+
65+
/**
66+
* A step.
67+
*/
68+
@Node
69+
public static class Chain extends Step {
70+
}
71+
72+
/**
73+
* A step.
74+
*/
75+
@Node
76+
public static class End extends Step {
77+
}
78+
79+
/**
80+
* A step.
81+
*/
82+
@Node
83+
public static class Origin extends Step {
84+
}
85+
}
86+
87+
@Configuration
88+
@EnableTransactionManagement
89+
static class Config extends AbstractNeo4jConfig {
90+
91+
@Bean
92+
public Driver driver() {
93+
94+
return neo4jConnectionSupport.getDriver();
95+
}
96+
97+
@Override
98+
public Neo4jMappingContext neo4jMappingContext(Neo4jConversions neo4JConversions)
99+
throws ClassNotFoundException {
100+
101+
Neo4jMappingContext ctx = new Neo4jMappingContext(neo4JConversions);
102+
ctx.setInitialEntitySet(new HashSet<>(Arrays.asList(Step.class, Step.Chain.class, Step.End.class,
103+
Step.Origin.class)));
104+
return ctx;
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)