Skip to content

Option for custom compare function #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,21 @@ def nested_b_t2():
def nested_b_result():
with open(os.path.join(FIXTURES_DIR, 'nested_b_result.json')) as the_file:
return json.load(the_file)


@pytest.fixture(scope='class')
def compare_func_t1():
with open(os.path.join(FIXTURES_DIR, 'compare_func_t1.json')) as the_file:
return json.load(the_file)


@pytest.fixture(scope='class')
def compare_func_t2():
with open(os.path.join(FIXTURES_DIR, 'compare_func_t2.json')) as the_file:
return json.load(the_file)


@pytest.fixture(scope='class')
def compare_func_result():
with open(os.path.join(FIXTURES_DIR, 'compare_func_result.json')) as the_file:
return json.load(the_file)
26 changes: 22 additions & 4 deletions deepdiff/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,14 @@ def _del_elem(self, parent, parent_to_obj_elem, parent_to_obj_action,
value=obj, action=parent_to_obj_action)

def _do_iterable_item_added(self):
iterable_item_added = self.diff.get('iterable_item_added')
iterable_item_added = self.diff.get('iterable_item_added', {})
iterable_item_moved = self.diff.get('iterable_item_moved')
if iterable_item_moved:
added_dict = {v["new_path"]: v["new_value"] for k, v in iterable_item_moved.items()}
iterable_item_added.update(added_dict)

if iterable_item_added:
self._do_item_added(iterable_item_added)
self._do_item_added(iterable_item_added, insert=True)

def _do_dictionary_item_added(self):
dictionary_item_added = self.diff.get('dictionary_item_added')
Expand All @@ -274,7 +279,7 @@ def _do_attribute_added(self):
if attribute_added:
self._do_item_added(attribute_added)

def _do_item_added(self, items, sort=True):
def _do_item_added(self, items, sort=True, insert=False):
if sort:
# sorting items by their path so that the items with smaller index
# are applied first (unless `sort` is `False` so that order of
Expand All @@ -289,6 +294,11 @@ def _do_item_added(self, items, sort=True):
elements, parent, parent_to_obj_elem, parent_to_obj_action, obj, elem, action = elem_and_details
else:
continue # pragma: no cover. Due to cPython peephole optimizer, this line doesn't get covered. https://github.com/nedbat/coveragepy/issues/198

# Insert is only true for iterables, make sure it is a valid index.
if(insert and elem < len(obj)):
obj.insert(elem, None)

self._set_new_value(parent, parent_to_obj_elem, parent_to_obj_action,
obj, elements, path, elem, action, new_value)

Expand Down Expand Up @@ -397,10 +407,18 @@ def _do_item_removed(self, items):
self._do_verify_changes(path, expected_old_value, current_old_value)

def _do_iterable_item_removed(self):
iterable_item_removed = self.diff.get('iterable_item_removed')
iterable_item_removed = self.diff.get('iterable_item_removed', {})

iterable_item_moved = self.diff.get('iterable_item_moved')
if iterable_item_moved:
# These will get added back during items_added
removed_dict = {k: v["new_value"] for k, v in iterable_item_moved.items()}
iterable_item_removed.update(removed_dict)

if iterable_item_removed:
self._do_item_removed(iterable_item_removed)


def _do_dictionary_item_removed(self):
dictionary_item_removed = self.diff.get('dictionary_item_removed')
if dictionary_item_removed:
Expand Down
92 changes: 85 additions & 7 deletions deepdiff/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
number_to_string, datetime_normalize, KEY_TO_VAL_STR, booleans,
np_ndarray, get_numpy_ndarray_rows, OrderedSetPlus, RepeatedTimer,
TEXT_VIEW, TREE_VIEW, DELTA_VIEW,
np, get_truncate_datetime, dict_)
np, get_truncate_datetime, dict_, CannotCompare)
from deepdiff.serialization import SerializationMixin
from deepdiff.distance import DistanceMixin
from deepdiff.model import (
Expand Down Expand Up @@ -139,6 +139,7 @@ def __init__(self,
truncate_datetime=None,
verbose_level=1,
view=TEXT_VIEW,
iterable_compare_func=None,
_original_type=None,
_parameters=None,
_shared_parameters=None,
Expand All @@ -154,7 +155,8 @@ def __init__(self,
"view, hasher, hashes, max_passes, max_diffs, "
"cutoff_distance_for_pairs, cutoff_intersection_for_pairs, log_frequency_in_sec, cache_size, "
"cache_tuning_sample_size, get_deep_distance, group_by, cache_purge_level, "
"math_epsilon, _original_type, _parameters and _shared_parameters.") % ', '.join(kwargs.keys()))
"math_epsilon, iterable_compare_func, _original_type, "
"_parameters and _shared_parameters.") % ', '.join(kwargs.keys()))

