diff --git a/pom.xml b/pom.xml index 2e2a2b577f..852ccc3df6 100644 --- a/pom.xml +++ b/pom.xml @@ -24,7 +24,7 @@ org.springframework.data spring-data-neo4j - 6.0.6-SNAPSHOT + 6.0.6-GH-2177-SNAPSHOT Spring Data Neo4j Next generation Object-Graph-Mapping for Spring Data. diff --git a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java index b4cd3289bc..987491e45e 100644 --- a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java @@ -253,7 +253,7 @@ private T saveImpl(T instance, @Nullable String inDatabase) { propertyAccessor.setProperty(entityMetaData.getRequiredIdProperty(), optionalInternalId.get()); entityToBeSaved = propertyAccessor.getBean(); } - return processRelations(entityMetaData, entityToBeSaved, isEntityNew, inDatabase); + return processRelations(entityMetaData, entityToBeSaved, isEntityNew, inDatabase, instance); } private DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersistentEntity entityMetaData, @@ -284,9 +284,9 @@ public List saveAll(Iterable instances) { String databaseName = getDatabaseName(); - Collection entities; + List entities; if (instances instanceof Collection) { - entities = (Collection) instances; + entities = new ArrayList<>((Collection) instances); } else { entities = new ArrayList<>(); instances.forEach(entities::add); @@ -323,8 +323,11 @@ public List saveAll(Iterable instances) { .bind(entityList).to(Constants.NAME_OF_ENTITY_LIST_PARAM).run(); // Save related - entitiesToBeSaved.forEach(entityToBeSaved -> processRelations(entityMetaData, entityToBeSaved, - isNewIndicator.get(entitiesToBeSaved.indexOf(entityToBeSaved)), databaseName)); + entitiesToBeSaved.forEach(entityToBeSaved -> { + int positionInList = entitiesToBeSaved.indexOf(entityToBeSaved); + processRelations(entityMetaData, entityToBeSaved, isNewIndicator.get(positionInList), databaseName, + entities.get(positionInList)); + }); SummaryCounters counters = resultSummary.counters(); log.debug(() -> String.format( @@ -434,10 +437,10 @@ private ExecutableQuery createExecutableQuery(Class domainType, String } private T processRelations(Neo4jPersistentEntity neo4jPersistentEntity, Object parentObject, - boolean isParentObjectNew, @Nullable String inDatabase) { + boolean isParentObjectNew, @Nullable String inDatabase, Object parentEntity) { return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase, - new NestedRelationshipProcessingStateMachine()); + new NestedRelationshipProcessingStateMachine(parentEntity)); } private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Object parentObject, @@ -468,7 +471,7 @@ private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Obje // break recursive procession and deletion of previously created relationships ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore); - if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) { + if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) { return; } @@ -517,8 +520,14 @@ private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Obje relatedNode = eventSupport.maybeCallBeforeBind(relatedNode); - Long relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(), - targetEntity, inDatabase); + Long relatedInternalId; + // No need to save values if processed + if (processState == ProcessState.PROCESSED_ALL_VALUES) { + relatedInternalId = queryRelatedNode(relatedNode, targetEntity, inDatabase); + } else { + relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(), + targetEntity, inDatabase); + } CreateRelationshipStatementHolder statementHolder = neo4jMappingContext.createStatement( sourceEntity, relationshipContext, relatedValueToStore); @@ -553,6 +562,23 @@ private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Obje return (T) propertyAccessor.getBean(); } + private Long queryRelatedNode(Object entity, Neo4jPersistentEntity targetNodeDescription, + @Nullable String inDatabase) { + + Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty(); + PersistentPropertyAccessor targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity); + Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty); + + return neo4jClient.query(() -> + renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription, + targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID))) + .returning(Constants.NAME_OF_INTERNAL_ID) + .build()) + ) + .in(inDatabase).bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue)) + .fetchAs(Long.class).one().get(); + } + private Long saveRelatedNode(Object entity, Class entityType, NodeDescription targetNodeDescription, @Nullable String inDatabase) { diff --git a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java index a340c11751..87ac44e994 100644 --- a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java @@ -249,7 +249,7 @@ private Mono saveImpl(T instance, @Nullable String inDatabase) { })); if (!entityMetaData.isUsingInternalIds()) { - return idMono.then(processRelations(entityMetaData, entity, isNewEntity, inDatabase)) + return idMono.then(processRelations(entityMetaData, entity, isNewEntity, inDatabase, instance)) .thenReturn(entity); } else { return idMono.map(internalId -> { @@ -258,7 +258,7 @@ private Mono saveImpl(T instance, @Nullable String inDatabase) { return propertyAccessor.getBean(); }).flatMap( - savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, inDatabase) + savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, inDatabase, instance) .thenReturn(savedEntity)); } })); @@ -290,9 +290,9 @@ private Mono> determineDynamicLabels(T entityToBeSa @Override public Flux saveAll(Iterable instances) { - Collection entities; + List entities; if (instances instanceof Collection) { - entities = (Collection) instances; + entities = new ArrayList<>((Collection) instances); } else { entities = new ArrayList<>(); instances.forEach(entities::add); @@ -338,7 +338,7 @@ public Flux saveAll(Iterable instances) { T entityToBeSaved = t.getT2(); boolean isNew = isNewIndicator.get(Math.toIntExact(t.getT1())); return processRelations(entityMetaData, entityToBeSaved, isNew, - databaseName.getValue()) + databaseName.getValue(), entities.get(Math.toIntExact(t.getT1()))) .then(Mono.just(entityToBeSaved)); } ); @@ -566,10 +566,10 @@ Publisher, Collection>>> iterateAndMapNextLevel( } private Mono processRelations(Neo4jPersistentEntity neo4jPersistentEntity, Object parentObject, - boolean isParentObjectNew, @Nullable String inDatabase) { + boolean isParentObjectNew, @Nullable String inDatabase, Object parentEntity) { return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase, - new NestedRelationshipProcessingStateMachine()); + new NestedRelationshipProcessingStateMachine(parentEntity)); } private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, Object parentObject, @@ -602,7 +602,7 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, // break recursive procession and deletion of previously created relationships ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore); - if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) { + if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) { return; } @@ -652,9 +652,16 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, .flatMap(relatedNode -> { Neo4jPersistentEntity targetEntity = neo4jMappingContext .getPersistentEntity(relatedNodePreEvt.getClass()); - return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew -> - saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(), - targetEntity, inDatabase).flatMap(relatedInternalId -> { + return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew -> { + Mono relatedIdMono; + + if (processState == ProcessState.PROCESSED_ALL_VALUES) { + relatedIdMono = queryRelatedNode(relatedNode, targetEntity, inDatabase); + } else { + relatedIdMono = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(), + targetEntity, inDatabase); + } + return relatedIdMono.flatMap(relatedInternalId -> { // if an internal id is used this must get set to link this entity in the next iteration PersistentPropertyAccessor targetPropertyAccessor = targetEntity @@ -691,7 +698,8 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, } else { return relationshipCreationMonoNested.checkpoint().then(); } - }).checkpoint()); + }).checkpoint(); + }); }); relationshipCreationMonos.add(createRelationship); } @@ -701,6 +709,23 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, }); } + private Mono queryRelatedNode(Object entity, Neo4jPersistentEntity targetNodeDescription, + @Nullable String inDatabase) { + + Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty(); + PersistentPropertyAccessor targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity); + Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty); + + return neo4jClient.query(() -> + renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription, + targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID))) + .returning(Constants.NAME_OF_INTERNAL_ID) + .build()) + ) + .in(inDatabase).bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue)) + .fetchAs(Long.class).one(); + } + private Mono saveRelatedNode(Object relatedNode, Class entityType, NodeDescription targetNodeDescription, @Nullable String inDatabase) { diff --git a/src/main/java/org/springframework/data/neo4j/core/mapping/NestedRelationshipProcessingStateMachine.java b/src/main/java/org/springframework/data/neo4j/core/mapping/NestedRelationshipProcessingStateMachine.java index 4c2f27ab71..16b0e02d3c 100644 --- a/src/main/java/org/springframework/data/neo4j/core/mapping/NestedRelationshipProcessingStateMachine.java +++ b/src/main/java/org/springframework/data/neo4j/core/mapping/NestedRelationshipProcessingStateMachine.java @@ -55,6 +55,10 @@ public enum ProcessState { */ private final Set processedObjects = new HashSet<>(); + public NestedRelationshipProcessingStateMachine(Object initialObject) { + processedObjects.add(initialObject); + } + /** * @param relationshipDescription Check whether this relationship description has been processed * @param valuesToStore Check whether all the values in the collection have been processed diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/OptimisticLockingIT.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/OptimisticLockingIT.java index f7bdda8fe8..c529d1a155 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/imperative/OptimisticLockingIT.java +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/OptimisticLockingIT.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.neo4j.driver.Driver; +import org.neo4j.driver.Record; import org.neo4j.driver.Session; import org.neo4j.driver.SessionConfig; import org.neo4j.driver.Transaction; @@ -284,6 +285,28 @@ void immutablesShouldWork(@Autowired Neo4jTemplate neo4jTemplate) { assertThatExceptionOfType(OptimisticLockingFailureException.class).isThrownBy(() -> neo4jTemplate.save(copy)); } + @Test + void shouldDoThings(@Autowired VersionedThingRepository repository) { + VersionedThing thing1 = new VersionedThing("Thing1"); + VersionedThing thing2 = new VersionedThing("Thing2"); + + thing1.setOtherVersionedThings(Collections.singletonList(thing2)); + repository.save(thing1); + + thing1 = repository.findById(thing1.getId()).get(); + thing2 = repository.findById(thing2.getId()).get(); + + thing2.setOtherVersionedThings(Collections.singletonList(thing1)); + repository.save(thing2); + + try (Session session = driver.session()) { + List result = session + .run("MATCH (t:VersionedThing{name:'Thing1'})-[:HAS]->(:VersionedThing{name:'Thing2'}) return t") + .list(); + assertThat(result).hasSize(1); + } + } + interface VersionedThingRepository extends Neo4jRepository {} interface VersionedThingWithAssignedIdRepository extends Neo4jRepository {} diff --git a/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java b/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java index c5cd6573b4..0e8608c6e4 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java +++ b/src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java @@ -3542,7 +3542,7 @@ void findAndInstantiateRelationshipsWithExtendingSuperRootEntity( @Autowired SuperBaseClassWithRelationshipRepository repository) { Inheritance.ConcreteClassA ccA = new Inheritance.ConcreteClassA("cc1", "test"); - Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 42); + Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 41); Inheritance.ConcreteClassB ccB2 = new Inheritance.ConcreteClassB("cc2b", 42); List things = new ArrayList<>(); diff --git a/src/test/java/org/springframework/data/neo4j/integration/shared/common/VersionedThing.java b/src/test/java/org/springframework/data/neo4j/integration/shared/common/VersionedThing.java index 264326b066..bfb717fd78 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/shared/common/VersionedThing.java +++ b/src/test/java/org/springframework/data/neo4j/integration/shared/common/VersionedThing.java @@ -16,6 +16,7 @@ package org.springframework.data.neo4j.integration.shared.common; import java.util.List; +import java.util.Objects; import org.springframework.data.annotation.Version; import org.springframework.data.neo4j.core.schema.GeneratedValue; @@ -41,6 +42,10 @@ public VersionedThing(String name) { this.name = name; } + public Long getId() { + return id; + } + public Long getMyVersion() { return myVersion; } @@ -56,4 +61,21 @@ public List getOtherVersionedThings() { public void setOtherVersionedThings(List otherVersionedThings) { this.otherVersionedThings = otherVersionedThings; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + VersionedThing that = (VersionedThing) o; + return Objects.equals(id, that.id) && name.equals(that.name); + } + + @Override + public int hashCode() { + return Objects.hash(id, name); + } }