Skip to content

Make state machine for saving process more robust. #2181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-neo4j</artifactId>
<version>6.0.6-SNAPSHOT</version>
<version>6.0.6-GH-2177-SNAPSHOT</version>

<name>Spring Data Neo4j</name>
<description>Next generation Object-Graph-Mapping for Spring Data.</description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ private <T> T saveImpl(T instance, @Nullable String inDatabase) {
propertyAccessor.setProperty(entityMetaData.getRequiredIdProperty(), optionalInternalId.get());
entityToBeSaved = propertyAccessor.getBean();
}
return processRelations(entityMetaData, entityToBeSaved, isEntityNew, inDatabase);
return processRelations(entityMetaData, entityToBeSaved, isEntityNew, inDatabase, instance);
}

private <T> DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersistentEntity<?> entityMetaData,
Expand Down Expand Up @@ -284,9 +284,9 @@ public <T> List<T> saveAll(Iterable<T> instances) {

String databaseName = getDatabaseName();

Collection<T> entities;
List<T> entities;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that I mind, but is it necessary?

if (instances instanceof Collection) {
entities = (Collection<T>) instances;
entities = new ArrayList<>((Collection<T>) instances);
} else {
entities = new ArrayList<>();
instances.forEach(entities::add);
Expand Down Expand Up @@ -323,8 +323,11 @@ public <T> List<T> saveAll(Iterable<T> instances) {
.bind(entityList).to(Constants.NAME_OF_ENTITY_LIST_PARAM).run();

// Save related
entitiesToBeSaved.forEach(entityToBeSaved -> processRelations(entityMetaData, entityToBeSaved,
isNewIndicator.get(entitiesToBeSaved.indexOf(entityToBeSaved)), databaseName));
entitiesToBeSaved.forEach(entityToBeSaved -> {
int positionInList = entitiesToBeSaved.indexOf(entityToBeSaved);
processRelations(entityMetaData, entityToBeSaved, isNewIndicator.get(positionInList), databaseName,
entities.get(positionInList));
});

SummaryCounters counters = resultSummary.counters();
log.debug(() -> String.format(
Expand Down Expand Up @@ -434,10 +437,10 @@ private <T> ExecutableQuery<T> createExecutableQuery(Class<T> domainType, String
}

private <T> T processRelations(Neo4jPersistentEntity<?> neo4jPersistentEntity, Object parentObject,
boolean isParentObjectNew, @Nullable String inDatabase) {
boolean isParentObjectNew, @Nullable String inDatabase, Object parentEntity) {

return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase,
new NestedRelationshipProcessingStateMachine());
new NestedRelationshipProcessingStateMachine(parentEntity));
}

private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Object parentObject,
Expand Down Expand Up @@ -468,7 +471,7 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje

// break recursive procession and deletion of previously created relationships
ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore);
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) {
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) {
return;
}

Expand Down Expand Up @@ -517,8 +520,14 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje

relatedNode = eventSupport.maybeCallBeforeBind(relatedNode);

Long relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
targetEntity, inDatabase);
Long relatedInternalId;
// No need to save values if processed
if (processState == ProcessState.PROCESSED_ALL_VALUES) {
relatedInternalId = queryRelatedNode(relatedNode, targetEntity, inDatabase);
} else {
relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
targetEntity, inDatabase);
}

CreateRelationshipStatementHolder statementHolder = neo4jMappingContext.createStatement(
sourceEntity, relationshipContext, relatedValueToStore);
Expand Down Expand Up @@ -553,6 +562,23 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje
return (T) propertyAccessor.getBean();
}

private <Y> Long queryRelatedNode(Object entity, Neo4jPersistentEntity<?> targetNodeDescription,
@Nullable String inDatabase) {

Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty();
PersistentPropertyAccessor<Object> targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity);
Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty);

return neo4jClient.query(() ->
renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription,
targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID)))
.returning(Constants.NAME_OF_INTERNAL_ID)
.build())
)
.in(inDatabase).bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue))
.fetchAs(Long.class).one().get();
}

private <Y> Long saveRelatedNode(Object entity, Class<Y> entityType, NodeDescription targetNodeDescription,
@Nullable String inDatabase) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ private <T> Mono<T> saveImpl(T instance, @Nullable String inDatabase) {
}));

if (!entityMetaData.isUsingInternalIds()) {
return idMono.then(processRelations(entityMetaData, entity, isNewEntity, inDatabase))
return idMono.then(processRelations(entityMetaData, entity, isNewEntity, inDatabase, instance))
.thenReturn(entity);
} else {
return idMono.map(internalId -> {
Expand All @@ -258,7 +258,7 @@ private <T> Mono<T> saveImpl(T instance, @Nullable String inDatabase) {

return propertyAccessor.getBean();
}).flatMap(
savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, inDatabase)
savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, inDatabase, instance)
.thenReturn(savedEntity));
}
}));
Expand Down Expand Up @@ -290,9 +290,9 @@ private <T> Mono<Tuple2<T, DynamicLabels>> determineDynamicLabels(T entityToBeSa
@Override
public <T> Flux<T> saveAll(Iterable<T> instances) {

Collection<T> entities;
List<T> entities;
if (instances instanceof Collection) {
entities = (Collection<T>) instances;
entities = new ArrayList<>((Collection<T>) instances);
} else {
entities = new ArrayList<>();
instances.forEach(entities::add);
Expand Down Expand Up @@ -338,7 +338,7 @@ public <T> Flux<T> saveAll(Iterable<T> instances) {
T entityToBeSaved = t.getT2();
boolean isNew = isNewIndicator.get(Math.toIntExact(t.getT1()));
return processRelations(entityMetaData, entityToBeSaved, isNew,
databaseName.getValue())
databaseName.getValue(), entities.get(Math.toIntExact(t.getT1())))
.then(Mono.just(entityToBeSaved));
}
);
Expand Down Expand Up @@ -566,10 +566,10 @@ Publisher<Tuple2<Collection<Long>, Collection<Long>>>> iterateAndMapNextLevel(
}

private Mono<Void> processRelations(Neo4jPersistentEntity<?> neo4jPersistentEntity, Object parentObject,
boolean isParentObjectNew, @Nullable String inDatabase) {
boolean isParentObjectNew, @Nullable String inDatabase, Object parentEntity) {

return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase,
new NestedRelationshipProcessingStateMachine());
new NestedRelationshipProcessingStateMachine(parentEntity));
}

