Skip to content

Commit 6aef5d9

Browse files
committed
GH-2177 - Improve processing state machine behaviour.
1 parent d84991b commit 6aef5d9

File tree

6 files changed

+123
-25
lines changed

6 files changed

+123
-25
lines changed

src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

+35-10
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ private <T> T saveImpl(T instance) {
269269
propertyAccessor.setProperty(entityMetaData.getRequiredIdProperty(), optionalInternalId.get());
270270
entityToBeSaved = propertyAccessor.getBean();
271271
}
272-
return processRelations(entityMetaData, entityToBeSaved, isEntityNew);
272+
return processRelations(entityMetaData, entityToBeSaved, isEntityNew, instance);
273273
}
274274

275275
private <T> DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersistentEntity<?> entityMetaData) {
@@ -297,9 +297,9 @@ private <T> DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersist
297297
@Override
298298
public <T> List<T> saveAll(Iterable<T> instances) {
299299

300-
Collection<T> entities;
300+
List<T> entities;
301301
if (instances instanceof Collection) {
302-
entities = (Collection<T>) instances;
302+
entities = new ArrayList<>((Collection<T>) instances);
303303
} else {
304304
entities = new ArrayList<>();
305305
instances.forEach(entities::add);
@@ -335,8 +335,11 @@ public <T> List<T> saveAll(Iterable<T> instances) {
335335
.bind(entityList).to(Constants.NAME_OF_ENTITY_LIST_PARAM).run();
336336

337337
// Save related
338-
entitiesToBeSaved.forEach(entityToBeSaved -> processRelations(entityMetaData, entityToBeSaved,
339-
isNewIndicator.get(entitiesToBeSaved.indexOf(entityToBeSaved))));
338+
entitiesToBeSaved.forEach(entityToBeSaved -> {
339+
int positionInList = entitiesToBeSaved.indexOf(entityToBeSaved);
340+
processRelations(entityMetaData, entityToBeSaved, isNewIndicator.get(positionInList),
341+
entities.get(positionInList));
342+
});
340343

341344
SummaryCounters counters = resultSummary.counters();
342345
log.debug(() -> String.format(
@@ -446,10 +449,10 @@ private <T> ExecutableQuery<T> createExecutableQuery(Class<T> domainType, String
446449
}
447450

448451
private <T> T processRelations(Neo4jPersistentEntity<?> neo4jPersistentEntity, Object parentObject,
449-
boolean isParentObjectNew) {
452+
boolean isParentObjectNew, Object parentEntity) {
450453

451454
return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew,
452-
new NestedRelationshipProcessingStateMachine());
455+
new NestedRelationshipProcessingStateMachine(parentEntity));
453456
}
454457

455458
private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Object parentObject,
@@ -480,7 +483,7 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje
480483

481484
// break recursive procession and deletion of previously created relationships
482485
ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore);
483-
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) {
486+
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) {
484487
return;
485488
}
486489

@@ -529,8 +532,14 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje
529532

530533
relatedNode = eventSupport.maybeCallBeforeBind(relatedNode);
531534

532-
Long relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
533-
targetEntity);
535+
Long relatedInternalId;
536+
// No need to save values if processed
537+
if (processState == ProcessState.PROCESSED_ALL_VALUES) {
538+
relatedInternalId = queryRelatedNode(relatedNode, targetEntity);
539+
} else {
540+
relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
541+
targetEntity);
542+
}
534543

535544
CreateRelationshipStatementHolder statementHolder = neo4jMappingContext.createStatement(
536545
sourceEntity, relationshipContext, relatedValueToStore);
@@ -565,6 +574,22 @@ private <T> T processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Obje
565574
return (T) propertyAccessor.getBean();
566575
}
567576

577+
private <Y> Long queryRelatedNode(Object entity, Neo4jPersistentEntity<?> targetNodeDescription) {
578+
579+
Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty();
580+
PersistentPropertyAccessor<Object> targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity);
581+
Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty);
582+
583+
return neo4jClient.query(() ->
584+
renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription,
585+
targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID)))
586+
.returning(Constants.NAME_OF_INTERNAL_ID)
587+
.build())
588+
)
589+
.bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue))
590+
.fetchAs(Long.class).one().get();
591+
}
592+
568593
private <Y> Long saveRelatedNode(Object entity, Class<Y> entityType, NodeDescription targetNodeDescription) {
569594

570595
DynamicLabels dynamicLabels = determineDynamicLabels(entity, (Neo4jPersistentEntity) targetNodeDescription);

src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

+38-14
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ private <T> Mono<T> saveImpl(T instance) {
254254
}));
255255

