Skip to content

Commit e13df11

Browse files
committed
Adjust unit tests
1 parent 02cc9d5 commit e13df11

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

tests/unit/common/test_debug.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def setup_mock(*logger_names):
4848
loggers = [logging.getLogger(name) for name in logger_names]
4949
for logger in loggers:
5050
logger.addHandler = mocker.Mock()
51-
logger.addFilter = mocker.Mock(side_effect=logger.addFilter)
5251
logger.removeHandler = mocker.Mock()
5352
logger.setLevel = mocker.Mock()
5453
return loggers
@@ -179,6 +178,7 @@ def test_watcher_colour(logger_mocker, colour, thread, task) -> None:
179178
thread_info=thread, task_info=task)
180179
watcher.watch()
181180

181+
logger.addHandler.assert_called_once()
182182
(handler,), _ = logger.addHandler.call_args
183183
assert isinstance(handler, logging.Handler)
184184
assert isinstance(handler.formatter, logging.Formatter)
@@ -198,6 +198,7 @@ def test_watcher_format(logger_mocker, colour, thread, task) -> None:
198198
thread_info=thread, task_info=task)
199199
watcher.watch()
200200

201+
logger.addHandler.assert_called_once()
201202
(handler,), _ = logger.addHandler.call_args
202203
assert isinstance(handler, logging.Handler)
203204
assert isinstance(handler.formatter, logging.Formatter)
@@ -213,14 +214,13 @@ def test_watcher_format(logger_mocker, colour, thread, task) -> None:
213214
assert format_ == expected_format
214215

215216

216-
@pytest.mark.parametrize("colour", (True, False))
217-
@pytest.mark.parametrize("thread", (True, False))
218-
@pytest.mark.parametrize("task", (True, False))
219-
def test_watcher_task_injection(
220-
mocker, logger_mocker, colour, thread, task
217+
def _assert_task_injection(
218+
async_: bool, mocker, logger_mocker, colour: bool, thread: bool, task: bool
221219
) -> None:
220+
handler_cls_mock = mocker.patch("neo4j.debug.StreamHandler", autospec=True)
221+
handler_mock = handler_cls_mock.return_value
222222
logger_name = "neo4j"
223-
logger = logger_mocker(logger_name)[0]
223+
# logger = logger_mocker(logger_name)[0]
224224
watcher = neo4j_debug.Watcher(logger_name, colour=colour,
225225
thread_info=thread, task_info=task)
226226
record_mock = mocker.Mock(spec=logging.LogRecord)
@@ -229,34 +229,32 @@ def test_watcher_task_injection(
229229
watcher.watch()
230230

231231
if task:
232-
(filter_,), _ = logger.addFilter.call_args
232+
handler_mock.addFilter.assert_called_once()
233+
(filter_,), _ = handler_mock.addFilter.call_args
233234
assert isinstance(filter_, logging.Filter)
234235
filter_.filter(record_mock)
235-
assert record_mock.task is None
236+
if async_:
237+
assert record_mock.task == id(asyncio.current_task())
238+
else:
239+
assert record_mock.task is None
236240
else:
237-
logger.addFilter.assert_not_called()
241+
handler_mock.addFilter.assert_not_called()
238242

239243

240244
@pytest.mark.parametrize("colour", (True, False))
241245
@pytest.mark.parametrize("thread", (True, False))
242246
@pytest.mark.parametrize("task", (True, False))
243-
@mark_async_test
244-
async def test_async_watcher_task_injection(
247+
def test_watcher_task_injection(
245248
mocker, logger_mocker, colour, thread, task
246249
) -> None:
247-
logger_name = "neo4j"
248-
logger = logger_mocker(logger_name)[0]
249-
watcher = neo4j_debug.Watcher(logger_name, colour=colour,
250-
thread_info=thread, task_info=task)
251-
record_mock = mocker.Mock(spec=logging.LogRecord)
252-
assert not hasattr(record_mock, "task")
250+
_assert_task_injection(False, mocker, logger_mocker, colour, thread, task)
253251

254-
watcher.watch()
255252

256-
if task:
257-
(filter_,), _ = logger.addFilter.call_args
258-
assert isinstance(filter_, logging.Filter)
259-
filter_.filter(record_mock)
260-
assert record_mock.task == id(asyncio.current_task())
261-
else:
262-
logger.addFilter.assert_not_called()
253+
@pytest.mark.parametrize("colour", (True, False))
254+
@pytest.mark.parametrize("thread", (True, False))
255+
@pytest.mark.parametrize("task", (True, False))
256+
@mark_async_test
257+
async def test_async_watcher_task_injection(
258+
mocker, logger_mocker, colour, thread, task
259+
) -> None:
260+
_assert_task_injection(True, mocker, logger_mocker, colour, thread, task)

0 commit comments

Comments
 (0)