Skip to content

Commit 65ee2ae

Browse files
committed
GH-2177 - Improve processing state machine behaviour.
1 parent 1a6a2f3 commit 65ee2ae

File tree

6 files changed

+123
-23
lines changed

6 files changed

+123
-23
lines changed

src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

+36-10
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ private <T> T saveImpl(T instance, @Nullable String inDatabase) {
253253
propertyAccessor.setProperty(entityMetaData.getRequiredIdProperty(), optionalInternalId.get());
254254
entityToBeSaved = propertyAccessor.getBean();
255255
}
256-
return processRelations(entityMetaData, entityToBeSaved, isEntityNew, inDatabase);
256+
return processRelations(entityMetaData, entityToBeSaved, isEntityNew, inDatabase, instance);
257257
}
258258

259259
private <T> DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersistentEntity<?> entityMetaData,
@@ -284,9 +284,9 @@ public <T> List<T> saveAll(Iterable<T> instances) {
284284

285285
String databaseName = getDatabaseName();
286286

287-
Collection<T> entities;
287+
List<T> entities;
288288
if (instances instanceof Collection) {
289-
entities = (Collection<T>) instances;
289+
entities = new ArrayList<>((Collection<T>) instances);
290290
} else {
291291
entities = new ArrayList<>();
292292
instances.forEach(entities::add);
@@ -323,8 +323,11 @@ public <T> List<T> saveAll(Iterable<T> instances) {
323323
.bind(entityList).to(Constants.NAME_OF_ENTITY_LIST_PARAM).run();
324324

325325
// Save related
326-
entitiesToBeSaved.forEach(entityToBeSaved -> processRelations(entityMetaData, entityToBeSaved,
327-
isNewIndicator.get(entitiesToBeSaved.indexOf(entityToBeSaved)), databaseName));
326+
entitiesToBeSaved.forEach(entityToBeSaved -> {
327+
int positionInList = entitiesToBeSaved.indexOf(entityToBeSaved);
328+
processRelations(entityMetaData, entityToBeSaved, isNewIndicator.get(positionInList), databaseName,
329+
entities.get(positionInList));
330+
});
328331

329332
SummaryCounters counters = resultSummary.counters();
330333
log.debug(() -> String.format(
@@ -434,10 +437,10 @@ private <T> ExecutableQuery<T> createExecutableQuery(Class<T> domainType, String
434437
}
435438

436439
private <T> T processRelations(Neo4jPersistentEntity<?> neo4jPersistentEntity, Object parentObject,
437-
boolean isParentObjectNew, @Nullable String inDatabase) {
440+
boolean isParentObjectNew, @Nullable String inDatabase, Object parentEntity) {
438441

439442
return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase,
440-
new NestedRelationshipProcessingStateMachine());
443+
new NestedRelationshipProcessingStateMachine(parentEntity));
441444
}
442445

443446
private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Object parentObject,
@@ -468,7 +471,7 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje
468471

469472
// break recursive procession and deletion of previously created relationships
470473
ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore);
471-
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) {
474+
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) {
472475
return;
473476
}
474477

@@ -517,8 +520,14 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje
517520

518521
relatedNode = eventSupport.maybeCallBeforeBind(relatedNode);
519522

520-
Long relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
521-
targetEntity, inDatabase);
523+
Long relatedInternalId;
524+
// No need to save values if processed
525+
if (processState == ProcessState.PROCESSED_ALL_VALUES) {
526+
relatedInternalId = queryRelatedNode(relatedNode, targetEntity, inDatabase);
527+
} else {
528+
relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
529+
targetEntity, inDatabase);
530+
}
522531

523532
CreateRelationshipStatementHolder statementHolder = neo4jMappingContext.createStatement(
524533
sourceEntity, relationshipContext, relatedValueToStore);
@@ -553,6 +562,23 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje
553562
return (T) propertyAccessor.getBean();
554563
}
555564

565+
private <Y> Long queryRelatedNode(Object entity, Neo4jPersistentEntity<?> targetNodeDescription,
566+
@Nullable String inDatabase) {
567+
568+
Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty();
569+
PersistentPropertyAccessor<Object> targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity);
570+
Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty);
571+
572+
return neo4jClient.query(() ->
573+
renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription,
574+
targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID)))
575+
.returning(Constants.NAME_OF_INTERNAL_ID)
576+
.build())
577+
)
578+
.in(inDatabase).bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue))
579+
.fetchAs(Long.class).one().get();
580+
}
581+
556582
private <Y> Long saveRelatedNode(Object entity, Class<Y> entityType, NodeDescription targetNodeDescription,
557583
@Nullable String inDatabase) {
558584

src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

+37-12
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ private <T> Mono<T> saveImpl(T instance, @Nullable String inDatabase) {
249249
}));
250250