if _parameters:
self.__dict__.update(_parameters)
Expand Down Expand Up @@ -182,6 +184,7 @@ def __init__(self,
self.ignore_string_case = ignore_string_case
self.exclude_obj_callback = exclude_obj_callback
self.number_to_string = number_to_string_func or number_to_string
self.iterable_compare_func = iterable_compare_func
self.ignore_private_variables = ignore_private_variables
self.ignore_nan_inequality = ignore_nan_inequality
self.hasher = hasher
Expand Down Expand Up @@ -558,6 +561,72 @@ def _diff_iterable(self, level, parents_ids=frozenset(), _original_type=None):
else:
self._diff_iterable_in_order(level, parents_ids, _original_type=_original_type)

def _compare_in_order(self, level):
"""
Default compare if `iterable_compare_func` is not provided.
This will compare in sequence order.
"""

return [((i, i), (x, y)) for i, (x, y) in enumerate(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

zip_longest(
level.t1, level.t2, fillvalue=ListItemRemovedOrAdded))]

def _get_matching_pairs(self, level):
"""
Given a level get matching pairs. This returns list of two tuples in the form:
[
(t1 index, t2 index), (t1 item, t2 item)
]

This will compare using the passed in `iterable_compare_func` if available.
Default it to compare in order
"""

if(self.iterable_compare_func is None):
# Match in order if there is no compare function provided
return self._compare_in_order(level)
try:
matches = []
y_matched = set()
y_index_matched = set()
for i, x in enumerate(level.t1):
x_found = False
for j, y in enumerate(level.t2):

if(j in y_index_matched):
# This ensures a one-to-one relationship of matches from t1 to t2.
# If y this index in t2 has already been matched to another x
# it cannot have another match, so just continue.
continue

if(self.iterable_compare_func(x, y, level)):
deep_hash = DeepHash(y,
hashes=self.hashes,
apply_hash=True,
**self.deephash_parameters,
)
y_index_matched.add(j)
y_matched.add(deep_hash[y])
matches.append(((i, j), (x, y)))
x_found = True
break

if(not x_found):
matches.append(((i, -1), (x, ListItemRemovedOrAdded)))
for j, y in enumerate(level.t2):

deep_hash = DeepHash(y,
hashes=self.hashes,
apply_hash=True,
**self.deephash_parameters,
)
if(deep_hash[y] not in y_matched):
matches.append(((-1, j), (ListItemRemovedOrAdded, y)))
return matches
except CannotCompare:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we log to the user that the CannotCompare is raised or let the user's iterable_compare_func log if they want it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be common that CannotCompare is raised, since the most basic case of a simple list of numbers or strings will most likely not want to use a compare function, and would revert to comparing in order. Since it will be so common that this will be raised, I would say probably let the creator of the compare function determine when logging is necessary.

return self._compare_in_order(level)


def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type=None):
# We're handling both subscriptable and non-subscriptable iterables. Which one is it?
subscriptable = self._iterables_subscriptable(level.t1, level.t2)
Expand All @@ -566,10 +635,7 @@ def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type
else:
child_relationship_class = NonSubscriptableIterableRelationship

for i, (x, y) in enumerate(
zip_longest(
level.t1, level.t2, fillvalue=ListItemRemovedOrAdded)):

for (i, j), (x, y) in self._get_matching_pairs(level):
if self._count_diff() is StopIteration:
return # pragma: no cover. This is already covered for addition.

Expand All @@ -586,10 +652,22 @@ def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type
notpresent,
y,
child_relationship_class=child_relationship_class,
child_relationship_param=i)
child_relationship_param=j)
self._report_result('iterable_item_added', change_level)

else: # check if item value has changed

if (i != j):
# Item moved
change_level = level.branch_deeper(
x,
y,
child_relationship_class=child_relationship_class,
child_relationship_param=i,
child_relationship_param2=j
)
self._report_result('iterable_item_moved', change_level)

item_id = id(x)
if parents_ids and item_id in parents_ids:
continue
Expand Down
7 changes: 7 additions & 0 deletions deepdiff/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ def __repr__(self):
__str__ = __repr__


class CannotCompare(Exception):
"""
Exception when two items cannot be compared in the compare function.
"""
pass


unprocessed = Unprocessed()
skipped = Skipped()
not_hashed = NotHashed()
Expand Down
31 changes: 24 additions & 7 deletions deepdiff/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"unprocessed",
"iterable_item_added",
"iterable_item_removed",
"iterable_item_moved",
"attribute_added",
"attribute_removed",
"set_item_removed",
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(self, tree_results=None, verbose_level=1):
"unprocessed": [],
"iterable_item_added": dict_(),
"iterable_item_removed": dict_(),
"iterable_item_moved": dict_(),
"attribute_added": self.__set_or_dict(),
"attribute_removed": self.__set_or_dict(),
"set_item_removed": PrettyOrderedSet(),
Expand All @@ -126,6 +128,7 @@ def _from_tree_results(self, tree):
self._from_tree_unprocessed(tree)
self._from_tree_default(tree, 'iterable_item_added')
self._from_tree_default(tree, 'iterable_item_removed')
self._from_tree_iterable_item_moved(tree)
self._from_tree_default(tree, 'attribute_added')
self._from_tree_default(tree, 'attribute_removed')
self._from_tree_set_item_removed(tree)
Expand Down Expand Up @@ -187,6 +190,13 @@ def _from_tree_value_changed(self, tree):
if 'diff' in change.additional:
the_changed.update({'diff': change.additional['diff']})