256256
if (!entityMetaData.isUsingInternalIds()) {
257-
return idMono.then(processRelations(entityMetaData, entity, isNewEntity))
257+
return idMono.then(processRelations(entityMetaData, entity, isNewEntity, instance))
258258
.thenReturn(entity);
259259
} else {
260260
return idMono.map(internalId -> {
@@ -263,7 +263,7 @@ private <T> Mono<T> saveImpl(T instance) {
263263

264264
return propertyAccessor.getBean();
265265
}).flatMap(
266-
savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity)
266+
savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, instance)
267267
.thenReturn(savedEntity));
268268
}
269269
}));
@@ -295,9 +295,9 @@ private <T> Mono<Tuple2<T, DynamicLabels>> determineDynamicLabels(T entityToBeSa
295295
@Override
296296
public <T> Flux<T> saveAll(Iterable<T> instances) {
297297

298-
Collection<T> entities;
298+
List<T> entities;
299299
if (instances instanceof Collection) {
300-
entities = (Collection<T>) instances;
300+
entities = new ArrayList<>((Collection<T>) instances);
301301
} else {
302302
entities = new ArrayList<>();
303303
instances.forEach(entities::add);
@@ -341,7 +341,8 @@ public <T> Flux<T> saveAll(Iterable<T> instances) {
341341
.flatMap(t -> {
342342
T entityToBeSaved = t.getT2();
343343
boolean isNew = isNewIndicator.get(Math.toIntExact(t.getT1()));
344-
return processRelations(entityMetaData, entityToBeSaved, isNew)
344+
return processRelations(entityMetaData, entityToBeSaved, isNew,
345+
entities.get(Math.toIntExact(t.getT1())))
345346
.then(Mono.just(entityToBeSaved));
346347
}
347348
);
@@ -563,10 +564,10 @@ Publisher<Tuple2<Collection<Long>, Collection<Long>>>> iterateAndMapNextLevel(
563564
}
564565

565566
private Mono<Void> processRelations(Neo4jPersistentEntity<?> neo4jPersistentEntity, Object parentObject,
566-
boolean isParentObjectNew) {
567+
boolean isParentObjectNew, Object parentEntity) {
567568

568569
return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew,
569-
new NestedRelationshipProcessingStateMachine());
570+
new NestedRelationshipProcessingStateMachine(parentEntity));
570571
}
571572

572573
private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity, Object parentObject,
@@ -599,7 +600,7 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
599600

600601
// break recursive procession and deletion of previously created relationships
601602
ProcessState processState = stateMachine.getStateOf(relationshipDescriptionObverse, relatedValuesToStore);
602-
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS) {
603+
if (processState == ProcessState.PROCESSED_ALL_RELATIONSHIPS || processState == ProcessState.PROCESSED_BOTH) {
603604
return;
604605
}
605606

@@ -649,9 +650,16 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
649650
.flatMap(relatedNode -> {
650651
Neo4jPersistentEntity<?> targetEntity = neo4jMappingContext
651652
.getPersistentEntity(relatedNodePreEvt.getClass());
652-
return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew ->
653-
saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
654-
targetEntity).flatMap(relatedInternalId -> {
653+
return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew -> {
654+
Mono<Long> relatedIdMono;
655+
656+
if (processState == ProcessState.PROCESSED_ALL_VALUES) {
657+
relatedIdMono = queryRelatedNode(relatedNode, targetEntity);
658+
} else {
659+
relatedIdMono = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(),
660+
targetEntity);
661+
}
662+
return relatedIdMono.flatMap(relatedInternalId -> {
655663

656664
// if an internal id is used this must get set to link this entity in the next iteration
657665
PersistentPropertyAccessor<?> targetPropertyAccessor = targetEntity
@@ -688,7 +696,8 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
688696
} else {
689697
return relationshipCreationMonoNested.checkpoint().then();
690698
}
691-
}).checkpoint());
699+
}).checkpoint();
700+
});
692701
});
693702
relationshipCreationMonos.add(createRelationship);
694703
}
@@ -698,8 +707,23 @@ private Mono<Void> processNestedRelations(Neo4jPersistentEntity<?> sourceEntity,
698707
});
699708
}
700709