251251
if (!entityMetaData.isUsingInternalIds()) {
252-
return idMono.then(processRelations(entityMetaData, entity, isNewEntity, inDatabase))
252+
return idMono.then(processRelations(entityMetaData, entity, isNewEntity, inDatabase, instance))
253253
.thenReturn(entity);
254254
} else {
255255
return idMono.map(internalId -> {
@@ -258,7 +258,7 @@ private <T> Mono<T> saveImpl(T instance, @Nullable String inDatabase) {
258258

259259
return propertyAccessor.getBean();
260260
}).flatMap(
261-
savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, inDatabase)
261+
savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, inDatabase, instance)
262262
.thenReturn(savedEntity));
263263
}
264264
}));
@@ -290,9 +290,9 @@ private <T> Mono<Tuple2<T, DynamicLabels>> determineDynamicLabels(T entityToBeSa
290290
@Override
291291
public <T> Flux<T> saveAll(Iterable<T> instances) {
292292

293-
Collection<T> entities;
293+
List<T> entities;
294294
if (instances instanceof Collection) {
295-
entities = (Collection<T>) instances;
295+
entities = new ArrayList<>((Collection<T>) instances);
296296
} else {
297297
entities = new ArrayList<>();
298298
instances.forEach(entities::add);
@@ -338,7 +338,7 @@ public <T> Flux<T> saveAll(Iterable<T> instances) {
338338
T entityToBeSaved = t.getT2();
339339
boolean isNew = isNewIndicator.get(Math.toIntExact(t.getT1()));
340340
return processRelations(entityMetaData, entityToBeSaved, isNew,
341-
databaseName.getValue())
341+
databaseName.getValue(), entities.get(Math.toIntExact(t.getT1())))
342342
.then(Mono.just(entityToBeSaved));
343343
}
344344
);
@@ -566,10 +566,10 @@ Publisher<Tuple2<Collection<Long>, Collection<Long>>>> iterateAndMapNextLevel(
566566
}
567567

568568
private Mono<Void> processRelations(Neo4jPersistentEntity<?> neo4jPersistentEntity, Object parentObject,
569-
boolean isParentObjectNew, @Nullable String inDatabase) {
569+
boolean isParentObjectNew, @Nullable String inDatabase, Object parentEntity) {
570570

571571
return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase,
572-
new NestedRelationshipProcessingStateMachine());
572+
new NestedRelationshipProcessingStateMachine(parentEntity));
573573
}
574574

575575
private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Object parentObject,
@@ -602,7 +602,7 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
602602

603603
// break recursive procession and deletion of previously created relationships
604604
ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore);
605-
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) {
605+
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) {
606606
return;
607607
}
608608

@@ -652,9 +652,16 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
652652
.flatMap(relatedNode -> {
653653
Neo4jPersistentEntity<?> targetEntity = neo4jMappingContext
654654
.getPersistentEntity(relatedNodePreEvt.getClass());
655-
return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew ->
656-
saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
657-
targetEntity, inDatabase).flatMap(relatedInternalId -> {
655+
return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew -> {
656+
Mono<Long> relatedIdMono;
657+
658+
if (processState == ProcessState.PROCESSED_ALL_VALUES) {
659+
relatedIdMono = queryRelatedNode(relatedNode, targetEntity, inDatabase);
660+
} else {
661+
relatedIdMono = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
662+
targetEntity, inDatabase);
663+
}
664+
return relatedIdMono.flatMap(relatedInternalId -> {
658665

659666
// if an internal id is used this must get set to link this entity in the next iteration
660667
PersistentPropertyAccessor<?> targetPropertyAccessor = targetEntity
@@ -691,7 +698,8 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
691698
} else {
692699
return relationshipCreationMonoNested.checkpoint().then();
693700
}
694-
}).checkpoint());
701+
}).checkpoint();
702+
});
695703
});
696704
relationshipCreationMonos.add(createRelationship);
697705
}
@@ -701,6 +709,23 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
701709
});
702710
}
703711