private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Object parentObject,
Expand Down Expand Up @@ -602,7 +602,7 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,

// break recursive procession and deletion of previously created relationships
ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore);
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) {
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) {
return;
}

Expand Down Expand Up @@ -652,9 +652,16 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
.flatMap(relatedNode -> {
Neo4jPersistentEntity<?> targetEntity = neo4jMappingContext
.getPersistentEntity(relatedNodePreEvt.getClass());
return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew ->
saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
targetEntity, inDatabase).flatMap(relatedInternalId -> {
return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew -> {
Mono<Long> relatedIdMono;

if (processState == ProcessState.PROCESSED_ALL_VALUES) {
relatedIdMono = queryRelatedNode(relatedNode, targetEntity, inDatabase);
} else {
relatedIdMono = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
targetEntity, inDatabase);
}
return relatedIdMono.flatMap(relatedInternalId -> {

// if an internal id is used this must get set to link this entity in the next iteration
PersistentPropertyAccessor<?> targetPropertyAccessor = targetEntity
Expand Down Expand Up @@ -691,7 +698,8 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
} else {
return relationshipCreationMonoNested.checkpoint().then();
}
}).checkpoint());
}).checkpoint();
});
});
relationshipCreationMonos.add(createRelationship);
}
Expand All @@ -701,6 +709,23 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
});
}

private <Y> Mono<Long> queryRelatedNode(Object entity, Neo4jPersistentEntity<?> targetNodeDescription,
@Nullable String inDatabase) {

Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty();
PersistentPropertyAccessor<Object> targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity);
Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty);

return neo4jClient.query(() ->
renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription,
targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID)))
.returning(Constants.NAME_OF_INTERNAL_ID)
.build())
)
.in(inDatabase).bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue))
.fetchAs(Long.class).one();
}

private <Y> Mono<Long> saveRelatedNode(Object relatedNode, Class<Y> entityType, NodeDescription targetNodeDescription,
@Nullable String inDatabase) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ public enum ProcessState {
*/
private final Set<Object> processedObjects = new HashSet<>();

public NestedRelationshipProcessingStateMachine(Object initialObject) {
processedObjects.add(initialObject);
}

/**
* @param relationshipDescription Check whether this relationship description has been processed
* @param valuesToStore Check whether all the values in the collection have been processed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Record;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Transaction;
Expand Down Expand Up @@ -284,6 +285,28 @@ void immutablesShouldWork(@Autowired Neo4jTemplate neo4jTemplate) {
assertThatExceptionOfType(OptimisticLockingFailureException.class).isThrownBy(() -> neo4jTemplate.save(copy));
}

@Test
void shouldDoThings(@Autowired VersionedThingRepository repository) {
VersionedThing thing1 = new VersionedThing("Thing1");
VersionedThing thing2 = new VersionedThing("Thing2");

thing1.setOtherVersionedThings(Collections.singletonList(thing2));
repository.save(thing1);

thing1 = repository.findById(thing1.getId()).get();
thing2 = repository.findById(thing2.getId()).get();

thing2.setOtherVersionedThings(Collections.singletonList(thing1));
repository.save(thing2);

try (Session session = driver.session()) {
List<Record> result = session
.run("MATCH (t:VersionedThing{name:'Thing1'})-[:HAS]->(:VersionedThing{name:'Thing2'}) return t")
.list();
assertThat(result).hasSize(1);
}
}

interface VersionedThingRepository extends Neo4jRepository<VersionedThing, Long> {}

interface VersionedThingWithAssignedIdRepository extends Neo4jRepository<VersionedThingWithAssignedId, Long> {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3542,7 +3542,7 @@ void findAndInstantiateRelationshipsWithExtendingSuperRootEntity(
@Autowired SuperBaseClassWithRelationshipRepository repository) {

Inheritance.ConcreteClassA ccA = new Inheritance.ConcreteClassA("cc1", "test");
Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 42);
Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 41);
Inheritance.ConcreteClassB ccB2 = new Inheritance.ConcreteClassB("cc2b", 42);

List<Inheritance.SuperBaseClass> things = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.data.neo4j.integration.shared.common;

import java.util.List;
import java.util.Objects;

import org.springframework.data.annotation.Version;
import org.springframework.data.neo4j.core.schema.GeneratedValue;
Expand All @@ -41,6 +42,10 @@ public VersionedThing(String name) {
this.name = name;
}

public Long getId() {
return id;
}

public Long getMyVersion() {
return myVersion;
}
Expand All @@ -56,4 +61,21 @@ public List<VersionedThing> getOtherVersionedThings() {
public void setOtherVersionedThings(List<VersionedThing> otherVersionedThings) {
this.otherVersionedThings = otherVersionedThings;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
VersionedThing that = (VersionedThing) o;
return Objects.equals(id, that.id) && name.equals(that.name);
}

@Override
public int hashCode() {
return Objects.hash(id, name);
}
}