def _from_tree_iterable_item_moved(self, tree):
if 'iterable_item_moved' in tree:
for change in tree['iterable_item_moved']:
the_changed = {'new_path': change.path(use_t2=True), 'new_value': change.t2}
self['iterable_item_moved'][change.path(
force=FORCE_DEFAULT)] = the_changed

def _from_tree_unprocessed(self, tree):
if 'unprocessed' in tree:
for change in tree['unprocessed']:
Expand Down Expand Up @@ -244,6 +254,7 @@ def __init__(self, tree_results=None, ignore_order=None):
"values_changed": dict_(),
"iterable_item_added": dict_(),
"iterable_item_removed": dict_(),
"iterable_item_moved": dict_(),
"attribute_added": dict_(),
"attribute_removed": dict_(),
"set_item_removed": dict_(),
Expand Down Expand Up @@ -273,6 +284,7 @@ def _from_tree_results(self, tree):
else:
self._from_tree_default(tree, 'iterable_item_added')
self._from_tree_default(tree, 'iterable_item_removed')
self._from_tree_iterable_item_moved(tree)
self._from_tree_default(tree, 'attribute_added')
self._from_tree_default(tree, 'attribute_removed')
self._from_tree_set_item_removed(tree)
Expand Down Expand Up @@ -528,7 +540,7 @@ def __setattr__(self, key, value):
def repetition(self):
return self.additional['repetition']

def auto_generate_child_rel(self, klass, param):
def auto_generate_child_rel(self, klass, param, param2=None):
"""
Auto-populate self.child_rel1 and self.child_rel2.
This requires self.down to be another valid DiffLevel object.
Expand All @@ -542,7 +554,7 @@ def auto_generate_child_rel(self, klass, param):
klass=klass, parent=self.t1, child=self.down.t1, param=param)
if self.down.t2 is not notpresent:
self.t2_child_rel = ChildRelationship.create(
klass=klass, parent=self.t2, child=self.down.t2, param=param)
klass=klass, parent=self.t2, child=self.down.t2, param=param if param2 is None else param2)

@property
def all_up(self):
Expand Down Expand Up @@ -572,7 +584,7 @@ def all_down(self):
def _format_result(root, result):
return None if result is None else "{}{}".format(root, result)

def path(self, root="root", force=None, get_parent_too=False):
def path(self, root="root", force=None, get_parent_too=False, use_t2=False):
"""
A python syntax string describing how to descend to this level, assuming the top level object is called root.
Returns None if the path is not representable as a string.
Expand All @@ -594,7 +606,7 @@ def path(self, root="root", force=None, get_parent_too=False):
This will pretend all iterables are subscriptable, for example.
"""
# TODO: We could optimize this by building on top of self.up's path if it is cached there
cache_key = "{}{}".format(force, get_parent_too)
cache_key = "{}{}{}".format(force, get_parent_too, use_t2)
if cache_key in self._path:
cached = self._path[cache_key]
if get_parent_too:
Expand All @@ -609,7 +621,10 @@ def path(self, root="root", force=None, get_parent_too=False):
# traverse all levels of this relationship
while level and level is not self:
# get this level's relationship object
next_rel = level.t1_child_rel or level.t2_child_rel # next relationship object to get a formatted param from
if(use_t2):
next_rel = level.t2_child_rel
else:
next_rel = level.t1_child_rel or level.t2_child_rel # next relationship object to get a formatted param from

# t1 and t2 both are empty
if next_rel is None:
Expand Down Expand Up @@ -642,6 +657,7 @@ def create_deeper(self,
new_t2,
child_relationship_class,
child_relationship_param=None,
child_relationship_param2=None,
report_type=None):
"""
Start a new comparison level and correctly link it to this one.
Expand All @@ -653,14 +669,15 @@ def create_deeper(self,
new_t1, new_t2, down=None, up=level, report_type=report_type)
level.down = result
level.auto_generate_child_rel(
klass=child_relationship_class, param=child_relationship_param)
klass=child_relationship_class, param=child_relationship_param, param2=child_relationship_param2)
return result

def branch_deeper(self,
new_t1,
new_t2,
child_relationship_class,
child_relationship_param=None,
child_relationship_param2=None,
report_type=None):
"""
Branch this comparison: Do not touch this comparison line, but create a new one with exactly the same content,
Expand All @@ -670,7 +687,7 @@ def branch_deeper(self,
"""
branch = self.copy()
return branch.create_deeper(new_t1, new_t2, child_relationship_class,
child_relationship_param, report_type)
child_relationship_param, child_relationship_param2, report_type)

def copy(self):
"""
Expand Down
Loading