diff --git a/driver/src/main/java/org/neo4j/driver/exceptions/AuthorizationExpiredException.java b/driver/src/main/java/org/neo4j/driver/exceptions/AuthorizationExpiredException.java new file mode 100644 index 0000000000..451ec7667d --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/exceptions/AuthorizationExpiredException.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.exceptions; + +/** + * The authorization info maintained on the server has expired. The client should reconnect. + *

+ * Error code: Neo.ClientError.Security.AuthorizationExpired + */ +public class AuthorizationExpiredException extends SecurityException +{ + public static final String DESCRIPTION = "Authorization information kept on the server has expired, this connection is no longer valid."; + + public AuthorizationExpiredException( String code, String message ) + { + super( code, message ); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index a5e70c557f..76683b1d54 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -28,6 +28,7 @@ import org.neo4j.driver.Session; import org.neo4j.driver.TransactionConfig; import org.neo4j.driver.async.ResultCursor; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.internal.BookmarkHolder; import org.neo4j.driver.internal.cursor.AsyncResultCursor; @@ -136,8 +137,14 @@ public CompletionStage beginAsync(Bookmark initialBookmark { if ( beginError != null ) { - // release connection if begin failed, transaction can't be started - connection.release(); + if ( beginError instanceof AuthorizationExpiredException ) + { + connection.terminateAndRelease( AuthorizationExpiredException.DESCRIPTION ); + } + else + { + connection.release(); + } throw Futures.asCompletionException( beginError ); } return this; @@ -169,8 +176,8 @@ else if ( state.value == State.ROLLED_BACK ) else { return resultCursors.retrieveNotConsumedError() - .thenCompose( error -> doCommitAsync().handle( handleCommitOrRollback( error ) ) ) - .whenComplete( ( ignore, error ) -> transactionClosed( error == null ) ); + .thenCompose( error -> doCommitAsync().handle( handleCommitOrRollback( error ) ) ) + .whenComplete( ( ignore, error ) -> handleTransactionCompletion( State.COMMITTED, error ) ); } } @@ -187,8 +194,8 @@ else if ( state.value == State.ROLLED_BACK ) else { return resultCursors.retrieveNotConsumedError() - .thenCompose( error -> doRollbackAsync().handle( handleCommitOrRollback( error ) ) ) - .whenComplete( ( ignore, error ) -> transactionClosed( false ) ); + .thenCompose( error -> doRollbackAsync().handle( handleCommitOrRollback( error ) ) ) + .whenComplete( ( ignore, error ) -> handleTransactionCompletion( State.ROLLED_BACK, error ) ); } } @@ -274,16 +281,16 @@ private static BiFunction handleCommitOrRollback( Throwable }; } - private void transactionClosed( boolean isCommitted ) + private void handleTransactionCompletion( State onSuccessState, Throwable throwable ) { - if ( isCommitted ) + if ( throwable instanceof AuthorizationExpiredException ) { - state = StateHolder.of( State.COMMITTED ); - } - else - { - state = StateHolder.of( State.ROLLED_BACK ); + markTerminated( throwable ); + connection.terminateAndRelease( AuthorizationExpiredException.DESCRIPTION ); + return; } + + state = StateHolder.of( onSuccessState ); connection.release(); // release in background } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/AuthorizationStateListener.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/AuthorizationStateListener.java new file mode 100644 index 0000000000..3a7dac6297 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/AuthorizationStateListener.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async.connection; + +import io.netty.channel.Channel; + +import org.neo4j.driver.exceptions.AuthorizationExpiredException; + +/** + * Listener for authorization info state maintained on the server side. + */ +public interface AuthorizationStateListener +{ + /** + * Notifies the listener that the credentials stored on the server side have expired. + * + * @param e the {@link AuthorizationExpiredException} exception. + * @param channel the channel that received the error. + */ + void onExpired( AuthorizationExpiredException e, Channel channel ); +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelAttributes.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelAttributes.java index 35fddcfc94..a8773211ef 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelAttributes.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelAttributes.java @@ -40,6 +40,7 @@ public final class ChannelAttributes private static final AttributeKey LAST_USED_TIMESTAMP = newInstance( "lastUsedTimestamp" ); private static final AttributeKey MESSAGE_DISPATCHER = newInstance( "messageDispatcher" ); private static final AttributeKey TERMINATION_REASON = newInstance( "terminationReason" ); + private static final AttributeKey AUTHORIZATION_STATE_LISTENER = newInstance( "authorizationStateListener" ); private ChannelAttributes() { @@ -145,6 +146,16 @@ public static void setTerminationReason( Channel channel, String reason ) setOnce( channel, TERMINATION_REASON, reason ); } + public static AuthorizationStateListener authorizationStateListener( Channel channel ) + { + return get( channel, AUTHORIZATION_STATE_LISTENER ); + } + + public static void setAuthorizationStateListener( Channel channel, AuthorizationStateListener authorizationStateListener ) + { + set( channel, AUTHORIZATION_STATE_LISTENER, authorizationStateListener ); + } + private static T get( Channel channel, AttributeKey key ) { return channel.attr( key ).get(); diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java index 7f7f5236e7..dde5d9558b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java @@ -25,18 +25,19 @@ import java.util.Map; import java.util.Queue; -import org.neo4j.driver.exceptions.ServiceUnavailableException; +import org.neo4j.driver.Logger; +import org.neo4j.driver.Logging; +import org.neo4j.driver.Value; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; +import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.internal.handlers.ResetResponseHandler; import org.neo4j.driver.internal.logging.ChannelActivityLogger; import org.neo4j.driver.internal.messaging.ResponseMessageHandler; import org.neo4j.driver.internal.spi.ResponseHandler; import org.neo4j.driver.internal.util.ErrorUtil; -import org.neo4j.driver.Logger; -import org.neo4j.driver.Logging; -import org.neo4j.driver.Value; -import org.neo4j.driver.exceptions.ClientException; import static java.util.Objects.requireNonNull; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.authorizationStateListener; import static org.neo4j.driver.internal.messaging.request.ResetMessage.RESET; import static org.neo4j.driver.internal.util.ErrorUtil.addSuppressed; @@ -114,9 +115,17 @@ public void handleFailureMessage( String code, String message ) return; } - // write a RESET to "acknowledge" the failure - enqueue( new ResetResponseHandler( this ) ); - channel.writeAndFlush( RESET, channel.voidPromise() ); + Throwable currentError = this.currentError; + if ( currentError instanceof AuthorizationExpiredException ) + { + authorizationStateListener( channel ).onExpired( (AuthorizationExpiredException) currentError, channel ); + } + else + { + // write a RESET to "acknowledge" the failure + enqueue( new ResetResponseHandler( this ) ); + channel.writeAndFlush( RESET, channel.voidPromise() ); + } ResponseHandler handler = removeHandler(); handler.onFailure( currentError ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java index f88fdaafdc..d8def041cb 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java @@ -46,6 +46,7 @@ import org.neo4j.driver.internal.util.Futures; import static java.lang.String.format; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setAuthorizationStateListener; import static org.neo4j.driver.internal.util.Futures.combineErrors; import static org.neo4j.driver.internal.util.Futures.completeWithNullIfNoError; @@ -66,19 +67,22 @@ public class ConnectionPoolImpl implements ConnectionPool private final ConnectionFactory connectionFactory; public ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, PoolSettings settings, MetricsListener metricsListener, Logging logging, - Clock clock, boolean ownsEventLoopGroup ) + Clock clock, boolean ownsEventLoopGroup ) { - this( connector, bootstrap, new NettyChannelTracker( metricsListener, bootstrap.config().group().next(), logging ), settings, metricsListener, logging, - clock, ownsEventLoopGroup, new NetworkConnectionFactory( clock, metricsListener ) ); + this( connector, bootstrap, new NettyChannelTracker( metricsListener, bootstrap.config().group().next(), logging ), + new NettyChannelHealthChecker( settings, clock, logging ), settings, metricsListener, logging, + clock, ownsEventLoopGroup, new NetworkConnectionFactory( clock, metricsListener ) ); } - public ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, NettyChannelTracker nettyChannelTracker, PoolSettings settings, - MetricsListener metricsListener, Logging logging, Clock clock, boolean ownsEventLoopGroup, ConnectionFactory connectionFactory ) + protected ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, NettyChannelTracker nettyChannelTracker, + NettyChannelHealthChecker nettyChannelHealthChecker, PoolSettings settings, + MetricsListener metricsListener, Logging logging, Clock clock, boolean ownsEventLoopGroup, + ConnectionFactory connectionFactory ) { this.connector = connector; this.bootstrap = bootstrap; this.nettyChannelTracker = nettyChannelTracker; - this.channelHealthChecker = new NettyChannelHealthChecker( settings, clock, logging ); + this.channelHealthChecker = nettyChannelHealthChecker; this.settings = settings; this.metricsListener = metricsListener; this.log = logging.getLog( ConnectionPool.class.getSimpleName() ); @@ -104,6 +108,7 @@ public CompletionStage acquire( BoltServerAddress address ) { processAcquisitionError( pool, address, error ); assertNotClosed( address, channel, pool ); + setAuthorizationStateListener( channel, channelHealthChecker ); Connection connection = connectionFactory.createConnection( channel, pool ); metricsListener.afterAcquiredOrCreated( pool.id(), acquireEvent ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java index 156c94fca7..792ad0bbd1 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java @@ -23,27 +23,34 @@ import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import org.neo4j.driver.Logger; +import org.neo4j.driver.Logging; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; +import org.neo4j.driver.internal.async.connection.AuthorizationStateListener; import org.neo4j.driver.internal.handlers.PingResponseHandler; import org.neo4j.driver.internal.messaging.request.ResetMessage; import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.Logger; -import org.neo4j.driver.Logging; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.creationTimestamp; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.lastUsedTimestamp; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.messageDispatcher; -public class NettyChannelHealthChecker implements ChannelHealthChecker +public class NettyChannelHealthChecker implements ChannelHealthChecker, AuthorizationStateListener { private final PoolSettings poolSettings; private final Clock clock; private final Logger log; + private final AtomicReference> minCreationTimestampMillisOpt; public NettyChannelHealthChecker( PoolSettings poolSettings, Clock clock, Logging logging ) { this.poolSettings = poolSettings; this.clock = clock; this.log = logging.getLog( getClass().getSimpleName() ); + this.minCreationTimestampMillisOpt = new AtomicReference<>( Optional.empty() ); } @Override @@ -60,11 +67,27 @@ public Future isHealthy( Channel channel ) return ACTIVE.isHealthy( channel ); } + @Override + public void onExpired( AuthorizationExpiredException e, Channel channel ) + { + long ts = creationTimestamp( channel ); + // Override current value ONLY if the new one is greater + minCreationTimestampMillisOpt.getAndUpdate( prev -> Optional.of( prev.filter( prevTs -> ts <= prevTs ).orElse( ts ) ) ); + } + private boolean isTooOld( Channel channel ) { - if ( poolSettings.maxConnectionLifetimeEnabled() ) + long creationTimestampMillis = creationTimestamp( channel ); + Optional minCreationTimestampMillisOpt = this.minCreationTimestampMillisOpt.get(); + + if ( minCreationTimestampMillisOpt.isPresent() && creationTimestampMillis <= minCreationTimestampMillisOpt.get() ) + { + log.trace( "The channel %s is marked for closure as its creation timestamp is older than or equal to the acceptable minimum timestamp: %s <= %s", + channel, creationTimestampMillis, minCreationTimestampMillisOpt.get() ); + return true; + } + else if ( poolSettings.maxConnectionLifetimeEnabled() ) { - long creationTimestampMillis = creationTimestamp( channel ); long currentTimestampMillis = clock.millis(); long ageMillis = currentTimestampMillis - creationTimestampMillis; @@ -74,7 +97,7 @@ private boolean isTooOld( Channel channel ) if ( tooOld ) { log.trace( "Failed acquire channel %s from the pool because it is too old: %s > %s", - channel, ageMillis, maxAgeMillis ); + channel, ageMillis, maxAgeMillis ); } return tooOld; diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/SessionPullResponseCompletionListener.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/SessionPullResponseCompletionListener.java index 56a38f5709..73b9473829 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/SessionPullResponseCompletionListener.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/SessionPullResponseCompletionListener.java @@ -21,6 +21,7 @@ import java.util.Map; import org.neo4j.driver.Value; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.internal.BookmarkHolder; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.util.MetadataExtractor; @@ -48,7 +49,14 @@ public void afterSuccess( Map metadata ) @Override public void afterFailure( Throwable error ) { - releaseConnection(); + if ( error instanceof AuthorizationExpiredException ) + { + connection.terminateAndRelease( AuthorizationExpiredException.DESCRIPTION ); + } + else + { + releaseConnection(); + } } private void releaseConnection() diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java index e54913fc35..8b3a19a0f5 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java @@ -21,7 +21,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelPromise; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -29,7 +28,6 @@ import org.neo4j.driver.Bookmark; import org.neo4j.driver.Query; import org.neo4j.driver.TransactionConfig; -import org.neo4j.driver.Value; import org.neo4j.driver.internal.BookmarkHolder; import org.neo4j.driver.internal.DatabaseName; import org.neo4j.driver.internal.async.UnmanagedTransaction; @@ -123,19 +121,10 @@ public CompletionStage beginTransaction( Connection connection, Bookmark b return Futures.failedFuture( error ); } + CompletableFuture beginTxFuture = new CompletableFuture<>(); BeginMessage beginMessage = new BeginMessage( bookmark, config, connection.databaseName(), connection.mode() ); - - if ( bookmark.isEmpty() ) - { - connection.write( beginMessage, NoOpResponseHandler.INSTANCE ); - return Futures.completedWithNull(); - } - else - { - CompletableFuture beginTxFuture = new CompletableFuture<>(); - connection.writeAndFlush( beginMessage, new BeginTxResponseHandler( beginTxFuture ) ); - return beginTxFuture; - } + connection.writeAndFlush( beginMessage, new BeginTxResponseHandler( beginTxFuture ) ); + return beginTxFuture; } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogic.java b/driver/src/main/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogic.java index 872238962a..3847142f79 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogic.java +++ b/driver/src/main/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogic.java @@ -39,6 +39,7 @@ import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.exceptions.SessionExpiredException; @@ -155,7 +156,8 @@ protected boolean canRetryOn( Throwable error ) @Experimental public static boolean isRetryable( Throwable error ) { - return error instanceof SessionExpiredException || error instanceof ServiceUnavailableException || isTransientError( error ); + return error instanceof SessionExpiredException || error instanceof ServiceUnavailableException || error instanceof AuthorizationExpiredException || + isTransientError( error ); } /** diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java b/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java index ce24fa0c8a..d4b4fa4df7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java @@ -25,6 +25,7 @@ import java.util.stream.Stream; import org.neo4j.driver.exceptions.AuthenticationException; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.DatabaseException; import org.neo4j.driver.exceptions.FatalDiscoveryException; @@ -75,6 +76,10 @@ else if ( code.equalsIgnoreCase( "Neo.ClientError.Database.DatabaseNotFound" ) ) { return new FatalDiscoveryException( code, message ); } + else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.AuthorizationExpired" ) ) + { + return new AuthorizationExpiredException( code, message ); + } else { return new ClientException( code, message ); diff --git a/driver/src/test/java/org/neo4j/driver/integration/ConnectionHandlingIT.java b/driver/src/test/java/org/neo4j/driver/integration/ConnectionHandlingIT.java index 322a468a6a..29376014ab 100644 --- a/driver/src/test/java/org/neo4j/driver/integration/ConnectionHandlingIT.java +++ b/driver/src/test/java/org/neo4j/driver/integration/ConnectionHandlingIT.java @@ -32,15 +32,17 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import org.neo4j.driver.AuthToken; import org.neo4j.driver.Config; import org.neo4j.driver.Driver; import org.neo4j.driver.Logging; +import org.neo4j.driver.QueryRunner; import org.neo4j.driver.Record; -import org.neo4j.driver.Session; import org.neo4j.driver.Result; -import org.neo4j.driver.QueryRunner; +import org.neo4j.driver.Session; import org.neo4j.driver.Transaction; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.internal.BoltServerAddress; @@ -59,8 +61,8 @@ import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.EnabledOnNeo4jWith; -import org.neo4j.driver.reactive.RxSession; import org.neo4j.driver.reactive.RxResult; +import org.neo4j.driver.reactive.RxSession; import org.neo4j.driver.reactive.RxTransaction; import org.neo4j.driver.summary.ResultSummary; import org.neo4j.driver.util.DatabaseExtension; @@ -367,46 +369,80 @@ void resultSummaryShouldReleaseConnectionUsedBySessionRun() throws Throwable @Test @EnabledOnNeo4jWith( BOLT_V4 ) - void txCommitShouldReleaseConnectionUsedByBeginTx() throws Throwable + void txCommitShouldReleaseConnectionUsedByBeginTx() { - RxSession session = driver.rxSession(); - - StepVerifier.create( Mono.from( session.beginTransaction() ).doOnSuccess( tx -> { - Connection connection1 = connectionPool.lastAcquiredConnectionSpy; - verify( connection1, never() ).release(); - - RxResult result = tx.run( "UNWIND [1,2,3,4] AS a RETURN a" ); - StepVerifier.create( Flux.from( result.records() ).map( record -> record.get( "a" ).asInt() ) ) - .expectNext( 1, 2, 3, 4 ).verifyComplete(); + AtomicReference connection1Ref = new AtomicReference<>(); - StepVerifier.create( Mono.from( tx.commit() ) ).verifyComplete(); - Connection connection2 = connectionPool.lastAcquiredConnectionSpy; - assertSame( connection1, connection2 ); - verify( connection1 ).release(); - - } ) ).expectNextCount( 1 ).verifyComplete(); + Function> sessionToRecordPublisher = ( RxSession session ) -> Flux.usingWhen( + Mono.fromDirect( session.beginTransaction() ), + tx -> + { + connection1Ref.set( connectionPool.lastAcquiredConnectionSpy ); + verify( connection1Ref.get(), never() ).release(); + return tx.run( "UNWIND [1,2,3,4] AS a RETURN a" ).records(); + }, + RxTransaction::commit, + ( tx, error ) -> tx.rollback(), + RxTransaction::rollback + ); + + Flux resultsFlux = Flux.usingWhen( + Mono.fromSupplier( driver::rxSession ), + sessionToRecordPublisher, + session -> + { + Connection connection2 = connectionPool.lastAcquiredConnectionSpy; + assertSame( connection1Ref.get(), connection2 ); + verify( connection1Ref.get() ).release(); + return Mono.empty(); + }, + ( session, error ) -> session.close(), + RxSession::close + ).map( record -> record.get( "a" ).asInt() ); + + StepVerifier.create( resultsFlux ) + .expectNext( 1, 2, 3, 4 ) + .expectComplete() + .verify(); } @Test @EnabledOnNeo4jWith( BOLT_V4 ) - void txRollbackShouldReleaseConnectionUsedByBeginTx() throws Throwable + void txRollbackShouldReleaseConnectionUsedByBeginTx() { - RxSession session = driver.rxSession(); - - StepVerifier.create( Mono.from( session.beginTransaction() ).doOnSuccess( tx -> { - Connection connection1 = connectionPool.lastAcquiredConnectionSpy; - verify( connection1, never() ).release(); + AtomicReference connection1Ref = new AtomicReference<>(); - RxResult result = tx.run( "UNWIND [1,2,3,4] AS a RETURN a" ); - StepVerifier.create( Flux.from( result.records() ).map( record -> record.get( "a" ).asInt() ) ) - .expectNext( 1, 2, 3, 4 ).verifyComplete(); - - StepVerifier.create( Mono.from( tx.rollback() ) ).verifyComplete(); - Connection connection2 = connectionPool.lastAcquiredConnectionSpy; - assertSame( connection1, connection2 ); - verify( connection1 ).release(); - - } ) ).expectNextCount( 1 ).verifyComplete(); + Function> sessionToRecordPublisher = ( RxSession session ) -> Flux.usingWhen( + Mono.fromDirect( session.beginTransaction() ), + tx -> + { + connection1Ref.set( connectionPool.lastAcquiredConnectionSpy ); + verify( connection1Ref.get(), never() ).release(); + return tx.run( "UNWIND [1,2,3,4] AS a RETURN a" ).records(); + }, + RxTransaction::rollback, + ( tx, error ) -> tx.rollback(), + RxTransaction::rollback + ); + + Flux resultsFlux = Flux.usingWhen( + Mono.fromSupplier( driver::rxSession ), + sessionToRecordPublisher, + session -> + { + Connection connection2 = connectionPool.lastAcquiredConnectionSpy; + assertSame( connection1Ref.get(), connection2 ); + verify( connection1Ref.get() ).release(); + return Mono.empty(); + }, + ( session, error ) -> session.close(), + RxSession::close + ).map( record -> record.get( "a" ).asInt() ); + + StepVerifier.create( resultsFlux ) + .expectNext( 1, 2, 3, 4 ) + .expectComplete() + .verify(); } @Test diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java index 39bbd5330f..0347a193bf 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java @@ -26,11 +26,11 @@ import org.mockito.InOrder; import org.neo4j.driver.AccessMode; +import org.neo4j.driver.Bookmark; import org.neo4j.driver.Query; import org.neo4j.driver.TransactionConfig; import org.neo4j.driver.async.ResultCursor; import org.neo4j.driver.exceptions.ClientException; -import org.neo4j.driver.Bookmark; import org.neo4j.driver.internal.InternalBookmark; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.request.PullMessage; @@ -65,12 +65,12 @@ import static org.neo4j.driver.util.TestUtil.connectionMock; import static org.neo4j.driver.util.TestUtil.newSession; import static org.neo4j.driver.util.TestUtil.setupFailingBegin; -import static org.neo4j.driver.util.TestUtil.setupSuccessfulRunRx; import static org.neo4j.driver.util.TestUtil.setupSuccessfulRunAndPull; +import static org.neo4j.driver.util.TestUtil.setupSuccessfulRunRx; import static org.neo4j.driver.util.TestUtil.verifyBeginTx; import static org.neo4j.driver.util.TestUtil.verifyRollbackTx; -import static org.neo4j.driver.util.TestUtil.verifyRunRx; import static org.neo4j.driver.util.TestUtil.verifyRunAndPull; +import static org.neo4j.driver.util.TestUtil.verifyRunRx; class NetworkSessionTest { @@ -271,7 +271,7 @@ void bookmarkIsPropagatedFromSession() UnmanagedTransaction tx = beginTransaction( session ); assertNotNull( tx ); - verifyBeginTx( connection, bookmark ); + verifyBeginTx( connection ); } @Test @@ -292,7 +292,7 @@ void bookmarkIsPropagatedBetweenTransactions() assertEquals( bookmark1, session.lastBookmark() ); UnmanagedTransaction tx2 = beginTransaction( session ); - verifyBeginTx( connection, bookmark1 ); + verifyBeginTx( connection, 2 ); await( tx2.commitAsync() ); assertEquals( bookmark2, session.lastBookmark() ); @@ -396,7 +396,7 @@ void shouldRunAfterBeginTxFailureOnBookmark() run( session, "RETURN 2" ); verify( connectionProvider, times( 2 ) ).acquireConnection( any( ConnectionContext.class ) ); - verifyBeginTx( connection1, bookmark ); + verifyBeginTx( connection1 ); verifyRunAndPull( connection2, "RETURN 2" ); } @@ -420,8 +420,8 @@ void shouldBeginTxAfterBeginTxFailureOnBookmark() beginTransaction( session ); verify( connectionProvider, times( 2 ) ).acquireConnection( any( ConnectionContext.class ) ); - verifyBeginTx( connection1, bookmark ); - verifyBeginTx( connection2, bookmark ); + verifyBeginTx( connection1 ); + verifyBeginTx( connection2 ); } @Test diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index de85ccc8ca..831c579ffc 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -126,7 +126,7 @@ void shouldFlushWhenBookmarkGiven() beginTx( connection, bookmark ); - verifyBeginTx( connection, bookmark ); + verifyBeginTx( connection ); verify( connection, never() ).write( any(), any(), any(), any() ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelAttributesTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelAttributesTest.java index d63ad4d8cc..d3d294ead2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelAttributesTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelAttributesTest.java @@ -30,6 +30,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.authorizationStateListener; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.connectionId; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.creationTimestamp; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.lastUsedTimestamp; @@ -38,6 +39,7 @@ import static org.neo4j.driver.internal.async.connection.ChannelAttributes.serverAddress; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.serverAgent; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.serverVersion; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setAuthorizationStateListener; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setConnectionId; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setCreationTimestamp; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setLastUsedTimestamp; @@ -197,4 +199,23 @@ void shouldFailToSetTerminationReasonTwice() assertThrows( IllegalStateException.class, () -> setTerminationReason( channel, "Reason 2" ) ); } + + @Test + void shouldSetAndGetAuthorizationStateListener() + { + AuthorizationStateListener listener = mock( AuthorizationStateListener.class ); + setAuthorizationStateListener( channel, listener ); + assertEquals( listener, authorizationStateListener( channel ) ); + } + + @Test + void shouldAllowOverridingAuthorizationStateListener() + { + AuthorizationStateListener listener = mock( AuthorizationStateListener.class ); + setAuthorizationStateListener( channel, listener ); + assertEquals( listener, authorizationStateListener( channel ) ); + AuthorizationStateListener newListener = mock( AuthorizationStateListener.class ); + setAuthorizationStateListener( channel, newListener ); + assertEquals( newListener, authorizationStateListener( channel ) ); + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java index ab776aafb2..676dfc532f 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java @@ -19,21 +19,27 @@ package org.neo4j.driver.internal.async.pool; import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import java.util.HashSet; +import java.util.concurrent.ExecutionException; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.util.FakeClock; import static java.util.Arrays.asList; import static java.util.Collections.singleton; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.BoltServerAddress.LOCAL_DEFAULT; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.authorizationStateListener; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.internal.metrics.InternalAbstractMetrics.DEV_NULL_METRICS; @@ -111,6 +117,21 @@ void shouldNotClosePoolsWithActiveConnectionsWhenRetaining() assertTrue( pool.getPool( ADDRESS_3 ).isClosed() ); } + @Test + void shouldRegisterAuthorizationStateListenerWithChannel() throws ExecutionException, InterruptedException + { + NettyChannelTracker nettyChannelTracker = mock( NettyChannelTracker.class ); + NettyChannelHealthChecker nettyChannelHealthChecker = mock( NettyChannelHealthChecker.class ); + ArgumentCaptor channelArgumentCaptor = ArgumentCaptor.forClass( Channel.class ); + TestConnectionPool pool = newConnectionPool( nettyChannelTracker, nettyChannelHealthChecker ); + + pool.acquire( ADDRESS_1 ).toCompletableFuture().get(); + verify( nettyChannelTracker ).channelAcquired( channelArgumentCaptor.capture() ); + Channel channel = channelArgumentCaptor.getValue(); + + assertEquals( nettyChannelHealthChecker, authorizationStateListener( channel ) ); + } + private static PoolSettings newSettings() { return new PoolSettings( 10, 5000, -1, -1 ); @@ -118,7 +139,13 @@ private static PoolSettings newSettings() private static TestConnectionPool newConnectionPool( NettyChannelTracker nettyChannelTracker ) { - return new TestConnectionPool( mock( Bootstrap.class ), nettyChannelTracker, newSettings(), DEV_NULL_METRICS, DEV_NULL_LOGGING, - new FakeClock(), true ); + return newConnectionPool( nettyChannelTracker, mock( NettyChannelHealthChecker.class ) ); + } + + private static TestConnectionPool newConnectionPool( NettyChannelTracker nettyChannelTracker, NettyChannelHealthChecker nettyChannelHealthChecker ) + { + return new TestConnectionPool( mock( Bootstrap.class ), nettyChannelTracker, nettyChannelHealthChecker, newSettings(), DEV_NULL_METRICS, + DEV_NULL_LOGGING, + new FakeClock(), true ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java index 503e7729e0..5c56a046c4 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal.async.pool; +import io.netty.channel.Channel; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.concurrent.Future; import org.junit.jupiter.api.AfterEach; @@ -25,11 +26,16 @@ import org.junit.jupiter.api.Test; import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.neo4j.driver.Value; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.messaging.request.ResetMessage; import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.Value; import static org.hamcrest.Matchers.is; import static org.hamcrest.junit.MatcherAssert.assertThat; @@ -82,7 +88,7 @@ void shouldDropTooOldChannelsWhenMaxLifetimeEnabled() void shouldAllowVeryOldChannelsWhenMaxLifetimeDisabled() { PoolSettings settings = new PoolSettings( DEFAULT_MAX_CONNECTION_POOL_SIZE, - DEFAULT_CONNECTION_ACQUISITION_TIMEOUT, NOT_CONFIGURED, DEFAULT_IDLE_TIME_BEFORE_CONNECTION_TEST ); + DEFAULT_CONNECTION_ACQUISITION_TIMEOUT, NOT_CONFIGURED, DEFAULT_IDLE_TIME_BEFORE_CONNECTION_TEST ); NettyChannelHealthChecker healthChecker = newHealthChecker( settings, Clock.SYSTEM ); setCreationTimestamp( channel, 0 ); @@ -91,6 +97,55 @@ void shouldAllowVeryOldChannelsWhenMaxLifetimeDisabled() assertThat( await( healthy ), is( true ) ); } + @Test + void shouldFailAllConnectionsCreatedOnOrBeforeExpirationTimestamp() + { + PoolSettings settings = new PoolSettings( DEFAULT_MAX_CONNECTION_POOL_SIZE, + DEFAULT_CONNECTION_ACQUISITION_TIMEOUT, NOT_CONFIGURED, DEFAULT_IDLE_TIME_BEFORE_CONNECTION_TEST ); + Clock clock = Clock.SYSTEM; + NettyChannelHealthChecker healthChecker = newHealthChecker( settings, clock ); + + long initialTimestamp = clock.millis(); + List channels = IntStream.range( 0, 100 ).mapToObj( i -> + { + Channel channel = new EmbeddedChannel(); + setCreationTimestamp( channel, initialTimestamp + i ); + return channel; + } ).collect( Collectors.toList() ); + + int authorizationExpiredChannelIndex = channels.size() / 2 - 1; + healthChecker.onExpired( new AuthorizationExpiredException( "", "" ), channels.get( authorizationExpiredChannelIndex ) ); + + for ( int i = 0; i < channels.size(); i++ ) + { + Channel channel = channels.get( i ); + boolean health = Objects.requireNonNull( await( healthChecker.isHealthy( channel ) ) ); + boolean expectedHealth = i > authorizationExpiredChannelIndex; + assertEquals( expectedHealth, health, String.format( "Channel %d has failed the check", i ) ); + } + } + + @Test + void shouldUseGreatestExpirationTimestamp() + { + PoolSettings settings = new PoolSettings( DEFAULT_MAX_CONNECTION_POOL_SIZE, + DEFAULT_CONNECTION_ACQUISITION_TIMEOUT, NOT_CONFIGURED, DEFAULT_IDLE_TIME_BEFORE_CONNECTION_TEST ); + Clock clock = Clock.SYSTEM; + NettyChannelHealthChecker healthChecker = newHealthChecker( settings, clock ); + + long initialTimestamp = clock.millis(); + Channel channel1 = new EmbeddedChannel(); + Channel channel2 = new EmbeddedChannel(); + setCreationTimestamp( channel1, initialTimestamp ); + setCreationTimestamp( channel2, initialTimestamp + 100 ); + + healthChecker.onExpired( new AuthorizationExpiredException( "", "" ), channel2 ); + healthChecker.onExpired( new AuthorizationExpiredException( "", "" ), channel1 ); + + assertFalse( Objects.requireNonNull( await( healthChecker.isHealthy( channel1 ) ) ) ); + assertFalse( Objects.requireNonNull( await( healthChecker.isHealthy( channel2 ) ) ) ); + } + @Test void shouldKeepIdleConnectionWhenPingSucceeds() { diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/TestConnectionPool.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/TestConnectionPool.java index 1836f89ffb..d317231aa0 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/TestConnectionPool.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/TestConnectionPool.java @@ -47,11 +47,13 @@ public class TestConnectionPool extends ConnectionPoolImpl final Map channelPoolsByAddress = new HashMap<>(); private final NettyChannelTracker nettyChannelTracker; - public TestConnectionPool( Bootstrap bootstrap, NettyChannelTracker nettyChannelTracker, PoolSettings settings, - MetricsListener metricsListener, Logging logging, Clock clock, boolean ownsEventLoopGroup ) + public TestConnectionPool( Bootstrap bootstrap, NettyChannelTracker nettyChannelTracker, NettyChannelHealthChecker nettyChannelHealthChecker, + PoolSettings settings, + MetricsListener metricsListener, Logging logging, Clock clock, boolean ownsEventLoopGroup ) { - super( mock( ChannelConnector.class ), bootstrap, nettyChannelTracker, settings, metricsListener, logging, clock, ownsEventLoopGroup, - newConnectionFactory() ); + super( mock( ChannelConnector.class ), bootstrap, nettyChannelTracker, nettyChannelHealthChecker, settings, metricsListener, logging, clock, + ownsEventLoopGroup, + newConnectionFactory() ); this.nettyChannelTracker = nettyChannelTracker; } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java index 4f235b5535..cff0a9bba9 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java @@ -44,6 +44,7 @@ import org.neo4j.driver.exceptions.ProtocolException; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.async.connection.BootstrapFactory; +import org.neo4j.driver.internal.async.pool.NettyChannelHealthChecker; import org.neo4j.driver.internal.async.pool.NettyChannelTracker; import org.neo4j.driver.internal.async.pool.PoolSettings; import org.neo4j.driver.internal.async.pool.TestConnectionPool; @@ -314,8 +315,9 @@ private ConnectionPool newConnectionPool() PoolSettings poolSettings = new PoolSettings( 10, 5000, -1, -1 ); Bootstrap bootstrap = BootstrapFactory.newBootstrap( 1 ); NettyChannelTracker channelTracker = new NettyChannelTracker( metrics, bootstrap.config().group().next(), logging ); + NettyChannelHealthChecker channelHealthChecker = new NettyChannelHealthChecker( poolSettings, clock, logging ); - return new TestConnectionPool( bootstrap, channelTracker, poolSettings, metrics, logging, clock, true ); + return new TestConnectionPool( bootstrap, channelTracker, channelHealthChecker, poolSettings, metrics, logging, clock, true ); } private RoutingTableRegistryImpl newRoutingTables( ConnectionPool connectionPool, Rediscovery rediscovery ) diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java index 1825fe6642..b06607fa3c 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3Test.java @@ -50,7 +50,6 @@ import org.neo4j.driver.internal.cursor.AsyncResultCursor; import org.neo4j.driver.internal.handlers.BeginTxResponseHandler; import org.neo4j.driver.internal.handlers.CommitTxResponseHandler; -import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.handlers.PullAllResponseHandler; import org.neo4j.driver.internal.handlers.RollbackTxResponseHandler; import org.neo4j.driver.internal.handlers.RunResponseHandler; @@ -194,7 +193,8 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); - verify( connection ).write( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + verify( connection ).writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -217,7 +217,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); - verify( connection ).write( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + verify( connection ) + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java index 7c251799c1..5e75004476 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v4/BoltProtocolV4Test.java @@ -51,7 +51,6 @@ import org.neo4j.driver.internal.cursor.ResultCursorFactory; import org.neo4j.driver.internal.handlers.BeginTxResponseHandler; import org.neo4j.driver.internal.handlers.CommitTxResponseHandler; -import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.handlers.PullAllResponseHandler; import org.neo4j.driver.internal.handlers.RollbackTxResponseHandler; import org.neo4j.driver.internal.handlers.RunResponseHandler; @@ -187,7 +186,8 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .write( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -211,7 +211,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); - verify( connection ).write( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + verify( connection ) + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -521,15 +522,7 @@ private void verifyBeginInvoked( Connection connection, Bookmark bookmark, Trans { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); - - if( bookmark.isEmpty() ) - { - verify( connection ).write( eq( beginMessage ), eq( NoOpResponseHandler.INSTANCE ) ); - } - else - { - verify( connection ).write( eq( beginMessage ), beginHandlerCaptor.capture() ); - assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); - } + verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); + assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java index fb0f2af866..84e026eeda 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v41/BoltProtocolV41Test.java @@ -51,7 +51,6 @@ import org.neo4j.driver.internal.cursor.ResultCursorFactory; import org.neo4j.driver.internal.handlers.BeginTxResponseHandler; import org.neo4j.driver.internal.handlers.CommitTxResponseHandler; -import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.handlers.PullAllResponseHandler; import org.neo4j.driver.internal.handlers.RollbackTxResponseHandler; import org.neo4j.driver.internal.handlers.RunResponseHandler; @@ -192,7 +191,8 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .write( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -216,7 +216,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); - verify( connection ).write( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + verify( connection ) + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -517,16 +518,8 @@ private void verifyBeginInvoked( Connection connection, Bookmark bookmark, Trans { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); - - if ( bookmark.isEmpty() ) - { - verify( connection ).write( eq( beginMessage ), eq( NoOpResponseHandler.INSTANCE ) ); - } - else - { - verify( connection ).write( eq( beginMessage ), beginHandlerCaptor.capture() ); - assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); - } + verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); + assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } private static InternalAuthToken dummyAuthToken() diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java index aa34a34d30..006da3ce60 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v42/BoltProtocolV42Test.java @@ -51,7 +51,6 @@ import org.neo4j.driver.internal.cursor.ResultCursorFactory; import org.neo4j.driver.internal.handlers.BeginTxResponseHandler; import org.neo4j.driver.internal.handlers.CommitTxResponseHandler; -import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.handlers.PullAllResponseHandler; import org.neo4j.driver.internal.handlers.RollbackTxResponseHandler; import org.neo4j.driver.internal.handlers.RunResponseHandler; @@ -192,7 +191,8 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .write( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -216,7 +216,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); - verify( connection ).write( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + verify( connection ) + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -517,16 +518,8 @@ private void verifyBeginInvoked( Connection connection, Bookmark bookmark, Trans { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); - - if ( bookmark.isEmpty() ) - { - verify( connection ).write( eq( beginMessage ), eq( NoOpResponseHandler.INSTANCE ) ); - } - else - { - verify( connection ).write( eq( beginMessage ), beginHandlerCaptor.capture() ); - assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); - } + verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); + assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } private static InternalAuthToken dummyAuthToken() diff --git a/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java b/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java index a084cb817e..956b0a80e3 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java +++ b/driver/src/test/java/org/neo4j/driver/internal/messaging/v43/BoltProtocolV43Test.java @@ -51,7 +51,6 @@ import org.neo4j.driver.internal.cursor.ResultCursorFactory; import org.neo4j.driver.internal.handlers.BeginTxResponseHandler; import org.neo4j.driver.internal.handlers.CommitTxResponseHandler; -import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.handlers.PullAllResponseHandler; import org.neo4j.driver.internal.handlers.RollbackTxResponseHandler; import org.neo4j.driver.internal.handlers.RunResponseHandler; @@ -191,7 +190,8 @@ void shouldBeginTransactionWithoutBookmark() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), TransactionConfig.empty() ); verify( connection ) - .write( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), TransactionConfig.empty(), defaultDatabase(), WRITE ) ), + any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -215,7 +215,8 @@ void shouldBeginTransactionWithConfig() CompletionStage stage = protocol.beginTransaction( connection, InternalBookmark.empty(), txConfig ); - verify( connection ).write( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ), NoOpResponseHandler.INSTANCE ); + verify( connection ) + .writeAndFlush( eq( new BeginMessage( InternalBookmark.empty(), txConfig, defaultDatabase(), WRITE ) ), any( BeginTxResponseHandler.class ) ); assertNull( await( stage ) ); } @@ -516,16 +517,8 @@ private void verifyBeginInvoked( Connection connection, Bookmark bookmark, Trans { ArgumentCaptor beginHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); BeginMessage beginMessage = new BeginMessage( bookmark, config, databaseName, mode ); - - if ( bookmark.isEmpty() ) - { - verify( connection ).write( eq( beginMessage ), eq( NoOpResponseHandler.INSTANCE ) ); - } - else - { - verify( connection ).write( eq( beginMessage ), beginHandlerCaptor.capture() ); - assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); - } + verify( connection ).writeAndFlush( eq( beginMessage ), beginHandlerCaptor.capture() ); + assertThat( beginHandlerCaptor.getValue(), instanceOf( BeginTxResponseHandler.class ) ); } private static InternalAuthToken dummyAuthToken() diff --git a/driver/src/test/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogicTest.java b/driver/src/test/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogicTest.java index 5218057e0e..1e49198a4c 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogicTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/retry/ExponentialBackoffRetryLogicTest.java @@ -43,6 +43,7 @@ import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.exceptions.SessionExpiredException; @@ -776,13 +777,35 @@ void doesRetryOnClientExceptionWithRetryableCause() AtomicBoolean exceptionThrown = new AtomicBoolean( false ); String result = logic.retry( () -> - { - if ( exceptionThrown.compareAndSet( false, true ) ) - { - throw clientExceptionWithValidTerminationCause(); - } - return "Done"; - } ); + { + if ( exceptionThrown.compareAndSet( false, true ) ) + { + throw clientExceptionWithValidTerminationCause(); + } + return "Done"; + } ); + + assertEquals( "Done", result ); + } + + @Test + void doesRetryOnAuthorizationExpiredException() + { + Clock clock = mock( Clock.class ); + Logging logging = mock( Logging.class ); + Logger logger = mock( Logger.class ); + when( logging.getLog( anyString() ) ).thenReturn( logger ); + ExponentialBackoffRetryLogic logic = new ExponentialBackoffRetryLogic( RetrySettings.DEFAULT, eventExecutor, clock, logging ); + + AtomicBoolean exceptionThrown = new AtomicBoolean( false ); + String result = logic.retry( () -> + { + if ( exceptionThrown.compareAndSet( false, true ) ) + { + throw authorizationExpiredException(); + } + return "Done"; + } ); assertEquals( "Done", result ); } @@ -851,6 +874,28 @@ void doesRetryOnClientExceptionWithRetryableCauseAsync() assertEquals( "Done", result ); } + @Test + void doesRetryOnAuthorizationExpiredExceptionAsync() + { + Clock clock = mock( Clock.class ); + Logging logging = mock( Logging.class ); + Logger logger = mock( Logger.class ); + when( logging.getLog( anyString() ) ).thenReturn( logger ); + ExponentialBackoffRetryLogic logic = new ExponentialBackoffRetryLogic( RetrySettings.DEFAULT, eventExecutor, clock, logging ); + + AtomicBoolean exceptionThrown = new AtomicBoolean( false ); + String result = await( logic.retryAsync( () -> + { + if ( exceptionThrown.compareAndSet( false, true ) ) + { + throw authorizationExpiredException(); + } + return CompletableFuture.completedFuture( "Done" ); + } ) ); + + assertEquals( "Done", result ); + } + @Test void doesNotRetryOnRandomClientExceptionAsync() { @@ -918,6 +963,28 @@ void doesRetryOnClientExceptionWithRetryableCauseRx() assertEquals( "Done", result ); } + @Test + void doesRetryOnAuthorizationExpiredExceptionRx() + { + Clock clock = mock( Clock.class ); + Logging logging = mock( Logging.class ); + Logger logger = mock( Logger.class ); + when( logging.getLog( anyString() ) ).thenReturn( logger ); + ExponentialBackoffRetryLogic logic = new ExponentialBackoffRetryLogic( RetrySettings.DEFAULT, eventExecutor, clock, logging ); + + AtomicBoolean exceptionThrown = new AtomicBoolean( false ); + String result = await( Mono.from( logic.retryRx( Mono.fromSupplier( () -> + { + if ( exceptionThrown.compareAndSet( false, true ) ) + { + throw authorizationExpiredException(); + } + return "Done"; + } ) ) ) ); + + assertEquals( "Done", result ); + } + @Test void doesNotRetryOnRandomClientExceptionRx() { @@ -1270,6 +1337,11 @@ private static TransientException transientException() return new TransientException( "", "" ); } + private static AuthorizationExpiredException authorizationExpiredException() + { + return new AuthorizationExpiredException( "", "" ); + } + @SuppressWarnings( "unchecked" ) private static Supplier newWorkMock() { @@ -1277,7 +1349,7 @@ private static Supplier newWorkMock() } private static void assertDelaysApproximatelyEqual( List expectedDelays, List actualDelays, - double delta ) + double delta ) { assertEquals( expectedDelays.size(), actualDelays.size() ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/ErrorUtilTest.java b/driver/src/test/java/org/neo4j/driver/internal/util/ErrorUtilTest.java index ab256103f1..f4b588e0c3 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/util/ErrorUtilTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/util/ErrorUtilTest.java @@ -23,6 +23,7 @@ import java.io.IOException; import org.neo4j.driver.exceptions.AuthenticationException; +import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.DatabaseException; import org.neo4j.driver.exceptions.Neo4jException; @@ -161,4 +162,17 @@ void shouldCreateConnectionTerminatedErrorWithReason() assertThat( error.getMessage(), startsWith( "Connection to the database terminated" ) ); assertThat( error.getMessage(), containsString( reason ) ); } + + @Test + void shouldCreateAuthorizationExpiredException() + { + String code = "Neo.ClientError.Security.AuthorizationExpired"; + String message = "Expired authorization info"; + + Neo4jException error = newNeo4jError( code, message ); + + assertThat( error, instanceOf( AuthorizationExpiredException.class ) ); + assertEquals( code, error.code() ); + assertEquals( message, error.getMessage() ); + } } diff --git a/driver/src/test/java/org/neo4j/driver/util/TestUtil.java b/driver/src/test/java/org/neo4j/driver/util/TestUtil.java index 215f0d7e0d..835cb0c3c6 100644 --- a/driver/src/test/java/org/neo4j/driver/util/TestUtil.java +++ b/driver/src/test/java/org/neo4j/driver/util/TestUtil.java @@ -57,7 +57,6 @@ import org.neo4j.driver.internal.async.NetworkSession; import org.neo4j.driver.internal.async.connection.EventLoopGroupFactory; import org.neo4j.driver.internal.handlers.BeginTxResponseHandler; -import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.BoltProtocolVersion; import org.neo4j.driver.internal.messaging.Message; @@ -88,7 +87,6 @@ import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -353,19 +351,12 @@ public static void verifyRollbackTx( Connection connection ) public static void verifyBeginTx( Connection connectionMock ) { - verifyBeginTx( connectionMock, empty() ); + verifyBeginTx( connectionMock, 1 ); } - public static void verifyBeginTx( Connection connectionMock, Bookmark bookmark ) + public static void verifyBeginTx( Connection connectionMock, int times ) { - if ( bookmark.isEmpty() ) - { - verify( connectionMock ).write( any( BeginMessage.class ), eq( NoOpResponseHandler.INSTANCE ) ); - } - else - { - verify( connectionMock ).writeAndFlush( any( BeginMessage.class ), any( BeginTxResponseHandler.class ) ); - } + verify( connectionMock, times( times ) ).writeAndFlush( any( BeginMessage.class ), any( BeginTxResponseHandler.class ) ); } public static void setupFailingRun( Connection connection, Throwable error ) diff --git a/driver/src/test/resources/database_shutdown_at_commit.script b/driver/src/test/resources/database_shutdown_at_commit.script index 9bfa248256..42c4465be6 100644 --- a/driver/src/test/resources/database_shutdown_at_commit.script +++ b/driver/src/test/resources/database_shutdown_at_commit.script @@ -4,11 +4,11 @@ !: AUTO GOODBYE C: BEGIN {} - RUN "CREATE (n {name:'Bob'})" {} {} +S: SUCCESS {} +C: RUN "CREATE (n {name:'Bob'})" {} {} PULL_ALL S: SUCCESS {} SUCCESS {} - SUCCESS {} C: COMMIT S: FAILURE {"code": "Neo.TransientError.General.DatabaseUnavailable", "message": "Database shut down."} S: diff --git a/driver/src/test/resources/read_tx_v4_discard.script b/driver/src/test/resources/read_tx_v4_discard.script index eb3e8a157b..ea13b249f1 100644 --- a/driver/src/test/resources/read_tx_v4_discard.script +++ b/driver/src/test/resources/read_tx_v4_discard.script @@ -4,9 +4,9 @@ !: AUTO GOODBYE C: BEGIN { "mode": "r" } - RUN "UNWIND [1,2,3,4] AS a RETURN a" {} {} S: SUCCESS {} - SUCCESS {"t_first": 110, "fields": ["a"], "qid": 0} +C: RUN "UNWIND [1,2,3,4] AS a RETURN a" {} {} +S: SUCCESS {"t_first": 110, "fields": ["a"], "qid": 0} C: DISCARD {"qid": 0, "n": -1} S: SUCCESS {"type": "r", "t_last": 3, "db": "neo4j"} C: COMMIT diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/TestkitRequest.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/TestkitRequest.java index 9445e8283e..c4d45f4770 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/TestkitRequest.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/TestkitRequest.java @@ -34,7 +34,8 @@ @JsonSubTypes.Type( SessionBeginTransaction.class ), @JsonSubTypes.Type( TransactionCommit.class ), @JsonSubTypes.Type( SessionLastBookmarks.class ), @JsonSubTypes.Type( SessionWriteTransaction.class ), @JsonSubTypes.Type( ResolverResolutionCompleted.class ), @JsonSubTypes.Type( CheckMultiDBSupport.class ), - @JsonSubTypes.Type( DomainNameResolutionCompleted.class ), @JsonSubTypes.Type( StartTest.class ) + @JsonSubTypes.Type( DomainNameResolutionCompleted.class ), @JsonSubTypes.Type( StartTest.class ), + @JsonSubTypes.Type( TransactionRollback.class ) } ) public interface TestkitRequest { diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/TransactionRollback.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/TransactionRollback.java new file mode 100644 index 0000000000..148a515886 --- /dev/null +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/TransactionRollback.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package neo4j.org.testkit.backend.messages.requests; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import neo4j.org.testkit.backend.TestkitState; +import neo4j.org.testkit.backend.messages.responses.TestkitResponse; +import neo4j.org.testkit.backend.messages.responses.Transaction; + +import java.util.Optional; + +@Getter +@NoArgsConstructor +@Setter +public class TransactionRollback implements TestkitRequest +{ + private TransactionRollbackBody data; + + @Override + public TestkitResponse process( TestkitState testkitState ) + { + return Optional.ofNullable( testkitState.getTransactions().get( data.txId ) ) + .map( tx -> + { + tx.rollback(); + return transaction( data.txId ); + } ) + .orElseThrow( () -> new RuntimeException( "Could not find transaction" ) ); + } + + private Transaction transaction( String txId ) + { + return Transaction.builder().data( Transaction.TransactionBody.builder().id( txId ).build() ).build(); + } + + @Getter + @NoArgsConstructor + @Setter + public static class TransactionRollbackBody + { + private String txId; + } +}