diff --git a/pom.xml b/pom.xml index 797b8633da..b58640aaf9 100644 --- a/pom.xml +++ b/pom.xml @@ -24,7 +24,7 @@ org.springframework.data spring-data-neo4j - 6.1.0-SNAPSHOT + 6.1.0-GH-2159-SNAPSHOT Spring Data Neo4j Next generation Object-Graph-Mapping for Spring Data. diff --git a/src/main/java/org/springframework/data/neo4j/config/AbstractNeo4jConfig.java b/src/main/java/org/springframework/data/neo4j/config/AbstractNeo4jConfig.java index 9f66735ae2..0cf7666102 100644 --- a/src/main/java/org/springframework/data/neo4j/config/AbstractNeo4jConfig.java +++ b/src/main/java/org/springframework/data/neo4j/config/AbstractNeo4jConfig.java @@ -54,15 +54,14 @@ public abstract class AbstractNeo4jConfig extends Neo4jConfigurationSupport { * @return A imperative Neo4j client. */ @Bean(Neo4jRepositoryConfigurationExtension.DEFAULT_NEO4J_CLIENT_BEAN_NAME) - public Neo4jClient neo4jClient(Driver driver) { - return Neo4jClient.create(driver); + public Neo4jClient neo4jClient(Driver driver, DatabaseSelectionProvider databaseSelectionProvider) { + return Neo4jClient.create(driver, databaseSelectionProvider); } @Bean(Neo4jRepositoryConfigurationExtension.DEFAULT_NEO4J_TEMPLATE_BEAN_NAME) - public Neo4jOperations neo4jTemplate(final Neo4jClient neo4jClient, final Neo4jMappingContext mappingContext, - DatabaseSelectionProvider databaseNameProvider) { + public Neo4jOperations neo4jTemplate(final Neo4jClient neo4jClient, final Neo4jMappingContext mappingContext) { - return new Neo4jTemplate(neo4jClient, mappingContext, databaseNameProvider); + return new Neo4jTemplate(neo4jClient, mappingContext); } /** diff --git a/src/main/java/org/springframework/data/neo4j/config/AbstractReactiveNeo4jConfig.java b/src/main/java/org/springframework/data/neo4j/config/AbstractReactiveNeo4jConfig.java index baf49ede0e..176823dc83 100644 --- a/src/main/java/org/springframework/data/neo4j/config/AbstractReactiveNeo4jConfig.java +++ b/src/main/java/org/springframework/data/neo4j/config/AbstractReactiveNeo4jConfig.java @@ -60,9 +60,9 @@ public ReactiveNeo4jClient neo4jClient(Driver driver) { @Bean(ReactiveNeo4jRepositoryConfigurationExtension.DEFAULT_NEO4J_TEMPLATE_BEAN_NAME) public ReactiveNeo4jTemplate neo4jTemplate(final ReactiveNeo4jClient neo4jClient, - final Neo4jMappingContext mappingContext, final ReactiveDatabaseSelectionProvider databaseNameProvider) { + final Neo4jMappingContext mappingContext) { - return new ReactiveNeo4jTemplate(neo4jClient, mappingContext, databaseNameProvider); + return new ReactiveNeo4jTemplate(neo4jClient, mappingContext); } /** diff --git a/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java b/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java index 50abfc0ce2..59dd78cb21 100644 --- a/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java +++ b/src/main/java/org/springframework/data/neo4j/config/Neo4jCdiConfigurationSupport.java @@ -15,12 +15,6 @@ */ package org.springframework.data.neo4j.config; -import javax.enterprise.context.ApplicationScoped; -import javax.enterprise.inject.Any; -import javax.enterprise.inject.Instance; -import javax.enterprise.inject.Produces; -import javax.inject.Singleton; - import org.apiguardian.api.API; import org.neo4j.driver.Driver; import org.springframework.data.neo4j.core.DatabaseSelectionProvider; @@ -32,6 +26,12 @@ import org.springframework.data.neo4j.core.transaction.Neo4jTransactionManager; import org.springframework.transaction.PlatformTransactionManager; +import javax.enterprise.context.ApplicationScoped; +import javax.enterprise.inject.Any; +import javax.enterprise.inject.Instance; +import javax.enterprise.inject.Produces; +import javax.inject.Singleton; + /** * Support class that can be used as is for all necessary CDI beans or as a blueprint for custom producers. * @@ -66,10 +66,9 @@ public DatabaseSelectionProvider databaseSelectionProvider() { @Produces @Builtin @Singleton public Neo4jOperations neo4jOperations( @Any Instance neo4jClient, - @Any Instance mappingContext, - @Any Instance databaseNameProvider + @Any Instance mappingContext ) { - return new Neo4jTemplate(resolve(neo4jClient), resolve(mappingContext), resolve(databaseNameProvider)); + return new Neo4jTemplate(resolve(neo4jClient), resolve(mappingContext)); } @Produces @Singleton diff --git a/src/main/java/org/springframework/data/neo4j/core/DefaultNeo4jClient.java b/src/main/java/org/springframework/data/neo4j/core/DefaultNeo4jClient.java index 1062f5da45..7b03fa7861 100644 --- a/src/main/java/org/springframework/data/neo4j/core/DefaultNeo4jClient.java +++ b/src/main/java/org/springframework/data/neo4j/core/DefaultNeo4jClient.java @@ -56,13 +56,15 @@ class DefaultNeo4jClient implements Neo4jClient { private final Driver driver; private final TypeSystem typeSystem; + private final DatabaseSelectionProvider databaseSelectionProvider; private final ConversionService conversionService; private final Neo4jPersistenceExceptionTranslator persistenceExceptionTranslator = new Neo4jPersistenceExceptionTranslator(); - DefaultNeo4jClient(Driver driver) { + DefaultNeo4jClient(Driver driver, DatabaseSelectionProvider databaseSelectionProvider) { this.driver = driver; this.typeSystem = driver.defaultTypeSystem(); + this.databaseSelectionProvider = databaseSelectionProvider; this.conversionService = new DefaultConversionService(); new Neo4jConversions().registerConvertersIn((ConverterRegistry) conversionService); @@ -262,7 +264,11 @@ class DefaultRecordFetchSpec implements RecordFetchSpec, MappingSpec { DefaultRecordFetchSpec(String targetDatabase, RunnableStatement runnableStatement, BiFunction mappingFunction) { - this.targetDatabase = targetDatabase; + this.targetDatabase = targetDatabase != null + ? targetDatabase + : databaseSelectionProvider != null + ? databaseSelectionProvider.getDatabaseSelection().getValue() + : null; this.runnableStatement = runnableStatement; this.mappingFunction = mappingFunction; } diff --git a/src/main/java/org/springframework/data/neo4j/core/DefaultReactiveNeo4jClient.java b/src/main/java/org/springframework/data/neo4j/core/DefaultReactiveNeo4jClient.java index b485b13801..b730c7f70b 100644 --- a/src/main/java/org/springframework/data/neo4j/core/DefaultReactiveNeo4jClient.java +++ b/src/main/java/org/springframework/data/neo4j/core/DefaultReactiveNeo4jClient.java @@ -15,15 +15,6 @@ */ package org.springframework.data.neo4j.core; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.util.function.Tuple2; - -import java.util.Map; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Supplier; - import org.neo4j.driver.Driver; import org.neo4j.driver.Record; import org.neo4j.driver.reactive.RxQueryRunner; @@ -41,6 +32,14 @@ import org.springframework.data.neo4j.core.transaction.ReactiveNeo4jTransactionManager; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; /** * Reactive variant of the {@link Neo4jClient}. @@ -54,13 +53,15 @@ class DefaultReactiveNeo4jClient implements ReactiveNeo4jClient { private final Driver driver; private final TypeSystem typeSystem; + private final ReactiveDatabaseSelectionProvider databaseSelectionProvider; private final ConversionService conversionService; private final Neo4jPersistenceExceptionTranslator persistenceExceptionTranslator = new Neo4jPersistenceExceptionTranslator(); - DefaultReactiveNeo4jClient(Driver driver) { + DefaultReactiveNeo4jClient(Driver driver, @Nullable ReactiveDatabaseSelectionProvider databaseSelectionProvider) { this.driver = driver; this.typeSystem = driver.defaultTypeSystem(); + this.databaseSelectionProvider = databaseSelectionProvider; this.conversionService = new DefaultConversionService(); new Neo4jConversions().registerConvertersIn((ConverterRegistry) conversionService); } @@ -180,7 +181,7 @@ public Mono run() { class DefaultRecordFetchSpec implements RecordFetchSpec, MappingSpec { - private final String targetDatabase; + private final Mono targetDatabase; private final Supplier cypherSupplier; @@ -192,9 +193,18 @@ class DefaultRecordFetchSpec implements RecordFetchSpec, MappingSpec { this(targetDatabase, cypherSupplier, parameters, null); } - DefaultRecordFetchSpec(String targetDatabase, Supplier cypherSupplier, NamedParameters parameters, + DefaultRecordFetchSpec(@Nullable String targetDatabase, Supplier cypherSupplier, NamedParameters parameters, @Nullable BiFunction mappingFunction) { - this.targetDatabase = targetDatabase; + + this.targetDatabase = Mono.defer(() -> { + if (targetDatabase != null) { + return ReactiveDatabaseSelectionProvider.createStaticDatabaseSelectionProvider(targetDatabase) + .getDatabaseSelection(); + } else if (databaseSelectionProvider != null) { + return databaseSelectionProvider.getDatabaseSelection(); + } + return Mono.just(DatabaseSelection.undecided()); + }); this.cypherSupplier = cypherSupplier; this.parameters = parameters; this.mappingFunction = mappingFunction; @@ -229,33 +239,36 @@ Flux executeWith(Tuple2> t, RxQueryRunner runner) @Override public Mono one() { - return doInQueryRunnerForMono(targetDatabase, - (runner) -> prepareStatement().flatMapMany(t -> executeWith(t, runner)).singleOrEmpty()) + return targetDatabase.flatMap(databaseSelection -> doInQueryRunnerForMono(databaseSelection.getValue(), + (runner) -> prepareStatement().flatMapMany(t -> executeWith(t, runner)).singleOrEmpty())) .onErrorMap(RuntimeException.class, DefaultReactiveNeo4jClient.this::potentiallyConvertRuntimeException); } @Override public Mono first() { - return doInQueryRunnerForMono(targetDatabase, - runner -> prepareStatement().flatMapMany(t -> executeWith(t, runner)).next()) + return targetDatabase.flatMap(databaseSelection -> doInQueryRunnerForMono(databaseSelection.getValue(), + runner -> prepareStatement().flatMapMany(t -> executeWith(t, runner)).next())) .onErrorMap(RuntimeException.class, DefaultReactiveNeo4jClient.this::potentiallyConvertRuntimeException); } @Override public Flux all() { - return doInStatementRunnerForFlux(targetDatabase, - runner -> prepareStatement().flatMapMany(t -> executeWith(t, runner))).onErrorMap(RuntimeException.class, - DefaultReactiveNeo4jClient.this::potentiallyConvertRuntimeException); + return targetDatabase.flatMapMany(databaseSelection -> + doInStatementRunnerForFlux(databaseSelection.getValue(), + runner -> prepareStatement().flatMapMany(t -> executeWith(t, runner))) + ) + .onErrorMap(RuntimeException.class, DefaultReactiveNeo4jClient.this::potentiallyConvertRuntimeException); } Mono run() { - return doInQueryRunnerForMono(targetDatabase, runner -> prepareStatement().flatMap(t -> { - RxResult rxResult = runner.run(t.getT1(), t.getT2()); - return Flux.from(rxResult.records()).then(Mono.from(rxResult.consume()).map(ResultSummaries::process)); - })).onErrorMap(RuntimeException.class, DefaultReactiveNeo4jClient.this::potentiallyConvertRuntimeException); + return targetDatabase.flatMap(databaseSelection -> + doInQueryRunnerForMono(databaseSelection.getValue(), runner -> prepareStatement().flatMap(t -> { + RxResult rxResult = runner.run(t.getT1(), t.getT2()); + return Flux.from(rxResult.records()).then(Mono.from(rxResult.consume()).map(ResultSummaries::process)); + }))).onErrorMap(RuntimeException.class, DefaultReactiveNeo4jClient.this::potentiallyConvertRuntimeException); } } diff --git a/src/main/java/org/springframework/data/neo4j/core/Neo4jClient.java b/src/main/java/org/springframework/data/neo4j/core/Neo4jClient.java index ab9658edb5..60fbef290c 100644 --- a/src/main/java/org/springframework/data/neo4j/core/Neo4jClient.java +++ b/src/main/java/org/springframework/data/neo4j/core/Neo4jClient.java @@ -48,7 +48,12 @@ public interface Neo4jClient { static Neo4jClient create(Driver driver) { - return new DefaultNeo4jClient(driver); + return new DefaultNeo4jClient(driver, null); + } + + static Neo4jClient create(Driver driver, DatabaseSelectionProvider databaseSelectionProvider) { + + return new DefaultNeo4jClient(driver, databaseSelectionProvider); } /** diff --git a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java index 157e3d0fd3..10126042bd 100644 --- a/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java @@ -98,31 +98,25 @@ public final class Neo4jTemplate implements Neo4jOperations, BeanFactoryAware { private EventSupport eventSupport; - private final DatabaseSelectionProvider databaseSelectionProvider; - public Neo4jTemplate(Neo4jClient neo4jClient) { - this(neo4jClient, new Neo4jMappingContext(), DatabaseSelectionProvider.getDefaultSelectionProvider()); + this(neo4jClient, new Neo4jMappingContext()); } - public Neo4jTemplate(Neo4jClient neo4jClient, Neo4jMappingContext neo4jMappingContext, - DatabaseSelectionProvider databaseSelectionProvider) { + public Neo4jTemplate(Neo4jClient neo4jClient, Neo4jMappingContext neo4jMappingContext) { - this(neo4jClient, neo4jMappingContext, databaseSelectionProvider, EntityCallbacks.create()); + this(neo4jClient, neo4jMappingContext, EntityCallbacks.create()); } public Neo4jTemplate(Neo4jClient neo4jClient, Neo4jMappingContext neo4jMappingContext, - DatabaseSelectionProvider databaseSelectionProvider, EntityCallbacks entityCallbacks) { + EntityCallbacks entityCallbacks) { Assert.notNull(neo4jClient, "The Neo4jClient is required"); Assert.notNull(neo4jMappingContext, "The Neo4jMappingContext is required"); - Assert.notNull(databaseSelectionProvider, "The database name provider is required"); this.neo4jClient = neo4jClient; this.neo4jMappingContext = neo4jMappingContext; this.cypherGenerator = CypherGenerator.INSTANCE; this.eventSupport = EventSupport.useExistingCallbacks(neo4jMappingContext, entityCallbacks); - - this.databaseSelectionProvider = databaseSelectionProvider; } @Override @@ -225,21 +219,20 @@ private Object convertIdValues(@Nullable Neo4jPersistentProperty idProperty, Obj @Override public T save(T instance) { - return saveImpl(instance, getDatabaseName()); + return saveImpl(instance); } - private T saveImpl(T instance, @Nullable String inDatabase) { + private T saveImpl(T instance) { Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getPersistentEntity(instance.getClass()); boolean isEntityNew = entityMetaData.isNew(instance); T entityToBeSaved = eventSupport.maybeCallBeforeBind(instance); - DynamicLabels dynamicLabels = determineDynamicLabels(entityToBeSaved, entityMetaData, inDatabase); + DynamicLabels dynamicLabels = determineDynamicLabels(entityToBeSaved, entityMetaData); Optional optionalInternalId = neo4jClient .query(() -> renderer.render(cypherGenerator.prepareSaveOf(entityMetaData, dynamicLabels))) - .in(inDatabase) .bind(entityToBeSaved) .with(neo4jMappingContext.getRequiredBinderFunctionFor((Class) entityToBeSaved.getClass())) .fetchAs(Long.class).one(); @@ -253,17 +246,16 @@ private T saveImpl(T instance, @Nullable String inDatabase) { propertyAccessor.setProperty(entityMetaData.getRequiredIdProperty(), optionalInternalId.get()); entityToBeSaved = propertyAccessor.getBean(); } - return processRelations(entityMetaData, entityToBeSaved, isEntityNew, inDatabase); + return processRelations(entityMetaData, entityToBeSaved, isEntityNew); } - private DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersistentEntity entityMetaData, - @Nullable String inDatabase) { + private DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersistentEntity entityMetaData) { return entityMetaData.getDynamicLabelsProperty().map(p -> { PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(entityToBeSaved); Neo4jClient.RunnableSpecTightToDatabase runnableQuery = neo4jClient .query(() -> renderer.render(cypherGenerator.createStatementReturningDynamicLabels(entityMetaData))) - .in(inDatabase).bind(propertyAccessor.getProperty(entityMetaData.getRequiredIdProperty())) + .bind(propertyAccessor.getProperty(entityMetaData.getRequiredIdProperty())) .to(Constants.NAME_OF_ID).bind(entityMetaData.getStaticLabels()) .to(Constants.NAME_OF_STATIC_LABELS_PARAM); @@ -282,8 +274,6 @@ private DynamicLabels determineDynamicLabels(T entityToBeSaved, Neo4jPersist @Override public List saveAll(Iterable instances) { - String databaseName = getDatabaseName(); - Collection entities; if (instances instanceof Collection) { entities = (Collection) instances; @@ -302,7 +292,7 @@ public List saveAll(Iterable instances) { || entityMetaData.getDynamicLabelsProperty().isPresent()) { log.debug("Saving entities using single statements."); - return entities.stream().map(e -> saveImpl(e, databaseName)).collect(Collectors.toList()); + return entities.stream().map(e -> saveImpl(e)).collect(Collectors.toList()); } // we need to determine the `isNew` state of the entities before calling the id generator @@ -319,12 +309,11 @@ public List saveAll(Iterable instances) { .collect(Collectors.toList()); ResultSummary resultSummary = neo4jClient .query(() -> renderer.render(cypherGenerator.prepareSaveOfMultipleInstancesOf(entityMetaData))) - .in(databaseName) .bind(entityList).to(Constants.NAME_OF_ENTITY_LIST_PARAM).run(); // Save related entitiesToBeSaved.forEach(entityToBeSaved -> processRelations(entityMetaData, entityToBeSaved, - isNewIndicator.get(entitiesToBeSaved.indexOf(entityToBeSaved)), databaseName)); + isNewIndicator.get(entitiesToBeSaved.indexOf(entityToBeSaved)))); SummaryCounters counters = resultSummary.counters(); log.debug(() -> String.format( @@ -345,7 +334,7 @@ public void deleteById(Object id, Class domainType) { log.debug(() -> String.format("Deleting entity with id %s ", id)); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - ResultSummary summary = this.neo4jClient.query(renderer.render(statement)).in(getDatabaseName()) + ResultSummary summary = this.neo4jClient.query(renderer.render(statement)) .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), id)) .to(nameOfParameter).run(); @@ -389,8 +378,8 @@ public void deleteAllById(Iterable ids, Class domainType) { log.debug(() -> String.format("Deleting all entities with the following ids: %s ", ids)); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - ResultSummary summary = this.neo4jClient.query(renderer.render(statement)).in(getDatabaseName()).bind( - convertIdValues(entityMetaData.getRequiredIdProperty(), ids)) + ResultSummary summary = this.neo4jClient.query(renderer.render(statement)) + .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), ids)) .to(nameOfParameter).run(); log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), @@ -404,7 +393,7 @@ public void deleteAll(Class domainType) { log.debug(() -> String.format("Deleting all nodes with primary label %s", entityMetaData.getPrimaryLabel())); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData); - ResultSummary summary = this.neo4jClient.query(renderer.render(statement)).in(getDatabaseName()).run(); + ResultSummary summary = this.neo4jClient.query(renderer.render(statement)).run(); log.debug(() -> String.format("Deleted %d nodes and %d relationships.", summary.counters().nodesDeleted(), summary.counters().relationshipsDeleted())); @@ -434,14 +423,14 @@ private ExecutableQuery createExecutableQuery(Class domainType, String } private T processRelations(Neo4jPersistentEntity neo4jPersistentEntity, Object parentObject, - boolean isParentObjectNew, @Nullable String inDatabase) { + boolean isParentObjectNew) { - return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase, + return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, new NestedRelationshipProcessingStateMachine()); } private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Object parentObject, - boolean isParentObjectNew, @Nullable String inDatabase, NestedRelationshipProcessingStateMachine stateMachine) { + boolean isParentObjectNew, NestedRelationshipProcessingStateMachine stateMachine) { PersistentPropertyAccessor propertyAccessor = sourceEntity.getPropertyAccessor(parentObject); Object fromId = propertyAccessor.getProperty(sourceEntity.getRequiredIdProperty()); @@ -492,7 +481,7 @@ private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Obje Statement relationshipRemoveQuery = cypherGenerator.prepareDeleteOf(sourceEntity, relationshipDescription); - neo4jClient.query(renderer.render(relationshipRemoveQuery)).in(inDatabase) + neo4jClient.query(renderer.render(relationshipRemoveQuery)) .bind(convertIdValues(sourceEntity.getIdProperty(), fromId)) // .to(Constants.FROM_ID_PARAMETER_NAME) // .bind(knownRelationshipsIds) // @@ -518,12 +507,12 @@ private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Obje relatedNode = eventSupport.maybeCallBeforeBind(relatedNode); Long relatedInternalId = saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(), - targetEntity, inDatabase); + targetEntity); CreateRelationshipStatementHolder statementHolder = neo4jMappingContext.createStatement( sourceEntity, relationshipContext, relatedValueToStore); - Optional relationshipInternalId = neo4jClient.query(renderer.render(statementHolder.getStatement())).in(inDatabase) + Optional relationshipInternalId = neo4jClient.query(renderer.render(statementHolder.getStatement())) .bind(convertIdValues(sourceEntity.getRequiredIdProperty(), fromId)) // .to(Constants.FROM_ID_PARAMETER_NAME) .bind(relatedInternalId) // @@ -543,7 +532,7 @@ private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Obje targetPropertyAccessor.setProperty(targetEntity.getRequiredIdProperty(), relatedInternalId); } if (processState != ProcessState.PROCESSED_ALL_VALUES) { - processNestedRelations(targetEntity, targetPropertyAccessor.getBean(), isEntityNew, inDatabase, stateMachine); + processNestedRelations(targetEntity, targetPropertyAccessor.getBean(), isEntityNew, stateMachine); } } @@ -553,14 +542,12 @@ private T processNestedRelations(Neo4jPersistentEntity sourceEntity, Obje return (T) propertyAccessor.getBean(); } - private Long saveRelatedNode(Object entity, Class entityType, NodeDescription targetNodeDescription, - @Nullable String inDatabase) { + private Long saveRelatedNode(Object entity, Class entityType, NodeDescription targetNodeDescription) { - DynamicLabels dynamicLabels = determineDynamicLabels(entity, (Neo4jPersistentEntity) targetNodeDescription, - inDatabase); + DynamicLabels dynamicLabels = determineDynamicLabels(entity, (Neo4jPersistentEntity) targetNodeDescription); Optional optionalSavedNodeId = neo4jClient .query(() -> renderer.render(cypherGenerator.prepareSaveOf(targetNodeDescription, dynamicLabels))) - .in(inDatabase).bind((Y) entity).with(neo4jMappingContext.getRequiredBinderFunctionFor(entityType)) + .bind((Y) entity).with(neo4jMappingContext.getRequiredBinderFunctionFor(entityType)) .fetchAs(Long.class).one(); if (((Neo4jPersistentEntity) targetNodeDescription).hasVersionProperty() && !optionalSavedNodeId.isPresent()) { @@ -570,11 +557,6 @@ private Long saveRelatedNode(Object entity, Class entityType, NodeDescrip return optionalSavedNodeId.get(); } - private String getDatabaseName() { - - return this.databaseSelectionProvider.getDatabaseSelection().getValue(); - } - @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { @@ -676,7 +658,7 @@ private Optional> createFetchSpec() { } Neo4jClient.MappingSpec newMappingSpec = neo4jClient.query(cypherQuery) - .in(getDatabaseName()).bindAll(finalParameters).fetchAs(preparedQuery.getResultType()); + .bindAll(finalParameters).fetchAs(preparedQuery.getResultType()); return Optional.of(preparedQuery.getOptionalMappingFunction() .map(f -> newMappingSpec.mappedBy(f)).orElse(newMappingSpec)); } @@ -690,7 +672,7 @@ private GenericQueryAndParameters createQueryAndParameters(Neo4jPersistentEntity .returning(Constants.NAME_OF_SYNTHESIZED_ROOT_NODE).build(); final Collection rootNodeIds = new HashSet<>((Collection) neo4jClient - .query(renderer.render(rootNodesStatement)).in(getDatabaseName()) + .query(renderer.render(rootNodesStatement)) .bindAll(parameters) .fetch() .one() @@ -716,7 +698,7 @@ private GenericQueryAndParameters createQueryAndParameters(Neo4jPersistentEntity .prepareMatchOf(entityMetaData, relationshipDescription, queryFragments.getMatchOn(), queryFragments.getCondition()) .returning(cypherGenerator.createReturnStatementForMatch(entityMetaData)).build(); - neo4jClient.query(renderer.render(statement)).in(getDatabaseName()) + neo4jClient.query(renderer.render(statement)) .bindAll(parameters) .fetch() .one() @@ -739,7 +721,7 @@ private void iterateNextLevel(Collection nodeIds, Neo4jPersistentEntity Functions.id(node).in(Cypher.parameter(Constants.NAME_OF_IDS))) .returning(cypherGenerator.createGenericReturnStatement()).build(); - neo4jClient.query(renderer.render(statement)).in(getDatabaseName()) + neo4jClient.query(renderer.render(statement)) .bindAll(Collections.singletonMap(Constants.NAME_OF_IDS, nodeIds)) .fetch() .one() diff --git a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jClient.java b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jClient.java index dda2bc8f2b..d3aea6f0e2 100644 --- a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jClient.java +++ b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jClient.java @@ -48,7 +48,12 @@ public interface ReactiveNeo4jClient { static ReactiveNeo4jClient create(Driver driver) { - return new DefaultReactiveNeo4jClient(driver); + return new DefaultReactiveNeo4jClient(driver, null); + } + + static ReactiveNeo4jClient create(Driver driver, ReactiveDatabaseSelectionProvider databaseSelectionProvider) { + + return new DefaultReactiveNeo4jClient(driver, databaseSelectionProvider); } /** diff --git a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java index 9b52a7c342..ecbc85a26e 100644 --- a/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java +++ b/src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java @@ -99,20 +99,15 @@ public final class ReactiveNeo4jTemplate implements ReactiveNeo4jOperations, Bea private ReactiveEventSupport eventSupport; - private final ReactiveDatabaseSelectionProvider databaseSelectionProvider; - - public ReactiveNeo4jTemplate(ReactiveNeo4jClient neo4jClient, Neo4jMappingContext neo4jMappingContext, - ReactiveDatabaseSelectionProvider databaseSelectionProvider) { + public ReactiveNeo4jTemplate(ReactiveNeo4jClient neo4jClient, Neo4jMappingContext neo4jMappingContext) { Assert.notNull(neo4jClient, "The Neo4jClient is required"); Assert.notNull(neo4jMappingContext, "The Neo4jMappingContext is required"); - Assert.notNull(databaseSelectionProvider, "The database selection provider is required"); this.neo4jClient = neo4jClient; this.neo4jMappingContext = neo4jMappingContext; this.cypherGenerator = CypherGenerator.INSTANCE; this.eventSupport = ReactiveEventSupport.useExistingCallbacks(neo4jMappingContext, ReactiveEntityCallbacks.create()); - this.databaseSelectionProvider = databaseSelectionProvider; } @Override @@ -226,21 +221,21 @@ private Object convertIdValues(@Nullable Neo4jPersistentProperty idProperty, Obj @Override public Mono save(T instance) { - return getDatabaseName().flatMap(databaseName -> saveImpl(instance, databaseName.getValue())); + return saveImpl(instance); } - private Mono saveImpl(T instance, @Nullable String inDatabase) { + private Mono saveImpl(T instance) { Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getPersistentEntity(instance.getClass()); return Mono.just(entityMetaData.isNew(instance)) .flatMap(isNewEntity -> Mono.just(instance).flatMap(eventSupport::maybeCallBeforeBind) - .flatMap(entity -> determineDynamicLabels(entity, entityMetaData, inDatabase)).flatMap(t -> { + .flatMap(entity -> determineDynamicLabels(entity, entityMetaData)).flatMap(t -> { T entity = t.getT1(); DynamicLabels dynamicLabels = t.getT2(); Statement saveStatement = cypherGenerator.prepareSaveOf(entityMetaData, dynamicLabels); - Mono idMono = this.neo4jClient.query(() -> renderer.render(saveStatement)).in(inDatabase) + Mono idMono = this.neo4jClient.query(() -> renderer.render(saveStatement)) .bind(entity).with(neo4jMappingContext.getRequiredBinderFunctionFor((Class) entity.getClass())) .fetchAs(Long.class).one().switchIfEmpty(Mono.defer(() -> { if (entityMetaData.hasVersionProperty()) { @@ -250,7 +245,7 @@ private Mono saveImpl(T instance, @Nullable String inDatabase) { })); if (!entityMetaData.isUsingInternalIds()) { - return idMono.then(processRelations(entityMetaData, entity, isNewEntity, inDatabase)) + return idMono.then(processRelations(entityMetaData, entity, isNewEntity)) .thenReturn(entity); } else { return idMono.map(internalId -> { @@ -259,20 +254,20 @@ private Mono saveImpl(T instance, @Nullable String inDatabase) { return propertyAccessor.getBean(); }).flatMap( - savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity, inDatabase) + savedEntity -> processRelations(entityMetaData, savedEntity, isNewEntity) .thenReturn(savedEntity)); } })); } private Mono> determineDynamicLabels(T entityToBeSaved, - Neo4jPersistentEntity entityMetaData, @Nullable String inDatabase) { + Neo4jPersistentEntity entityMetaData) { return entityMetaData.getDynamicLabelsProperty().map(p -> { PersistentPropertyAccessor propertyAccessor = entityMetaData.getPropertyAccessor(entityToBeSaved); ReactiveNeo4jClient.RunnableSpecTightToDatabase runnableQuery = neo4jClient .query(() -> renderer.render(cypherGenerator.createStatementReturningDynamicLabels(entityMetaData))) - .in(inDatabase).bind(propertyAccessor.getProperty(entityMetaData.getRequiredIdProperty())) + .bind(propertyAccessor.getProperty(entityMetaData.getRequiredIdProperty())) .to(Constants.NAME_OF_ID).bind(entityMetaData.getStaticLabels()).to(Constants.NAME_OF_STATIC_LABELS_PARAM); if (entityMetaData.hasVersionProperty()) { @@ -310,20 +305,19 @@ public Flux saveAll(Iterable instances) { || entityMetaData.getDynamicLabelsProperty().isPresent()) { log.debug("Saving entities using single statements."); - return getDatabaseName().flatMapMany( - databaseName -> Flux.fromIterable(entities).flatMap(e -> this.saveImpl(e, databaseName.getValue()))); + return Flux.fromIterable(entities).flatMap(e -> this.saveImpl(e)); } Function> binderFunction = neo4jMappingContext.getRequiredBinderFunctionFor(domainClass); String isNewIndicatorKey = "isNewIndicator"; - return getDatabaseName().flatMapMany(databaseName -> Flux.fromIterable(entities) + return Flux.fromIterable(entities) .flatMap(eventSupport::maybeCallBeforeBind).collectList().flatMapMany(entitiesToBeSaved -> Mono.defer(() -> { // Defer the actual save statement until the previous flux completes List> boundedEntityList = entitiesToBeSaved.stream().map(binderFunction) .collect(Collectors.toList()); return neo4jClient .query(() -> renderer.render(cypherGenerator.prepareSaveOfMultipleInstancesOf(entityMetaData))) - .in(databaseName.getValue()).bind(boundedEntityList).to(Constants.NAME_OF_ENTITY_LIST_PARAM).run(); + .bind(boundedEntityList).to(Constants.NAME_OF_ENTITY_LIST_PARAM).run(); }).doOnNext(resultSummary -> { SummaryCounters counters = resultSummary.counters(); log.debug(() -> String.format( @@ -338,13 +332,12 @@ public Flux saveAll(Iterable instances) { .flatMap(t -> { T entityToBeSaved = t.getT2(); boolean isNew = isNewIndicator.get(Math.toIntExact(t.getT1())); - return processRelations(entityMetaData, entityToBeSaved, isNew, - databaseName.getValue()) + return processRelations(entityMetaData, entityToBeSaved, isNew) .then(Mono.just(entityToBeSaved)); } ); }) - ))) + )) .contextWrite(ctx -> ctx.put(isNewIndicatorKey, entities.stream() .map(entity -> entityMetaData.isNew(entity)).collect(Collectors.toList()))); } @@ -357,8 +350,7 @@ public Mono deleteAllById(Iterable ids, Class domainType) { Condition condition = entityMetaData.getIdExpression().in(parameter(nameOfParameter)); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - return getDatabaseName().flatMap(databaseName -> this.neo4jClient.query(() -> renderer.render(statement)) - .in(databaseName.getValue()) + return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)) .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), ids)) .to(nameOfParameter).run().then()); } @@ -373,8 +365,7 @@ public Mono deleteById(Object id, Class domainType) { Condition condition = entityMetaData.getIdExpression().isEqualTo(parameter(nameOfParameter)); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData, condition); - return getDatabaseName().flatMap(databaseName -> this.neo4jClient.query(() -> renderer.render(statement)) - .in(databaseName.getValue()) + return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)) .bind(convertIdValues(entityMetaData.getRequiredIdProperty(), id)) .to(nameOfParameter).run().then()); } @@ -397,8 +388,7 @@ public Mono deleteByIdWithVersion(Object id, Class domainType, Neo4 parameters.put(nameOfParameter, convertIdValues(entityMetaData.getRequiredIdProperty(), id)); parameters.put(Constants.NAME_OF_VERSION_PARAM, versionValue); - return getDatabaseName().flatMap(databaseName -> this.neo4jClient.query(() -> renderer.render(statement)) - .in(databaseName.getValue()) + return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)) .bindAll(parameters) .fetch().one().switchIfEmpty(Mono.defer(() -> { if (entityMetaData.hasVersionProperty()) { @@ -414,8 +404,7 @@ public Mono deleteAll(Class domainType) { Neo4jPersistentEntity entityMetaData = neo4jMappingContext.getPersistentEntity(domainType); Statement statement = cypherGenerator.prepareDeleteOf(entityMetaData); - return getDatabaseName().flatMap(databaseName -> this.neo4jClient.query(() -> renderer.render(statement)) - .in(databaseName.getValue()).run().then()); + return Mono.defer(() -> this.neo4jClient.query(() -> renderer.render(statement)).run().then()); } private Mono> createExecutableQuery(Class domainType, Statement statement) { @@ -474,7 +463,6 @@ private Mono createQueryAndParameters(Neo4jPersistent || queryFragments.getReturnTuple().getIncludedProperties().isEmpty() || queryFragments.getReturnTuple().getIncludedProperties().contains(relationshipDescription.getFieldName()); - return getDatabaseName().flatMap(databaseName -> { return Mono.deferContextual(ctx -> { Set rootNodeIds = ctx.get("rootNodes"); Set processedRelationshipIds = ctx.get("processedRelationships"); @@ -487,7 +475,7 @@ private Mono createQueryAndParameters(Neo4jPersistent queryFragments.getMatchOn(), queryFragments.getCondition()) .returning(cypherGenerator.createReturnStatementForMatch(entityMetaData)).build(); - return neo4jClient.query(renderer.render(statement)).in(databaseName.getValue()) + return neo4jClient.query(renderer.render(statement)) .bindAll(parameters) .fetch() .one() @@ -499,12 +487,11 @@ private Mono createQueryAndParameters(Neo4jPersistent return Tuples.of(newRelationshipIds, newRelatedNodeIds); }) - .expand(iterateAndMapNextLevel(relationshipDescription, databaseName.getValue())); + .expand(iterateAndMapNextLevel(relationshipDescription)); }) .collect(GenericQueryAndParameters::new, (genericQueryAndParameters, _not_used2) -> genericQueryAndParameters.with(rootNodeIds, processedRelationshipIds, processedNodeIds) ); - }); }) .contextWrite(ctx -> { return ctx @@ -516,7 +503,7 @@ private Mono createQueryAndParameters(Neo4jPersistent } private Flux, Collection>> iterateNextLevel(Collection relatedNodeIds, - RelationshipDescription relationshipDescription, String databaseName) { + RelationshipDescription relationshipDescription) { NodeDescription target = relationshipDescription.getTarget(); @@ -529,7 +516,7 @@ private Flux, Collection>> iterateNextLevel(Collec Functions.id(node).in(Cypher.parameter(Constants.NAME_OF_ID))) .returning(cypherGenerator.createGenericReturnStatement()).build(); - return neo4jClient.query(renderer.render(statement)).in(databaseName) + return neo4jClient.query(renderer.render(statement)) .bindAll(Collections.singletonMap(Constants.NAME_OF_ID, relatedNodeIds)) .fetch() @@ -540,7 +527,7 @@ private Flux, Collection>> iterateNextLevel(Collec return Tuples.of(newRelationshipIds, newRelatedNodeIds); }) - .expand(object -> iterateAndMapNextLevel(relDe, databaseName).apply(object)); + .expand(object -> iterateAndMapNextLevel(relDe).apply(object)); }); } @@ -548,7 +535,7 @@ private Flux, Collection>> iterateNextLevel(Collec @NonNull private Function, Collection>, Publisher, Collection>>> iterateAndMapNextLevel( - RelationshipDescription relationshipDescription, String databaseName) { + RelationshipDescription relationshipDescription) { return newRelationshipAndRelatedNodeIds -> { return Flux.deferContextual(ctx -> { @@ -571,20 +558,20 @@ Publisher, Collection>>> iterateAndMapNextLevel( return Mono.empty(); } - return iterateNextLevel(newRelatedNodeIds, relationshipDescription, databaseName); + return iterateNextLevel(newRelatedNodeIds, relationshipDescription); }); }; } private Mono processRelations(Neo4jPersistentEntity neo4jPersistentEntity, Object parentObject, - boolean isParentObjectNew, @Nullable String inDatabase) { + boolean isParentObjectNew) { - return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, inDatabase, + return processNestedRelations(neo4jPersistentEntity, parentObject, isParentObjectNew, new NestedRelationshipProcessingStateMachine()); } private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, Object parentObject, - boolean isParentObjectNew, @Nullable String inDatabase, NestedRelationshipProcessingStateMachine stateMachine) { + boolean isParentObjectNew, NestedRelationshipProcessingStateMachine stateMachine) { return Mono.defer(() -> { PersistentPropertyAccessor propertyAccessor = sourceEntity.getPropertyAccessor(parentObject); @@ -640,7 +627,7 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, Statement relationshipRemoveQuery = cypherGenerator.prepareDeleteOf(sourceEntity, relationshipDescription); relationshipCreationMonos.add( - neo4jClient.query(renderer.render(relationshipRemoveQuery)).in(inDatabase) + neo4jClient.query(renderer.render(relationshipRemoveQuery)) .bind(convertIdValues(sourceEntity.getIdProperty(), fromId)) // .to(Constants.FROM_ID_PARAMETER_NAME) // .bind(knownRelationshipsIds) // @@ -665,7 +652,7 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, .getPersistentEntity(relatedNodePreEvt.getClass()); return Mono.just(targetEntity.isNew(relatedNode)).flatMap(isNew -> saveRelatedNode(relatedNode, relationshipContext.getAssociationTargetType(), - targetEntity, inDatabase).flatMap(relatedInternalId -> { + targetEntity).flatMap(relatedInternalId -> { // if an internal id is used this must get set to link this entity in the next iteration PersistentPropertyAccessor targetPropertyAccessor = targetEntity @@ -680,7 +667,7 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, // in case of no properties the bind will just return an empty map Mono relationshipCreationMonoNested = neo4jClient - .query(renderer.render(statementHolder.getStatement())).in(inDatabase) + .query(renderer.render(statementHolder.getStatement())) .bind(convertIdValues(sourceEntity.getRequiredIdProperty(), fromId)) // .to(Constants.FROM_ID_PARAMETER_NAME) // .bind(relatedInternalId) // @@ -698,7 +685,7 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, if (processState != ProcessState.PROCESSED_ALL_VALUES) { return relationshipCreationMonoNested.checkpoint().then( processNestedRelations(targetEntity, targetPropertyAccessor.getBean(), - isNew, inDatabase, stateMachine)); + isNew, stateMachine)); } else { return relationshipCreationMonoNested.checkpoint().then(); } @@ -712,17 +699,17 @@ private Mono processNestedRelations(Neo4jPersistentEntity sourceEntity, }); } - private Mono saveRelatedNode(Object relatedNode, Class entityType, NodeDescription targetNodeDescription, - @Nullable String inDatabase) { + private Mono saveRelatedNode(Object relatedNode, Class entityType, + NodeDescription targetNodeDescription) { - return determineDynamicLabels((Y) relatedNode, (Neo4jPersistentEntity) targetNodeDescription, inDatabase) + return determineDynamicLabels((Y) relatedNode, (Neo4jPersistentEntity) targetNodeDescription) .flatMap(t -> { Y entity = t.getT1(); DynamicLabels dynamicLabels = t.getT2(); return neo4jClient .query(() -> renderer.render(cypherGenerator.prepareSaveOf(targetNodeDescription, dynamicLabels))) - .in(inDatabase).bind((Y) entity).with(neo4jMappingContext.getRequiredBinderFunctionFor(entityType)) + .bind((Y) entity).with(neo4jMappingContext.getRequiredBinderFunctionFor(entityType)) .fetchAs(Long.class).one(); }).switchIfEmpty(Mono.defer(() -> { if (((Neo4jPersistentEntity) targetNodeDescription).hasVersionProperty()) { @@ -732,16 +719,10 @@ private Mono saveRelatedNode(Object relatedNode, Class entityType, })); } - private Mono getDatabaseName() { - - return this.databaseSelectionProvider.getDatabaseSelection() - .switchIfEmpty(Mono.just(DatabaseSelection.undecided())); - } - @Override public Mono> toExecutableQuery(PreparedQuery preparedQuery) { - return getDatabaseName().map(databaseName -> { + return Mono.defer(() -> { Class resultType = preparedQuery.getResultType(); QueryFragmentsAndParameters queryFragmentsAndParameters = preparedQuery.getQueryFragmentsAndParameters(); String cypherQuery = queryFragmentsAndParameters.getCypherQuery(); @@ -772,12 +753,12 @@ public Mono> toExecutableQuery(PreparedQuery preparedQ } ReactiveNeo4jClient.MappingSpec mappingSpec = this.neo4jClient.query(cypherQuery) - .in(databaseName.getValue()).bindAll(finalParameters).fetchAs(resultType); + .bindAll(finalParameters).fetchAs(resultType); ReactiveNeo4jClient.RecordFetchSpec fetchSpec = preparedQuery.getOptionalMappingFunction() .map(mappingFunction -> mappingSpec.mappedBy(mappingFunction)).orElse(mappingSpec); - return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec); + return Mono.just(new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec)); }); } diff --git a/src/test/java/org/springframework/data/neo4j/core/Neo4jClientTest.java b/src/test/java/org/springframework/data/neo4j/core/Neo4jClientTest.java index e39df7b5bd..b0580219c6 100644 --- a/src/test/java/org/springframework/data/neo4j/core/Neo4jClientTest.java +++ b/src/test/java/org/springframework/data/neo4j/core/Neo4jClientTest.java @@ -185,6 +185,32 @@ void databaseSelectionShouldPreventIllegalValues() { verify(driver).defaultTypeSystem(); } + @Test // GH-2159 + void databaseSelectionBeanShouldGetRespectedIfExisting() { + prepareMocks(); + when(session.run(anyString(), anyMap())).thenReturn(result); + when(result.stream()).thenReturn(Stream.of(record1, record2)); + when(result.consume()).thenReturn(resultSummary); + + String databaseName = "customDatabaseSelection"; + DatabaseSelectionProvider databaseSelection = DatabaseSelectionProvider + .createStaticDatabaseSelectionProvider(databaseName); + + Neo4jClient client = Neo4jClient.create(driver, databaseSelection); + + String query = "RETURN 1"; + client.query(query).fetch().first(); + verifyDatabaseSelection(databaseName); + + verify(session).run(eq(query), anyMap()); + verify(result).stream(); + verify(result).consume(); + verify(resultSummary).notifications(); + verify(record1).asMap(); + verify(session).close(); + + } + @Nested @DisplayName("Callback handling should feel good") class CallbackHandlingShouldFeelGood { diff --git a/src/test/java/org/springframework/data/neo4j/core/ReactiveNeo4jClientTest.java b/src/test/java/org/springframework/data/neo4j/core/ReactiveNeo4jClientTest.java index 574cdf8ddd..c7dc575428 100644 --- a/src/test/java/org/springframework/data/neo4j/core/ReactiveNeo4jClientTest.java +++ b/src/test/java/org/springframework/data/neo4j/core/ReactiveNeo4jClientTest.java @@ -15,26 +15,6 @@ */ package org.springframework.data.neo4j.core; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import java.time.LocalDate; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Nested; @@ -53,6 +33,25 @@ import org.neo4j.driver.reactive.RxTransaction; import org.neo4j.driver.summary.ResultSummary; import org.neo4j.driver.types.TypeSystem; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.time.LocalDate; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; /** * @author Michael J. Simons @@ -190,6 +189,40 @@ void databaseSelectionShouldPreventIllegalValues() { verify(driver).defaultTypeSystem(); } + @Test // GH-2159 + void databaseSelectionBeanShouldGetRespectedIfExisting() { + + prepareMocks(); + + when(transaction.run(anyString(), anyMap())).thenReturn(result); + when(transaction.commit()).thenReturn(Mono.empty()); + when(result.records()).thenReturn(Flux.just(record1, record2)); + when(result.consume()).thenReturn(Mono.just(resultSummary)); + + String databaseName = "customDatabaseSelection"; + String cypher = "RETURN 1"; + ReactiveDatabaseSelectionProvider databaseSelection = ReactiveDatabaseSelectionProvider + .createStaticDatabaseSelectionProvider(databaseName); + + + ReactiveNeo4jClient client = ReactiveNeo4jClient.create(driver, databaseSelection); + + StepVerifier.create(client.query(cypher).fetch().first()) + .expectNextCount(1L) + .verifyComplete(); + + verifyDatabaseSelection(databaseName); + + verify(transaction).run(eq(cypher), anyMap()); + verify(result).records(); + verify(result).consume(); + verify(resultSummary).notifications(); + verify(record1).asMap(); + verify(transaction).commit(); + verify(transaction).rollback(); + verify(session).close(); + } + @Nested @DisplayName("Callback handling should feel good") class CallbackHandlingShouldFeelGood { diff --git a/src/test/java/org/springframework/data/neo4j/core/TransactionHandlingTest.java b/src/test/java/org/springframework/data/neo4j/core/TransactionHandlingTest.java index ec4e545d0d..ae6db91c18 100644 --- a/src/test/java/org/springframework/data/neo4j/core/TransactionHandlingTest.java +++ b/src/test/java/org/springframework/data/neo4j/core/TransactionHandlingTest.java @@ -91,7 +91,7 @@ void shouldCallCloseOnSession() { when(driver.session(any(SessionConfig.class))).thenReturn(session); // Make template acquire session - DefaultNeo4jClient neo4jClient = new DefaultNeo4jClient(driver); + DefaultNeo4jClient neo4jClient = new DefaultNeo4jClient(driver, null); try (DefaultNeo4jClient.AutoCloseableQueryRunner s = neo4jClient.getQueryRunner("aDatabase")) { s.run("MATCH (n) RETURN n"); } @@ -124,7 +124,7 @@ void shouldNotInvokeCloseOnTransaction() { Neo4jTransactionManager txManager = new Neo4jTransactionManager(driver); TransactionTemplate txTemplate = new TransactionTemplate(txManager); - DefaultNeo4jClient neo4jClient = new DefaultNeo4jClient(driver); + DefaultNeo4jClient neo4jClient = new DefaultNeo4jClient(driver, null); txTemplate.execute(tx -> { try (DefaultNeo4jClient.AutoCloseableQueryRunner s = neo4jClient.getQueryRunner(null)) { s.run("MATCH (n) RETURN n"); @@ -154,7 +154,7 @@ class ReactiveNeo4jClientTest { @Test void shouldNotOpenTransactionsWithoutSubscription() { - DefaultReactiveNeo4jClient neo4jClient = new DefaultReactiveNeo4jClient(driver); + DefaultReactiveNeo4jClient neo4jClient = new DefaultReactiveNeo4jClient(driver, null); neo4jClient.query("RETURN 1").in("aDatabase").fetch().one(); verify(driver, never()).rxSession(any(SessionConfig.class)); @@ -169,7 +169,7 @@ void shouldCloseUnmanagedSessionOnComplete() { when(transaction.commit()).thenReturn(Mono.empty()); when(session.close()).thenReturn(Mono.empty()); - DefaultReactiveNeo4jClient neo4jClient = new DefaultReactiveNeo4jClient(driver); + DefaultReactiveNeo4jClient neo4jClient = new DefaultReactiveNeo4jClient(driver, null); Mono sequence = neo4jClient.doInQueryRunnerForMono("aDatabase", tx -> Mono.just("1")); @@ -191,7 +191,7 @@ void shouldCloseUnmanagedSessionOnError() { when(transaction.rollback()).thenReturn(Mono.empty()); when(session.close()).thenReturn(Mono.empty()); - DefaultReactiveNeo4jClient neo4jClient = new DefaultReactiveNeo4jClient(driver); + DefaultReactiveNeo4jClient neo4jClient = new DefaultReactiveNeo4jClient(driver, null); Mono sequence = neo4jClient.doInQueryRunnerForMono("aDatabase", tx -> Mono.error(new SomeException())); diff --git a/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain1/Domain1Config.java b/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain1/Domain1Config.java index f766ed4ca2..514909df20 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain1/Domain1Config.java +++ b/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain1/Domain1Config.java @@ -63,10 +63,9 @@ public Neo4jClient domain1Client(@Qualifier("domain1Driver") Driver driver) { @Primary @Bean public Neo4jOperations domain1Template( @Qualifier("domain1Client") Neo4jClient domain1Client, - @Qualifier("domain1Context") Neo4jMappingContext domain1Context, - @Qualifier("domain1Selection") DatabaseSelectionProvider domain1Selection + @Qualifier("domain1Context") Neo4jMappingContext domain1Context ) { - return new Neo4jTemplate(domain1Client, domain1Context, domain1Selection); + return new Neo4jTemplate(domain1Client, domain1Context); } @Primary @Bean diff --git a/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain2/Domain2Config.java b/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain2/Domain2Config.java index 9aca2787ba..8eb26e4055 100644 --- a/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain2/Domain2Config.java +++ b/src/test/java/org/springframework/data/neo4j/integration/multiple_ctx_imperative/domain2/Domain2Config.java @@ -62,10 +62,9 @@ public Neo4jClient domain2Client(@Qualifier("domain2Driver") Driver driver) { @Bean public Neo4jOperations domain2Template( @Qualifier("domain2Client") Neo4jClient domain2Client, - @Qualifier("domain2Context") Neo4jMappingContext domain2Context, - @Qualifier("domain2Selection") DatabaseSelectionProvider domain2Selection + @Qualifier("domain2Context") Neo4jMappingContext domain2Context ) { - return new Neo4jTemplate(domain2Client, domain2Context, domain2Selection); + return new Neo4jTemplate(domain2Client, domain2Context); } @Bean