Skip to content
35 changes: 31 additions & 4 deletions devtools/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,42 @@ class ToReplace:
insert_assert_summary: ContextVar[list[str]] = ContextVar('insert_assert_summary')


def insert_assert(value: Any) -> int:
def sort_data_from_source(source: Any, value: Any) -> Any:
if isinstance(value, dict) and isinstance(source, dict):
new_dict = {}
used_keys = set()
for k, v in source.items():
if k in value:
new_dict[k] = sort_data_from_source(v, value[k])
used_keys.add(k)
for k, v in value.items():
if k not in used_keys:
new_dict[k] = v
return new_dict
elif isinstance(value, list) and isinstance(source, list):
new_list: list[Any] = []
for i, v in enumerate(value):
if i < len(source):
new_list.append(sort_data_from_source(source[i], v))
else:
new_list.append(v)
return new_list
else:
return value


def insert_assert(value: Any, prev: Any = None) -> int:
call_frame: FrameType = sys._getframe(1)
if sys.version_info < (3, 8): # pragma: no cover
raise RuntimeError('insert_assert() requires Python 3.8+')

if prev:
use_value = sort_data_from_source(prev, value)
else:
use_value = value
format_code = load_black()
ex = Source.for_frame(call_frame).executing(call_frame)
if ex.node is None: # pragma: no cover
python_code = format_code(str(custom_repr(value)))
python_code = format_code(str(custom_repr(use_value)))
raise RuntimeError(
f'insert_assert() was unable to find the frame from which it was called, called with:\n{python_code}'
)
Expand All @@ -55,7 +82,7 @@ def insert_assert(value: Any) -> int:
else:
arg = ' '.join(map(str.strip, ex.source.asttokens().get_text(ast_arg).splitlines()))

python_code = format_code(f'# insert_assert({arg})\nassert {arg} == {custom_repr(value)}')
python_code = format_code(f'# insert_assert({arg})\nassert {arg} == {custom_repr(use_value)}')

python_code = textwrap.indent(python_code, ex.node.col_offset * ' ')
to_replace.append(ToReplace(Path(call_frame.f_code.co_filename), ex.node.lineno, ex.node.end_lineno, python_code))
Expand Down
73 changes: 73 additions & 0 deletions tests/test_insert_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,76 @@ def test_string_assert(x, insert_assert):
)
captured = capsys.readouterr()
assert '2 insert skipped because an assert statement on that line had already be inserted!\n' in captured.out


def test_insert_assert_sort_data(pytester_pretty):
os.environ.pop('CI', None)
pytester_pretty.makeconftest(config)
test_file = pytester_pretty.makepyfile(
"""
def test_dict(insert_assert):
old_data = {
"foo": 1,
"bar": [
{"name": "Pydantic", "tags": ["validation", "json"]},
{"name": "FastAPI", "description": "Web API framework in Python"},
{"name": "SQLModel"},
],
"baz": 3,
}
new_data = {
"bar": [
{
"description": "Data validation library",
"tags": ["validation", "json"],
"name": "Pydantic",
},
{"description": "Web API framework in Python", "name": "FastAPI"},
{"description": "DBs and Python", "name": "SQLModel"},
{"name": "ARQ"},
],
"baz": 6,
"foo": 1,
}
insert_assert(new_data, old_data)
"""
)
result = pytester_pretty.runpytest()
result.assert_outcomes(passed=1)
assert test_file.read_text() == (
"""def test_dict(insert_assert):
old_data = {
"foo": 1,
"bar": [
{"name": "Pydantic", "tags": ["validation", "json"]},
{"name": "FastAPI", "description": "Web API framework in Python"},
{"name": "SQLModel"},
],
"baz": 3,
}
new_data = {
"bar": [
{
"description": "Data validation library",
"tags": ["validation", "json"],
"name": "Pydantic",
},
{"description": "Web API framework in Python", "name": "FastAPI"},
{"description": "DBs and Python", "name": "SQLModel"},
{"name": "ARQ"},
],
"baz": 6,
"foo": 1,
}
# insert_assert(new_data)
assert new_data == {
'foo': 1,
'bar': [
{'name': 'Pydantic', 'tags': ['validation', 'json'], 'description': 'Data validation library'},
{'name': 'FastAPI', 'description': 'Web API framework in Python'},
{'name': 'SQLModel', 'description': 'DBs and Python'},
{'name': 'ARQ'},
],
'baz': 6,
}"""
)