Skip to content

Commit d9bb2bb

Browse files
committed
GH-2138 - Fix generic loading.
1 parent 8b636ef commit d9bb2bb

File tree

10 files changed

+340
-203
lines changed

10 files changed

+340
-203
lines changed

src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import java.util.Set;
3232
import java.util.function.Consumer;
3333
import java.util.function.Function;
34-
import java.util.function.Predicate;
3534
import java.util.stream.Collectors;
3635

3736
import org.apache.commons.logging.LogFactory;
@@ -702,13 +701,7 @@ private GenericQueryAndParameters createQueryAndParameters(Neo4jPersistentEntity
702701
final Set<Long> relationshipIds = new HashSet<>();
703702
final Set<Long> relatedNodeIds = new HashSet<>();
704703

705-
Predicate<RelationshipDescription> relationshipFilter = ((Predicate<RelationshipDescription>) relationshipDescription ->
706-
queryFragments.includeField(relationshipDescription.getFieldName())).negate();
707-
708-
for (RelationshipDescription relationshipDescription : entityMetaData.getRelationships()) {
709-
if (relationshipFilter.test(relationshipDescription)) {
710-
continue;
711-
}
704+
for (RelationshipDescription relationshipDescription : entityMetaData.getRelationshipsUpAndDown(fieldName -> queryFragments.includeField(fieldName))) {
712705

713706
Statement statement = cypherGenerator
714707
.prepareMatchOf(entityMetaData, relationshipDescription, queryFragments.getMatchOn(), queryFragments.getCondition())
@@ -727,7 +720,7 @@ private GenericQueryAndParameters createQueryAndParameters(Neo4jPersistentEntity
727720
private void iterateNextLevel(Collection<Long> nodeIds, Neo4jPersistentEntity<?> target, Set<Long> relationshipIds,
728721
Set<Long> relatedNodeIds) {
729722

730-
Collection<RelationshipDescription> relationships = target.getRelationships();
723+
Collection<RelationshipDescription> relationships = target.getRelationshipsUpAndDown(s -> true);
731724
for (RelationshipDescription relationshipDescription : relationships) {
732725

733726
Node node = anyNode(Constants.NAME_OF_ROOT_NODE);

src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
import java.util.Set;
7070
import java.util.concurrent.ConcurrentHashMap;
7171
import java.util.function.Function;
72-
import java.util.function.Predicate;
7372
import java.util.stream.Collectors;
7473

7574
import static org.neo4j.cypherdsl.core.Cypher.anyNode;
@@ -465,16 +464,12 @@ private <T> Mono<ExecutableQuery<T>> createExecutableQuery(Class<T> domainType,
465464
private Mono<GenericQueryAndParameters> createQueryAndParameters(Neo4jPersistentEntity<?> entityMetaData,
466465
QueryFragmentsAndParameters.QueryFragments queryFragments, Map<String, Object> parameters) {
467466

468-
Predicate<RelationshipDescription> relationshipFilter = relationshipDescription ->
469-
queryFragments.includeField(relationshipDescription.getFieldName());
470-
471467
return getDatabaseName().flatMap(databaseName -> {
472468
return Mono.deferContextual(ctx -> {
473469
Set<Long> rootNodeIds = ctx.get("rootNodes");
474470
Set<Long> processedRelationshipIds = ctx.get("processedRelationships");
475471
Set<Long> processedNodeIds = ctx.get("processedNodes");
476-
return Flux.fromIterable(entityMetaData.getRelationships())
477-
.filter(relationshipFilter)
472+
return Flux.fromIterable(entityMetaData.getRelationshipsUpAndDown(fieldName -> queryFragments.includeField(fieldName)))
478473
.flatMap(relationshipDescription -> {
479474

480475
Statement statement = cypherGenerator.prepareMatchOf(entityMetaData, relationshipDescription,
@@ -514,7 +509,7 @@ private Flux<Tuple2<Collection<Long>, Collection<Long>>> iterateNextLevel(Collec
514509

515510
NodeDescription<?> target = relationshipDescription.getTarget();
516511

517-
return Flux.fromIterable(target.getRelationships())
512+
return Flux.fromIterable(target.getRelationshipsUpAndDown(s -> true))
518513
.flatMap(relDe -> {
519514
Node node = anyNode(Constants.NAME_OF_ROOT_NODE);
520515

src/main/java/org/springframework/data/neo4j/core/mapping/CypherGenerator.java

Lines changed: 1 addition & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,16 @@
3636
import org.springframework.data.domain.Sort;
3737
import org.springframework.data.mapping.MappingException;
3838
import org.springframework.data.mapping.PersistentProperty;
39-
import org.springframework.data.neo4j.core.schema.Relationship.Direction;
4039
import org.springframework.lang.NonNull;
4140
import org.springframework.lang.Nullable;
4241
import org.springframework.util.Assert;
4342

4443
import java.util.ArrayList;
4544
import java.util.Arrays;
4645
import java.util.Collection;
47-
import java.util.HashSet;
4846
import java.util.List;
49-
import java.util.Set;
5047
import java.util.function.Predicate;
5148
import java.util.function.UnaryOperator;
52-
import java.util.stream.Collectors;
5349

5450
import static org.neo4j.cypherdsl.core.Cypher.anyNode;
5551
import static org.neo4j.cypherdsl.core.Cypher.listBasedOn;
@@ -506,183 +502,13 @@ private MapProjection projectPropertiesAndRelationships(NodeDescription<?> nodeD
506502
List<Object> propertiesProjection = projectNodeProperties(nodeDescription, nodeName, includedProperties);
507503
List<Object> contentOfProjection = new ArrayList<>(propertiesProjection);
508504

509-
Collection<RelationshipDescription> relationships = getRelationshipDescriptionsUpAndDown(nodeDescription, includedProperties);
505+
Collection<RelationshipDescription> relationships = nodeDescription.getRelationshipsUpAndDown(includedProperties);
510506
relationships.removeIf(r -> !includedProperties.test(r.getFieldName()));
511507

512508
contentOfProjection.addAll(generateListsFor(relationships, nodeName, processedRelationships));
513509
return Cypher.anyNode(nodeName).project(contentOfProjection);
514510
}
515511

516-
@NonNull
517-
static Collection<RelationshipDescription> getRelationshipDescriptionsUpAndDown(NodeDescription<?> nodeDescription,
518-
Predicate<String> includedProperties) {
519-
520-
Collection<RelationshipDescription> relationships = new HashSet<>(nodeDescription.getRelationships());
521-
for (NodeDescription<?> childDescription : nodeDescription.getChildNodeDescriptionsInHierarchy()) {
522-
childDescription.getRelationships().forEach(concreteRelationship -> {
523-
524-
String fieldName = concreteRelationship.getFieldName();
525-
526-
if (relationships.stream().noneMatch(relationship -> relationship.getFieldName().equals(fieldName))) {
527-
relationships.add(concreteRelationship);
528-
}
529-
});
530-
}
531-
532-
return relationships.stream().filter(relationshipDescription ->
533-
includedProperties.test(relationshipDescription.getFieldName()))
534-
.collect(Collectors.toSet());
535-
}
536-
537-
private RelationshipPattern createRelationships(Node node, Collection<RelationshipDescription> relationshipDescriptions) {
538-
RelationshipPattern relationship;
539-
540-
Direction determinedDirection = determineDirection(relationshipDescriptions);
541-
if (Direction.OUTGOING.equals(determinedDirection)) {
542-
relationship = node.relationshipTo(anyNode(), collectFirstLevelRelationshipTypes(relationshipDescriptions))
543-
.min(0).max(1);
544-
} else if (Direction.INCOMING.equals(determinedDirection)) {
545-
relationship = node.relationshipFrom(anyNode(), collectFirstLevelRelationshipTypes(relationshipDescriptions))
546-
.min(0).max(1);
547-
} else {
548-
relationship = node.relationshipBetween(anyNode(), collectFirstLevelRelationshipTypes(relationshipDescriptions))
549-
.min(0).max(1);
550-
}
551-
552-
Set<RelationshipDescription> processedRelationshipDescriptions = new HashSet<>(relationshipDescriptions);
553-
for (RelationshipDescription relationshipDescription : relationshipDescriptions) {
554-
Collection<RelationshipDescription> relationships = relationshipDescription.getTarget().getRelationships();
555-
if (relationships.size() > 0) {
556-
relationship = createRelationships(relationship, relationships, processedRelationshipDescriptions)
557-
.relationship;
558-
}
559-
}
560-
561-
return relationship;
562-
}
563-
564-
private RelationshipProcessState createRelationships(RelationshipPattern existingRelationship,
565-
Collection<RelationshipDescription> relationshipDescriptions,
566-
Set<RelationshipDescription> processedRelationshipDescriptions) {
567-
568-
RelationshipPattern relationship = existingRelationship;
569-
String[] relationshipTypes = collectAllRelationshipTypes(relationshipDescriptions);
570-
if (processedRelationshipDescriptions.containsAll(relationshipDescriptions)) {
571-
return new RelationshipProcessState(
572-
relationship.relationshipBetween(anyNode(),
573-
relationshipTypes).unbounded().min(0), true);
574-
}
575-
processedRelationshipDescriptions.addAll(relationshipDescriptions);
576-
577-
// we can process through the path
578-
if (relationshipDescriptions.size() == 1) {
579-
RelationshipDescription relationshipDescription = relationshipDescriptions.iterator().next();
580-
switch (relationshipDescription.getDirection()) {
581-
case OUTGOING:
582-
relationship = existingRelationship.relationshipTo(anyNode(),
583-
collectFirstLevelRelationshipTypes(relationshipDescriptions)).unbounded().min(0).max(1);
584-
break;
585-
case INCOMING:
586-
relationship = existingRelationship.relationshipFrom(anyNode(),
587-
collectFirstLevelRelationshipTypes(relationshipDescriptions)).unbounded().min(0).max(1);
588-
break;
589-
default:
590-
relationship = existingRelationship.relationshipBetween(anyNode(),
591-
collectFirstLevelRelationshipTypes(relationshipDescriptions)).unbounded().min(0).max(1);
592-
}
593-
594-
RelationshipProcessState relationships = createRelationships(relationship,
595-
relationshipDescription.getTarget().getRelationships(), processedRelationshipDescriptions);
596-
597-
if (!relationships.done) {
598-
relationship = relationships.relationship;
599-
}
600-
} else {
601-
Direction determinedDirection = determineDirection(relationshipDescriptions);
602-
if (Direction.OUTGOING.equals(determinedDirection)) {
603-
relationship = existingRelationship.relationshipTo(anyNode(), relationshipTypes).unbounded().min(0);
604-
} else if (Direction.INCOMING.equals(determinedDirection)) {
605-
relationship = existingRelationship.relationshipFrom(anyNode(), relationshipTypes).unbounded().min(0);
606-
} else {
607-
relationship = existingRelationship.relationshipBetween(anyNode(), relationshipTypes).unbounded().min(0);
608-
}
609-
return new RelationshipProcessState(relationship, true);
610-
}
611-
return new RelationshipProcessState(relationship, false);
612-
}
613-
614-
@Nullable
615-
Direction determineDirection(Collection<RelationshipDescription> relationshipDescriptions) {
616-
617-
Direction direction = null;
618-
for (RelationshipDescription relationshipDescription : relationshipDescriptions) {
619-
if (direction == null) {
620-
direction = relationshipDescription.getDirection();
621-
}
622-
if (!direction.equals(relationshipDescription.getDirection())) {
623-
return null;
624-
}
625-
}
626-
return direction;
627-
}
628-
629-
private String[] collectFirstLevelRelationshipTypes(Collection<RelationshipDescription> relationshipDescriptions) {
630-
Set<String> relationshipTypes = new HashSet<>();
631-
632-
for (RelationshipDescription relationshipDescription : relationshipDescriptions) {
633-
String relationshipType = relationshipDescription.getType();
634-
if (relationshipTypes.contains(relationshipType)) {
635-
continue;
636-
}
637-
if (relationshipDescription.isDynamic()) {
638-
handleDynamicRelationship(relationshipTypes, (DefaultRelationshipDescription) relationshipDescription);
639-
continue;
640-
}
641-
relationshipTypes.add(relationshipType);
642-
}
643-
return relationshipTypes.toArray(new String[0]);
644-
}
645-
646-
private String[] collectAllRelationshipTypes(Collection<RelationshipDescription> relationshipDescriptions) {
647-
Set<String> relationshipTypes = new HashSet<>();
648-
649-
for (RelationshipDescription relationshipDescription : relationshipDescriptions) {
650-
String relationshipType = relationshipDescription.getType();
651-
if (relationshipDescription.isDynamic()) {
652-
handleDynamicRelationship(relationshipTypes, (DefaultRelationshipDescription) relationshipDescription);
653-
continue;
654-
}
655-
relationshipTypes.add(relationshipType);
656-
collectAllRelationshipTypes(relationshipDescription.getTarget(), relationshipTypes, new HashSet<>(relationshipDescriptions));
657-
}
658-
return relationshipTypes.toArray(new String[0]);
659-
}
660-
661-
private void handleDynamicRelationship(Set<String> relationshipTypes, DefaultRelationshipDescription relationshipDescription) {
662-
Class<?> componentType = relationshipDescription.getInverse().getComponentType();
663-
if (componentType != null && componentType.isEnum()) {
664-
Arrays.stream(componentType.getEnumConstants())
665-
.forEach(constantName -> relationshipTypes.add(constantName.toString()));
666-
} else {
667-
relationshipTypes.clear();
668-
}
669-
}
670-
671-
private void collectAllRelationshipTypes(NodeDescription<?> nodeDescription, Set<String> relationshipTypes,
672-
Collection<RelationshipDescription> processedRelationshipDescriptions) {
673-
674-
for (RelationshipDescription relationshipDescription : nodeDescription.getRelationships()) {
675-
String relationshipType = relationshipDescription.getType();
676-
if (processedRelationshipDescriptions.contains(relationshipDescription)) {
677-
continue;
678-
}
679-
relationshipTypes.add(relationshipType);
680-
processedRelationshipDescriptions.add(relationshipDescription);
681-
collectAllRelationshipTypes(relationshipDescription.getTarget(), relationshipTypes,
682-
processedRelationshipDescriptions);
683-
}
684-
}
685-
686512
/**
687513
* Creates a list of objects that represents a very basic of {@code MapEntry<String, Object>} with the exception that
688514
* this list can also contain two "keys" in a row. The {@link MapProjection} will take care to handle them as

src/main/java/org/springframework/data/neo4j/core/mapping/DefaultNeo4jEntityConverter.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ private <ET> ET map(MapAccessor queryResult, MapAccessor allValues, Neo4jPersist
250250

251251
Predicate<String> includeAllFields = (field) -> true;
252252

253-
Collection<RelationshipDescription> relationships = CypherGenerator
254-
.getRelationshipDescriptionsUpAndDown(nodeDescription, includeAllFields);
253+
Collection<RelationshipDescription> relationships = nodeDescription
254+
.getRelationshipsUpAndDown(includeAllFields);
255255

256256
ET instance = instantiate(concreteNodeDescription, queryResult, allValues, relationships,
257257
nodeDescriptionAndLabels.getDynamicLabels(), lastMappedEntity);
@@ -314,6 +314,8 @@ private List<String> getLabels(MapAccessor queryResult, @Nullable NodeDescriptio
314314
} else if (queryResult instanceof Node) {
315315
Node nodeRepresentation = (Node) queryResult;
316316
nodeRepresentation.labels().forEach(labels::add);
317+
} else if (!queryResult.get(Constants.NAME_OF_SYNTHESIZED_ROOT_NODE).isNull()) {
318+
queryResult.get(Constants.NAME_OF_SYNTHESIZED_ROOT_NODE).asNode().labels().forEach(labels::add);
317319
} else if (nodeDescription != null) {
318320
labels.addAll(nodeDescription.getStaticLabels());
319321
}
@@ -394,12 +396,6 @@ private Optional<Object> createInstanceOfRelationships(Neo4jPersistentProperty p
394396
Neo4jPersistentEntity<?> genericTargetNodeDescription = (Neo4jPersistentEntity<?>) relationshipDescription
395397
.getTarget();
396398

397-
List<String> allLabels = getLabels(values, null);
398-
NodeDescriptionAndLabels nodeDescriptionAndLabels = NodeDescriptionStore
399-
.deriveConcreteNodeDescription(genericTargetNodeDescription, allLabels);
400-
Neo4jPersistentEntity<?> concreteTargetNodeDescription = (Neo4jPersistentEntity<?>) nodeDescriptionAndLabels
401-
.getNodeDescription();
402-
403399
List<Object> value = new ArrayList<>();
404400
Map<Object, Object> dynamicValue = new HashMap<>();
405401

@@ -461,6 +457,13 @@ private Optional<Object> createInstanceOfRelationships(Neo4jPersistentProperty p
461457

462458
for (Relationship possibleRelationship : allMatchingTypeRelationshipsInResult) {
463459
if (targetIdSelector.apply(possibleRelationship) == targetNodeId && sourceIdSelector.apply(possibleRelationship).equals(sourceNodeId)) {
460+
461+
List<String> allLabels = getLabels(possibleValueNode, null);
462+
NodeDescriptionAndLabels nodeDescriptionAndLabels = NodeDescriptionStore
463+
.deriveConcreteNodeDescription(genericTargetNodeDescription, allLabels);
464+
Neo4jPersistentEntity<?> concreteTargetNodeDescription = (Neo4jPersistentEntity<?>) nodeDescriptionAndLabels
465+
.getNodeDescription();
466+
464467
Object mappedObject = map(possibleValueNode, allValues, concreteTargetNodeDescription);
465468
if (relationshipDescription.hasRelationshipProperties()) {
466469

@@ -480,6 +483,12 @@ private Optional<Object> createInstanceOfRelationships(Neo4jPersistentProperty p
480483
} else {
481484
for (Value relatedEntity : list.asList(Function.identity())) {
482485

486+
List<String> allLabels = getLabels(relatedEntity, null);
487+
NodeDescriptionAndLabels nodeDescriptionAndLabels = NodeDescriptionStore
488+
.deriveConcreteNodeDescription(genericTargetNodeDescription, allLabels);
489+
Neo4jPersistentEntity<?> concreteTargetNodeDescription = (Neo4jPersistentEntity<?>) nodeDescriptionAndLabels
490+
.getNodeDescription();
491+
483492
Object valueEntry = map(relatedEntity, allValues, concreteTargetNodeDescription);
484493

485494
if (relationshipDescription.hasRelationshipProperties()) {

src/main/java/org/springframework/data/neo4j/core/mapping/DefaultNeo4jPersistentEntity.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,26 @@ public Collection<RelationshipDescription> getRelationships() {
421421
return Collections.unmodifiableCollection(relationships);
422422
}
423423

424+
@NonNull
425+
public Collection<RelationshipDescription> getRelationshipsUpAndDown(Predicate<String> propertyFilter) {
426+
427+
Collection<RelationshipDescription> relationships = new HashSet<>(getRelationships());
428+
for (NodeDescription<?> childDescription : getChildNodeDescriptionsInHierarchy()) {
429+
childDescription.getRelationships().forEach(concreteRelationship -> {
430+
431+
String fieldName = concreteRelationship.getFieldName();
432+
433+
if (relationships.stream().noneMatch(relationship -> relationship.getFieldName().equals(fieldName))) {
434+
relationships.add(concreteRelationship);
435+
}
436+
});
437+
}
438+
439+
return relationships.stream().filter(relationshipDescription ->
440+
propertyFilter.test(relationshipDescription.getFieldName()))
441+
.collect(Collectors.toSet());
442+
}
443+
424444
private Collection<GraphPropertyDescription> computeGraphProperties() {
425445

426446
final List<GraphPropertyDescription> computedGraphProperties = new ArrayList<>();
@@ -476,7 +496,7 @@ public boolean containsPossibleCircles(Predicate<String> includeField) {
476496
}
477497

478498
private boolean calculatePossibleCircles(Predicate<String> includeField) {
479-
Collection<RelationshipDescription> relationships = getRelationships();
499+
Collection<RelationshipDescription> relationships = new HashSet<>(getRelationshipsUpAndDown(includeField));
480500

481501
Set<RelationshipDescription> processedRelationships = new HashSet<>();
482502
for (RelationshipDescription relationship : relationships) {
@@ -495,7 +515,7 @@ private boolean calculatePossibleCircles(Predicate<String> includeField) {
495515
}
496516

497517
private boolean calculatePossibleCircles(NodeDescription<?> nodeDescription, Set<RelationshipDescription> processedRelationships) {
498-
Collection<RelationshipDescription> relationships = nodeDescription.getRelationships();
518+
Collection<RelationshipDescription> relationships = nodeDescription.getRelationshipsUpAndDown(s -> true);
499519

500520
for (RelationshipDescription relationship : relationships) {
501521
if (processedRelationships.contains(relationship)) {

src/main/java/org/springframework/data/neo4j/core/mapping/NestedRelationshipContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public static NestedRelationshipContext of(Association<Neo4jPersistentProperty>
124124
Object value = propertyAccessor.getProperty(inverse);
125125
boolean inverseValueIsEmpty = value == null;
126126

127-
RelationshipDescription relationship = neo4jPersistentEntity.getRelationships().stream()
127+
RelationshipDescription relationship = neo4jPersistentEntity.getRelationshipsUpAndDown(s -> true).stream()
128128
.filter(r -> r.getFieldName().equals(inverse.getName())).findFirst().get();
129129

130130
// if we have a relationship with properties, the targetNodeType is the map key

src/main/java/org/springframework/data/neo4j/core/mapping/NodeDescription.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ default boolean isUsingInternalIds() {
103103
*/
104104
Collection<RelationshipDescription> getRelationships();
105105

106+
Collection<RelationshipDescription> getRelationshipsUpAndDown(Predicate<String> propertyPredicate);
107+
106108
/**
107109
* Register a direct child node description for this entity.
108110
*

0 commit comments

Comments
 (0)