17
17
18
18
19
19
import inspect
20
+ from collections import defaultdict
20
21
21
22
import pytest
22
23
50
51
ROUTER1_ADDRESS = ResolvedAddress (("1.2.3.1" , 9000 ), host_name = "host" )
51
52
ROUTER2_ADDRESS = ResolvedAddress (("1.2.3.1" , 9001 ), host_name = "host" )
52
53
ROUTER3_ADDRESS = ResolvedAddress (("1.2.3.1" , 9002 ), host_name = "host" )
53
- READER_ADDRESS = ResolvedAddress (("1.2.3.1" , 9010 ), host_name = "host" )
54
- WRITER_ADDRESS = ResolvedAddress (("1.2.3.1" , 9020 ), host_name = "host" )
54
+ READER1_ADDRESS = ResolvedAddress (("1.2.3.1" , 9010 ), host_name = "host" )
55
+ READER2_ADDRESS = ResolvedAddress (("1.2.3.1" , 9011 ), host_name = "host" )
56
+ READER3_ADDRESS = ResolvedAddress (("1.2.3.1" , 9012 ), host_name = "host" )
57
+ WRITER1_ADDRESS = ResolvedAddress (("1.2.3.1" , 9020 ), host_name = "host" )
55
58
56
59
57
60
@pytest .fixture
58
- def routing_failure_opener (async_fake_connection_generator , mocker ):
59
- def make_opener (failures = None ):
61
+ def custom_routing_opener (async_fake_connection_generator , mocker ):
62
+ def make_opener (failures = None , get_readers = None ):
60
63
def routing_side_effect (* args , ** kwargs ):
61
64
nonlocal failures
62
65
res = next (failures , None )
63
66
if res is None :
67
+ if get_readers is not None :
68
+ readers = get_readers (kwargs .get ("database" ))
69
+ else :
70
+ readers = [str (READER1_ADDRESS )]
64
71
return [{
65
72
"ttl" : 1000 ,
66
73
"servers" : [
67
74
{"addresses" : [str (ROUTER1_ADDRESS ),
68
75
str (ROUTER2_ADDRESS ),
69
76
str (ROUTER3_ADDRESS )],
70
77
"role" : "ROUTE" },
71
- {"addresses" : [ str ( READER_ADDRESS )] , "role" : "READ" },
72
- {"addresses" : [str (WRITER_ADDRESS )], "role" : "WRITE" },
78
+ {"addresses" : readers , "role" : "READ" },
79
+ {"addresses" : [str (WRITER1_ADDRESS )], "role" : "WRITE" },
73
80
],
74
81
}]
75
82
raise res
@@ -96,8 +103,8 @@ async def open_(addr, auth, timeout):
96
103
97
104
98
105
@pytest .fixture
99
- def opener (routing_failure_opener ):
100
- return routing_failure_opener ()
106
+ def opener (custom_routing_opener ):
107
+ return custom_routing_opener ()
101
108
102
109
103
110
def _pool_config ():
@@ -177,9 +184,9 @@ async def test_chooses_right_connection_type(opener, type_):
177
184
)
178
185
await pool .release (cx1 )
179
186
if type_ == "r" :
180
- assert cx1 .unresolved_address == READER_ADDRESS
187
+ assert cx1 .unresolved_address == READER1_ADDRESS
181
188
else :
182
- assert cx1 .unresolved_address == WRITER_ADDRESS
189
+ assert cx1 .unresolved_address == WRITER1_ADDRESS
183
190
184
191
185
192
@mark_async_test
@@ -298,9 +305,9 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection(
298
305
opener , liveness_timeout
299
306
):
300
307
pool = _simple_pool (opener )
301
- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
308
+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
302
309
liveness_timeout )
303
- assert cx1 .unresolved_address == READER_ADDRESS
310
+ assert cx1 .unresolved_address == READER1_ADDRESS
304
311
cx1 .reset .assert_not_called ()
305
312
306
313
@@ -311,11 +318,11 @@ async def test_acquire_performs_liveness_check_on_existing_connection(
311
318
):
312
319
pool = _simple_pool (opener )
313
320
# populate the pool with a connection
314
- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
321
+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
315
322
liveness_timeout )
316
323
317
324
# make sure we assume the right state
318
- assert cx1 .unresolved_address == READER_ADDRESS
325
+ assert cx1 .unresolved_address == READER1_ADDRESS
319
326
cx1 .is_idle_for .assert_not_called ()
320
327
cx1 .reset .assert_not_called ()
321
328
@@ -326,7 +333,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection(
326
333
cx1 .reset .assert_not_called ()
327
334
328
335
# then acquire it again and assert the liveness check was performed
329
- cx2 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
336
+ cx2 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
330
337
liveness_timeout )
331
338
assert cx1 is cx2
332
339
cx1 .is_idle_for .assert_called_once_with (liveness_timeout )
@@ -345,11 +352,11 @@ def liveness_side_effect(*args, **kwargs):
345
352
liveness_timeout = 1
346
353
pool = _simple_pool (opener )
347
354
# populate the pool with a connection
348
- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
355
+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
349
356
liveness_timeout )
350
357
351
358
# make sure we assume the right state
352
- assert cx1 .unresolved_address == READER_ADDRESS
359
+ assert cx1 .unresolved_address == READER1_ADDRESS
353
360
cx1 .is_idle_for .assert_not_called ()
354
361
cx1 .reset .assert_not_called ()
355
362
@@ -362,7 +369,7 @@ def liveness_side_effect(*args, **kwargs):
362
369
cx1 .reset .assert_not_called ()
363
370
364
371
# then acquire it again and assert the liveness check was performed
365
- cx2 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
372
+ cx2 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
366
373
liveness_timeout )
367
374
assert cx1 is not cx2
368
375
assert cx1 .unresolved_address == cx2 .unresolved_address
@@ -384,14 +391,14 @@ def liveness_side_effect(*args, **kwargs):
384
391
liveness_timeout = 1
385
392
pool = _simple_pool (opener )
386
393
# populate the pool with a connection
387
- cx1 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
394
+ cx1 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
388
395
liveness_timeout )
389
- cx2 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
396
+ cx2 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
390
397
liveness_timeout )
391
398
392
399
# make sure we assume the right state
393
- assert cx1 .unresolved_address == READER_ADDRESS
394
- assert cx2 .unresolved_address == READER_ADDRESS
400
+ assert cx1 .unresolved_address == READER1_ADDRESS
401
+ assert cx2 .unresolved_address == READER1_ADDRESS
395
402
assert cx1 is not cx2
396
403
cx1 .is_idle_for .assert_not_called ()
397
404
cx2 .is_idle_for .assert_not_called ()
@@ -409,7 +416,7 @@ def liveness_side_effect(*args, **kwargs):
409
416
cx2 .reset .assert_not_called ()
410
417
411
418
# then acquire it again and assert the liveness check was performed
412
- cx3 = await pool ._acquire (READER_ADDRESS , None , Deadline (30 ),
419
+ cx3 = await pool ._acquire (READER1_ADDRESS , None , Deadline (30 ),
413
420
liveness_timeout )
414
421
assert cx3 is cx2
415
422
cx1 .is_idle_for .assert_called_once_with (liveness_timeout )
@@ -426,7 +433,7 @@ def mock_connection_breaks_on_close(cx):
426
433
async def close_side_effect ():
427
434
cx .closed .return_value = True
428
435
cx .defunct .return_value = True
429
- await pool .deactivate (READER_ADDRESS )
436
+ await pool .deactivate (READER1_ADDRESS )
430
437
431
438
cx .attach_mock (mocker .AsyncMock (side_effect = close_side_effect ),
432
439
"close" )
@@ -470,9 +477,9 @@ async def test__acquire_new_later_with_room(opener):
470
477
pool = AsyncNeo4jPool (
471
478
opener , config , WorkspaceConfig (), ROUTER1_ADDRESS
472
479
)
473
- assert pool .connections_reservations [READER_ADDRESS ] == 0
474
- creator = pool ._acquire_new_later (READER_ADDRESS , None , Deadline (1 ))
475
- assert pool .connections_reservations [READER_ADDRESS ] == 1
480
+ assert pool .connections_reservations [READER1_ADDRESS ] == 0
481
+ creator = pool ._acquire_new_later (READER1_ADDRESS , None , Deadline (1 ))
482
+ assert pool .connections_reservations [READER1_ADDRESS ] == 1
476
483
assert callable (creator )
477
484
if AsyncUtil .is_async_code :
478
485
assert inspect .iscoroutinefunction (creator )
@@ -487,9 +494,9 @@ async def test__acquire_new_later_without_room(opener):
487
494
)
488
495
_ = await pool .acquire (READ_ACCESS , 30 , "test_db" , None , None , None )
489
496
# pool is full now
490
- assert pool .connections_reservations [READER_ADDRESS ] == 0
491
- creator = pool ._acquire_new_later (READER_ADDRESS , None , Deadline (1 ))
492
- assert pool .connections_reservations [READER_ADDRESS ] == 0
497
+ assert pool .connections_reservations [READER1_ADDRESS ] == 0
498
+ creator = pool ._acquire_new_later (READER1_ADDRESS , None , Deadline (1 ))
499
+ assert pool .connections_reservations [READER1_ADDRESS ] == 0
493
500
assert creator is None
494
501
495
502
@@ -519,8 +526,8 @@ async def test_passes_pool_config_to_connection(mocker):
519
526
"Neo.ClientError.Security.AuthorizationExpired" ),
520
527
))
521
528
@mark_async_test
522
- async def test_discovery_is_retried (routing_failure_opener , error ):
523
- opener = routing_failure_opener ([
529
+ async def test_discovery_is_retried (custom_routing_opener , error ):
530
+ opener = custom_routing_opener ([
524
531
None , # first call to router for seeding the RT with more routers
525
532
error , # will be retried
526
533
])
@@ -563,8 +570,8 @@ async def test_discovery_is_retried(routing_failure_opener, error):
563
570
)
564
571
))
565
572
@mark_async_test
566
- async def test_fast_failing_discovery (routing_failure_opener , error ):
567
- opener = routing_failure_opener ([
573
+ async def test_fast_failing_discovery (custom_routing_opener , error ):
574
+ opener = custom_routing_opener ([
568
575
None , # first call to router for seeding the RT with more routers
569
576
error , # will be retried
570
577
])
@@ -648,3 +655,85 @@ async def test_connection_error_callback(
648
655
cx .mark_unauthenticated .assert_not_called ()
649
656
for cx in cxs_write :
650
657
cx .mark_unauthenticated .assert_not_called ()
658
+
659
+
660
+ @mark_async_test
661
+ async def test_pool_closes_connections_dropped_from_rt (custom_routing_opener ):
662
+ readers = {"db1" : [str (READER1_ADDRESS )]}
663
+
664
+ def get_readers (database ):
665
+ return readers [database ]
666
+
667
+ opener = custom_routing_opener (get_readers = get_readers )
668
+
669
+ pool = AsyncNeo4jPool (
670
+ opener , _pool_config (), WorkspaceConfig (), ROUTER1_ADDRESS
671
+ )
672
+ cx1 = await pool .acquire (READ_ACCESS , 30 , "db1" , None , None , None )
673
+ assert cx1 .unresolved_address == READER1_ADDRESS
674
+ await pool .release (cx1 )
675
+
676
+ cx1 .close .assert_not_called ()
677
+ assert len (pool .connections [READER1_ADDRESS ]) == 1
678
+
679
+ # force RT refresh, returning a different reader
680
+ del pool .routing_tables ["db1" ]
681
+ readers ["db1" ] = [str (READER2_ADDRESS )]
682
+
683
+ cx2 = await pool .acquire (READ_ACCESS , 30 , "db1" , None , None , None )
684
+ assert cx2 .unresolved_address == READER2_ADDRESS
685
+
686
+ cx1 .close .assert_awaited_once ()
687
+ assert len (pool .connections [READER1_ADDRESS ]) == 0
688
+
689
+ await pool .release (cx2 )
690
+ assert len (pool .connections [READER2_ADDRESS ]) == 1
691
+
692
+
693
+ @mark_async_test
694
+ async def test_pool_does_not_close_connections_dropped_from_rt_for_other_server (
695
+ custom_routing_opener
696
+ ):
697
+ readers = {
698
+ "db1" : [str (READER1_ADDRESS ), str (READER2_ADDRESS )],
699
+ "db2" : [str (READER1_ADDRESS )]
700
+ }
701
+
702
+ def get_readers (database ):
703
+ return readers [database ]
704
+
705
+ opener = custom_routing_opener (get_readers = get_readers )
706
+
707
+ pool = AsyncNeo4jPool (
708
+ opener , _pool_config (), WorkspaceConfig (), ROUTER1_ADDRESS
709
+ )
710
+ cx1 = await pool .acquire (READ_ACCESS , 30 , "db1" , None , None , None )
711
+ await pool .release (cx1 )
712
+ assert cx1 .unresolved_address in (READER1_ADDRESS , READER2_ADDRESS )
713
+ reader1_connection_count = len (pool .connections [READER1_ADDRESS ])
714
+ reader2_connection_count = len (pool .connections [READER2_ADDRESS ])
715
+ assert reader1_connection_count + reader2_connection_count == 1
716
+
717
+ cx2 = await pool .acquire (READ_ACCESS , 30 , "db2" , None , None , None )
718
+ await pool .release (cx2 )
719
+ assert cx2 .unresolved_address == READER1_ADDRESS
720
+ cx1 .close .assert_not_called ()
721
+ cx2 .close .assert_not_called ()
722
+ assert len (pool .connections [READER1_ADDRESS ]) == 1
723
+ assert len (pool .connections [READER2_ADDRESS ]) == reader2_connection_count
724
+
725
+
726
+ # force RT refresh, returning a different reader
727
+ del pool .routing_tables ["db2" ]
728
+ readers ["db2" ] = [str (READER3_ADDRESS )]
729
+
730
+ cx3 = await pool .acquire (READ_ACCESS , 30 , "db2" , None , None , None )
731
+ await pool .release (cx3 )
732
+ assert cx3 .unresolved_address == READER3_ADDRESS
733
+
734
+ cx1 .close .assert_not_called ()
735
+ cx2 .close .assert_not_called ()
736
+ cx3 .close .assert_not_called ()
737
+ assert len (pool .connections [READER1_ADDRESS ]) == 1
738
+ assert len (pool .connections [READER2_ADDRESS ]) == reader2_connection_count
739
+ assert len (pool .connections [READER3_ADDRESS ]) == 1
0 commit comments