712+
private <Y> Mono<Long> queryRelatedNode(Object entity, Neo4jPersistentEntity<?> targetNodeDescription,
713+
@Nullable String inDatabase) {
714+
715+
Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty();
716+
PersistentPropertyAccessor<Object> targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity);
717+
Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty);
718+
719+
return neo4jClient.query(() ->
720+
renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription,
721+
targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID)))
722+
.returning(Constants.NAME_OF_INTERNAL_ID)
723+
.build())
724+
)
725+
.in(inDatabase).bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue))
726+
.fetchAs(Long.class).one();
727+
}
728+
704729
private <Y> Mono<Long> saveRelatedNode(Object relatedNode, Class<Y> entityType, NodeDescription targetNodeDescription,
705730
@Nullable String inDatabase) {
706731

src/main/java/org/springframework/data/neo4j/core/mapping/NestedRelationshipProcessingStateMachine.java

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ public enum ProcessState {
5555
*/
5656
private final Set<Object> processedObjects = new HashSet<>();
5757

58+
public NestedRelationshipProcessingStateMachine(Object initialObject) {
59+
processedObjects.add(initialObject);
60+
}
61+
5862
/**
5963
* @param relationshipDescription Check whether this relationship description has been processed
6064
* @param valuesToStore Check whether all the values in the collection have been processed

src/test/java/org/springframework/data/neo4j/integration/imperative/OptimisticLockingIT.java

+23
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.junit.jupiter.api.BeforeEach;
2626
import org.junit.jupiter.api.Test;
2727
import org.neo4j.driver.Driver;
28+
import org.neo4j.driver.Record;
2829
import org.neo4j.driver.Session;
2930
import org.neo4j.driver.SessionConfig;
3031
import org.neo4j.driver.Transaction;
@@ -284,6 +285,28 @@ void immutablesShouldWork(@Autowired Neo4jTemplate neo4jTemplate) {
284285
assertThatExceptionOfType(OptimisticLockingFailureException.class).isThrownBy(() -> neo4jTemplate.save(copy));
285286
}
286287

288+
@Test
289+
void shouldDoThings(@Autowired VersionedThingRepository repository) {
290+
VersionedThing thing1 = new VersionedThing("Thing1");
291+
VersionedThing thing2 = new VersionedThing("Thing2");
292+
293+
thing1.setOtherVersionedThings(Collections.singletonList(thing2));
294+
repository.save(thing1);
295+
296+
thing1 = repository.findById(thing1.getId()).get();
297+
thing2 = repository.findById(thing2.getId()).get();
298+
299+
thing2.setOtherVersionedThings(Collections.singletonList(thing1));
300+
repository.save(thing2);
301+
302+
try (Session session = driver.session()) {
303+
List<Record> result = session
304+
.run("MATCH (t:VersionedThing{name:'Thing1'})-[:HAS]->(:VersionedThing{name:'Thing2'}) return t")
305+
.list();
306+
assertThat(result).hasSize(1);
307+
}
308+
}
309+
287310
interface VersionedThingRepository extends Neo4jRepository<VersionedThing, Long> {}
288311

289312
interface VersionedThingWithAssignedIdRepository extends Neo4jRepository<VersionedThingWithAssignedId, Long> {}

src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3598,7 +3598,7 @@ void findAndInstantiateRelationshipsWithExtendingSuperRootEntity(
35983598
@Autowired SuperBaseClassWithRelationshipRepository repository) {
35993599

36003600
Inheritance.ConcreteClassA ccA = new Inheritance.ConcreteClassA("cc1", "test");
3601-
Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 42);
3601+
Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 41);
36023602
Inheritance.ConcreteClassB ccB2 = new Inheritance.ConcreteClassB("cc2b", 42);
36033603

36043604
List<Inheritance.SuperBaseClass> things = new ArrayList<>();

src/test/java/org/springframework/data/neo4j/integration/shared/common/VersionedThing.java

+22
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.data.neo4j.integration.shared.common;
1717

1818
import java.util.List;
19+
import java.util.Objects;
1920

2021
import org.springframework.data.annotation.Version;
2122
import org.springframework.data.neo4j.core.schema.GeneratedValue;
@@ -41,6 +42,10 @@ public VersionedThing(String name) {
4142
this.name = name;
4243
}
4344

45+
public Long getId() {
46+
return id;
47+
}
48+
4449
public Long getMyVersion() {
4550
return myVersion;
4651
}
@@ -56,4 +61,21 @@ public List<VersionedThing> getOtherVersionedThings() {
5661
public void setOtherVersionedThings(List<VersionedThing> otherVersionedThings) {
5762
this.otherVersionedThings = otherVersionedThings;
5863
}
64+
65+
@Override
66+
public boolean equals(Object o) {
67+
if (this == o) {
68+
return true;
69+
}
70+
if (o == null || getClass() != o.getClass()) {
71+
return false;
72+
}
73+
VersionedThing that = (VersionedThing) o;
74+
return Objects.equals(id, that.id) && name.equals(that.name);
75+
}
76+
77+
@Override
78+
public int hashCode() {
79+
return Objects.hash(id, name);
80+
}
5981
}

0 commit comments

Comments
 (0)