Skip to content

Commit d92aa8e

Browse files
authored
Fix ResourceWarning when canceling await driver.close() (#887)
1 parent ca55f0b commit d92aa8e

File tree

3 files changed

+66
-3
lines changed

3 files changed

+66
-3
lines changed

src/neo4j/_async/driver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import asyncio
2122
import typing as t
2223

2324

@@ -466,7 +467,11 @@ def session(self, **config) -> AsyncSession:
466467
async def close(self) -> None:
467468
""" Shut down, closing any open connections in the pool.
468469
"""
469-
await self._pool.close()
470+
try:
471+
await self._pool.close()
472+
except asyncio.CancelledError:
473+
self._closed = True
474+
raise
470475
self._closed = True
471476

472477
if t.TYPE_CHECKING:

src/neo4j/_sync/driver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import asyncio
2122
import typing as t
2223

2324

@@ -463,7 +464,11 @@ def session(self, **config) -> Session:
463464
def close(self) -> None:
464465
""" Shut down, closing any open connections in the pool.
465466
"""
466-
self._pool.close()
467+
try:
468+
self._pool.close()
469+
except asyncio.CancelledError:
470+
self._closed = True
471+
raise
467472
self._closed = True
468473

469474
if t.TYPE_CHECKING:

tests/integration/mixed/test_async_driver.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import neo4j
2525

2626
from ... import env
27+
from ..._async_compat import mark_async_test
2728

2829

2930
# TODO: Python 3.9: when support gets dropped, remove this mark
@@ -44,7 +45,7 @@ def test_can_create_async_driver_outside_of_loop(uri, auth):
4445

4546
async def return_1(tx: neo4j.AsyncManagedTransaction) -> None:
4647
nonlocal counter, was_full
47-
res = await tx.run("RETURN 1")
48+
res = await tx.run("UNWIND range(1, 10000) AS x RETURN x")
4849

4950
counter += 1
5051
while not was_full and counter < pool_size:
@@ -86,3 +87,55 @@ async def run(driver_: neo4j.AsyncDriver):
8687
loop.run_until_complete(coro)
8788
finally:
8889
loop.close()
90+
91+
92+
@mark_async_test
93+
async def test_cancel_driver_close(uri, auth):
94+
class Signal:
95+
queried = False
96+
released = False
97+
98+
async def fill_pool(driver_: neo4j.AsyncDriver, n=10):
99+
signals = [Signal() for _ in range(n)]
100+
await asyncio.gather(
101+
*(handle_session(driver_.session(), signals[i]) for i in range(n)),
102+
handle_signals(signals),
103+
return_exceptions=True,
104+
)
105+
106+
async def handle_signals(signals):
107+
while any(not signal.queried for signal in signals):
108+
await asyncio.sleep(0.001)
109+
await asyncio.sleep(0.1)
110+
for signal in signals:
111+
signal.released = True
112+
113+
async def handle_session(session, signal):
114+
async with session:
115+
await session.execute_read(work, signal)
116+
117+
async def work(tx: neo4j.AsyncManagedTransaction, signal: Signal) -> None:
118+
res = await tx.run("UNWIND range(1, 10000) AS x RETURN x")
119+
signal.queried = True
120+
while not signal.released:
121+
await asyncio.sleep(0.001)
122+
await res.consume()
123+
124+
def connection_count(driver_):
125+
return sum(len(v) for v in driver_._pool.connections.values())
126+
127+
driver = neo4j.AsyncGraphDatabase.driver(uri, auth=auth)
128+
await fill_pool(driver)
129+
# sanity check, there should be some connections
130+
assert connection_count(driver) >= 10
131+
132+
# start the close and give it some event loop iterations to kick off
133+
fut = asyncio.ensure_future(driver.close())
134+
await asyncio.sleep(0)
135+
136+
# cancel in the middle of closing connections
137+
fut.cancel()
138+
# give the driver a chance to close connections forcefully
139+
await asyncio.sleep(0)
140+
# driver should be marked as closed to not emmit a ResourceWarning later
141+
assert driver._closed == True

0 commit comments

Comments
 (0)