2222import io .netty .channel .Channel ;
2323import io .netty .channel .EventLoopGroup ;
2424
25+ import java .util .HashMap ;
26+ import java .util .Iterator ;
27+ import java .util .Map ;
2528import java .util .Set ;
2629import java .util .concurrent .CompletableFuture ;
2730import java .util .concurrent .CompletionException ;
2831import java .util .concurrent .CompletionStage ;
29- import java .util .concurrent .ConcurrentHashMap ;
30- import java .util .concurrent .ConcurrentMap ;
3132import java .util .concurrent .TimeUnit ;
3233import java .util .concurrent .TimeoutException ;
3334import java .util .concurrent .atomic .AtomicBoolean ;
35+ import java .util .concurrent .locks .Lock ;
36+ import java .util .concurrent .locks .ReadWriteLock ;
37+ import java .util .concurrent .locks .ReentrantReadWriteLock ;
38+ import java .util .function .Supplier ;
3439
3540import org .neo4j .driver .Logger ;
3641import org .neo4j .driver .Logging ;
@@ -63,7 +68,8 @@ public class ConnectionPoolImpl implements ConnectionPool
6368 private final MetricsListener metricsListener ;
6469 private final boolean ownsEventLoopGroup ;
6570
66- private final ConcurrentMap <BoltServerAddress ,ExtendedChannelPool > pools = new ConcurrentHashMap <>();
71+ private final ReadWriteLock addressToPoolLock = new ReentrantReadWriteLock ();
72+ private final Map <BoltServerAddress ,ExtendedChannelPool > addressToPool = new HashMap <>();
6773 private final AtomicBoolean closed = new AtomicBoolean ();
6874 private final CompletableFuture <Void > closeFuture = new CompletableFuture <>();
6975 private final ConnectionFactory connectionFactory ;
@@ -126,25 +132,32 @@ public CompletionStage<Connection> acquire( BoltServerAddress address )
126132 @ Override
127133 public void retainAll ( Set <BoltServerAddress > addressesToRetain )
128134 {
129- for ( BoltServerAddress address : pools . keySet () )
135+ executeWithLock ( addressToPoolLock . writeLock (), () ->
130136 {
131- if ( !addressesToRetain .contains ( address ) )
137+ Iterator <Map .Entry <BoltServerAddress ,ExtendedChannelPool >> entryIterator = addressToPool .entrySet ().iterator ();
138+ while ( entryIterator .hasNext () )
132139 {
133- int activeChannels = nettyChannelTracker .inUseChannelCount ( address );
134- if ( activeChannels == 0 )
140+ Map .Entry <BoltServerAddress ,ExtendedChannelPool > entry = entryIterator .next ();
141+ BoltServerAddress address = entry .getKey ();
142+ if ( !addressesToRetain .contains ( address ) )
135143 {
136- // address is not present in updated routing table and has no active connections
137- // it's now safe to terminate corresponding connection pool and forget about it
138- ExtendedChannelPool pool = pools .remove ( address );
139- if ( pool != null )
144+ int activeChannels = nettyChannelTracker .inUseChannelCount ( address );
145+ if ( activeChannels == 0 )
140146 {
141- log .info ( "Closing connection pool towards %s, it has no active connections " +
142- "and is not in the routing table registry." , address );
143- closePoolInBackground ( address , pool );
147+ // address is not present in updated routing table and has no active connections
148+ // it's now safe to terminate corresponding connection pool and forget about it
149+ ExtendedChannelPool pool = entry .getValue ();
150+ entryIterator .remove ();
151+ if ( pool != null )
152+ {
153+ log .info ( "Closing connection pool towards %s, it has no active connections " +
154+ "and is not in the routing table registry." , address );
155+ closePoolInBackground ( address , pool );
156+ }
144157 }
145158 }
146159 }
147- }
160+ } );
148161 }
149162
150163 @ Override
@@ -165,35 +178,40 @@ public CompletionStage<Void> close()
165178 if ( closed .compareAndSet ( false , true ) )
166179 {
167180 nettyChannelTracker .prepareToCloseChannels ();
168- CompletableFuture <Void > allPoolClosedFuture = closeAllPools ();
169181
170- // We can only shutdown event loop group when all netty pools are fully closed,
171- // otherwise the netty pools might missing threads (from event loop group) to execute clean ups.
172- allPoolClosedFuture .whenComplete ( ( ignored , pollCloseError ) -> {
173- pools .clear ();
174- if ( !ownsEventLoopGroup )
175- {
176- completeWithNullIfNoError ( closeFuture , pollCloseError );
177- }
178- else
179- {
180- shutdownEventLoopGroup ( pollCloseError );
181- }
182- } );
182+ executeWithLockAsync ( addressToPoolLock .writeLock (),
183+ () ->
184+ {
185+ // We can only shutdown event loop group when all netty pools are fully closed,
186+ // otherwise the netty pools might missing threads (from event loop group) to execute clean ups.
187+ return closeAllPools ().whenComplete (
188+ ( ignored , pollCloseError ) ->
189+ {
190+ addressToPool .clear ();
191+ if ( !ownsEventLoopGroup )
192+ {
193+ completeWithNullIfNoError ( closeFuture , pollCloseError );
194+ }
195+ else
196+ {
197+ shutdownEventLoopGroup ( pollCloseError );
198+ }
199+ } );
200+ } );
183201 }
184202 return closeFuture ;
185203 }
186204
187205 @ Override
188206 public boolean isOpen ( BoltServerAddress address )
189207 {
190- return pools . containsKey ( address );
208+ return executeWithLock ( addressToPoolLock . readLock (), () -> addressToPool . containsKey ( address ) );
191209 }
192210
193211 @ Override
194212 public String toString ()
195213 {
196- return "ConnectionPoolImpl{" + "pools=" + pools + '}' ;
214+ return executeWithLock ( addressToPoolLock . readLock (), () -> "ConnectionPoolImpl{" + "pools=" + addressToPool + '}' ) ;
197215 }
198216
199217 private void processAcquisitionError ( ExtendedChannelPool pool , BoltServerAddress serverAddress , Throwable error )
@@ -239,15 +257,15 @@ private void assertNotClosed( BoltServerAddress address, Channel channel, Extend
239257 {
240258 pool .release ( channel );
241259 closePoolInBackground ( address , pool );
242- pools . remove ( address );
260+ executeWithLock ( addressToPoolLock . writeLock (), () -> addressToPool . remove ( address ) );
243261 assertNotClosed ();
244262 }
245263 }
246264
247265 // for testing only
248266 ExtendedChannelPool getPool ( BoltServerAddress address )
249267 {
250- return pools . get ( address );
268+ return executeWithLock ( addressToPoolLock . readLock (), () -> addressToPool . get ( address ) );
251269 }
252270
253271 ExtendedChannelPool newPool ( BoltServerAddress address )
@@ -258,12 +276,22 @@ ExtendedChannelPool newPool( BoltServerAddress address )
258276
259277 private ExtendedChannelPool getOrCreatePool ( BoltServerAddress address )
260278 {
261- return pools .computeIfAbsent ( address , ignored -> {
262- ExtendedChannelPool pool = newPool ( address );
263- // before the connection pool is added I can add the metrics for the pool.
264- metricsListener .putPoolMetrics ( pool .id (), address , this );
265- return pool ;
266- } );
279+ ExtendedChannelPool existingPool = executeWithLock ( addressToPoolLock .readLock (), () -> addressToPool .get ( address ) );
280+ return existingPool != null
281+ ? existingPool
282+ : executeWithLock ( addressToPoolLock .writeLock (),
283+ () ->
284+ {
285+ ExtendedChannelPool pool = addressToPool .get ( address );
286+ if ( pool == null )
287+ {
288+ pool = newPool ( address );
289+ // before the connection pool is added I can add the metrics for the pool.
290+ metricsListener .putPoolMetrics ( pool .id (), address , this );
291+ addressToPool .put ( address , pool );
292+ }
293+ return pool ;
294+ } );
267295 }
268296
269297 private CompletionStage <Void > closePool ( ExtendedChannelPool pool )
@@ -305,12 +333,45 @@ private void shutdownEventLoopGroup( Throwable pollCloseError )
305333 private CompletableFuture <Void > closeAllPools ()
306334 {
307335 return CompletableFuture .allOf (
308- pools .entrySet ().stream ().map ( entry -> {
309- BoltServerAddress address = entry .getKey ();
310- ExtendedChannelPool pool = entry .getValue ();
311- log .info ( "Closing connection pool towards %s" , address );
312- // Wait for all pools to be closed.
313- return closePool ( pool ).toCompletableFuture ();
314- } ).toArray ( CompletableFuture []::new ) );
336+ addressToPool .entrySet ().stream ()
337+ .map ( entry ->
338+ {
339+ BoltServerAddress address = entry .getKey ();
340+ ExtendedChannelPool pool = entry .getValue ();
341+ log .info ( "Closing connection pool towards %s" , address );
342+ // Wait for all pools to be closed.
343+ return closePool ( pool ).toCompletableFuture ();
344+ } )
345+ .toArray ( CompletableFuture []::new ) );
346+ }
347+
348+ private void executeWithLock ( Lock lock , Runnable runnable )
349+ {
350+ executeWithLock ( lock , () ->
351+ {
352+ runnable .run ();
353+ return null ;
354+ } );
355+ }
356+
357+ private <T > T executeWithLock ( Lock lock , Supplier <T > supplier )
358+ {
359+ lock .lock ();
360+ try
361+ {
362+ return supplier .get ();
363+ }
364+ finally
365+ {
366+ lock .unlock ();
367+ }
368+ }
369+
370+ private <T > void executeWithLockAsync ( Lock lock , Supplier <CompletionStage <T >> stageSupplier )
371+ {
372+ lock .lock ();
373+ CompletableFuture .completedFuture ( lock )
374+ .thenCompose ( ignored -> stageSupplier .get () )
375+ .whenComplete ( ( ignored , throwable ) -> lock .unlock () );
315376 }
316377}
0 commit comments