701-
private <Y> Mono<Long> saveRelatedNode(Object relatedNode, Class<Y> entityType,
702-
NodeDescription targetNodeDescription) {
710+
private <Y> Mono<Long> queryRelatedNode(Object entity, Neo4jPersistentEntity<?> targetNodeDescription) {
711+
712+
Neo4jPersistentProperty requiredIdProperty = targetNodeDescription.getRequiredIdProperty();
713+
PersistentPropertyAccessor<Object> targetPropertyAccessor = targetNodeDescription.getPropertyAccessor(entity);
714+
Object idValue = targetPropertyAccessor.getProperty(requiredIdProperty);
715+
716+
return neo4jClient.query(() ->
717+
renderer.render(cypherGenerator.prepareMatchOf(targetNodeDescription,
718+
targetNodeDescription.getIdExpression().isEqualTo(parameter(Constants.NAME_OF_ID)))
719+
.returning(Constants.NAME_OF_INTERNAL_ID)
720+
.build())
721+
)
722+
.bindAll(Collections.singletonMap(Constants.NAME_OF_ID, idValue))
723+
.fetchAs(Long.class).one();
724+
}
725+
726+
private <Y> Mono<Long> saveRelatedNode(Object relatedNode, Class<Y> entityType, NodeDescription targetNodeDescription) {
703727

704728
return determineDynamicLabels((Y) relatedNode, (Neo4jPersistentEntity<?>) targetNodeDescription)
705729
.flatMap(t -> {

src/main/java/org/springframework/data/neo4j/core/mapping/NestedRelationshipProcessingStateMachine.java

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ public enum ProcessState {
5555
*/
5656
private final Set<Object> processedObjects = new HashSet<>();
5757

58+
public NestedRelationshipProcessingStateMachine(Object initialObject) {
59+
processedObjects.add(initialObject);
60+
}
61+
5862
/**
5963
* @param relationshipDescription Check whether this relationship description has been processed
6064
* @param valuesToStore Check whether all the values in the collection have been processed

src/test/java/org/springframework/data/neo4j/integration/imperative/OptimisticLockingIT.java

+23
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.junit.jupiter.api.BeforeEach;
2626
import org.junit.jupiter.api.Test;
2727
import org.neo4j.driver.Driver;
28+
import org.neo4j.driver.Record;
2829
import org.neo4j.driver.Session;
2930
import org.neo4j.driver.SessionConfig;
3031
import org.neo4j.driver.Transaction;
@@ -284,6 +285,28 @@ void immutablesShouldWork(@Autowired Neo4jTemplate neo4jTemplate) {
284285
assertThatExceptionOfType(OptimisticLockingFailureException.class).isThrownBy(() -> neo4jTemplate.save(copy));
285286
}
286287

288+
@Test
289+
void shouldDoThings(@Autowired VersionedThingRepository repository) {
290+
VersionedThing thing1 = new VersionedThing("Thing1");
291+
VersionedThing thing2 = new VersionedThing("Thing2");
292+
293+
thing1.setOtherVersionedThings(Collections.singletonList(thing2));
294+
repository.save(thing1);
295+
296+
thing1 = repository.findById(thing1.getId()).get();
297+
thing2 = repository.findById(thing2.getId()).get();
298+
299+
thing2.setOtherVersionedThings(Collections.singletonList(thing1));
300+
repository.save(thing2);
301+
302+
try (Session session = driver.session()) {
303+
List<Record> result = session
304+
.run("MATCH (t:VersionedThing{name:'Thing1'})-[:HAS]->(:VersionedThing{name:'Thing2'}) return t")
305+
.list();
306+
assertThat(result).hasSize(1);
307+
}
308+
}
309+
287310
interface VersionedThingRepository extends Neo4jRepository<VersionedThing, Long> {}
288311

289312
interface VersionedThingWithAssignedIdRepository extends Neo4jRepository<VersionedThingWithAssignedId, Long> {}

src/test/java/org/springframework/data/neo4j/integration/imperative/RepositoryIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3650,7 +3650,7 @@ void findAndInstantiateRelationshipsWithExtendingSuperRootEntity(
36503650
@Autowired SuperBaseClassWithRelationshipRepository repository) {
36513651

36523652
Inheritance.ConcreteClassA ccA = new Inheritance.ConcreteClassA("cc1", "test");
3653-
Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 42);
3653+
Inheritance.ConcreteClassB ccB1 = new Inheritance.ConcreteClassB("cc2a", 41);
36543654
Inheritance.ConcreteClassB ccB2 = new Inheritance.ConcreteClassB("cc2b", 42);
36553655

36563656
List<Inheritance.SuperBaseClass> things = new ArrayList<>();

src/test/java/org/springframework/data/neo4j/integration/shared/common/VersionedThing.java

+22
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.data.neo4j.integration.shared.common;
1717

1818
import java.util.List;
19+
import java.util.Objects;
1920

2021
import org.springframework.data.annotation.Version;
2122
import org.springframework.data.neo4j.core.schema.GeneratedValue;
@@ -41,6 +42,10 @@ public VersionedThing(String name) {
4142
this.name = name;
4243
}
4344

45+
public Long getId() {
46+
return id;
47+
}
48+
4449
public Long getMyVersion() {
4550
return myVersion;
4651
}
@@ -56,4 +61,21 @@ public List<VersionedThing> getOtherVersionedThings() {
5661
public void setOtherVersionedThings(List<VersionedThing> otherVersionedThings) {
5762
this.otherVersionedThings = otherVersionedThings;
5863
}
64+
65+
@Override
66+
public boolean equals(Object o) {
67+
if (this == o) {
68+
return true;
69+
}
70+
if (o == null || getClass() != o.getClass()) {
71+
return false;
72+
}
73+
VersionedThing that = (VersionedThing) o;
74+
return Objects.equals(id, that.id) && name.equals(that.name);
75+
}
76+
77+
@Override
78+
public int hashCode() {
79+
return Objects.hash(id, name);
80+
}
5981
}

0 commit comments

Comments
 (0)