Skip to content

Commit 31ad1d0

Browse files
authored
Merge pull request #241 from surefyresystems/compare_func
Option for custom compare function
2 parents a00f352 + d587cf7 commit 31ad1d0

10 files changed

+429
-21
lines changed

conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,21 @@ def nested_b_t2():
6262
def nested_b_result():
6363
with open(os.path.join(FIXTURES_DIR, 'nested_b_result.json')) as the_file:
6464
return json.load(the_file)
65+
66+
67+
@pytest.fixture(scope='class')
68+
def compare_func_t1():
69+
with open(os.path.join(FIXTURES_DIR, 'compare_func_t1.json')) as the_file:
70+
return json.load(the_file)
71+
72+
73+
@pytest.fixture(scope='class')
74+
def compare_func_t2():
75+
with open(os.path.join(FIXTURES_DIR, 'compare_func_t2.json')) as the_file:
76+
return json.load(the_file)
77+
78+
79+
@pytest.fixture(scope='class')
80+
def compare_func_result():
81+
with open(os.path.join(FIXTURES_DIR, 'compare_func_result.json')) as the_file:
82+
return json.load(the_file)

deepdiff/delta.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,14 @@ def _del_elem(self, parent, parent_to_obj_elem, parent_to_obj_action,
260260
value=obj, action=parent_to_obj_action)
261261

262262
def _do_iterable_item_added(self):
263-
iterable_item_added = self.diff.get('iterable_item_added')
263+
iterable_item_added = self.diff.get('iterable_item_added', {})
264+
iterable_item_moved = self.diff.get('iterable_item_moved')
265+
if iterable_item_moved:
266+
added_dict = {v["new_path"]: v["new_value"] for k, v in iterable_item_moved.items()}
267+
iterable_item_added.update(added_dict)
268+
264269
if iterable_item_added:
265-
self._do_item_added(iterable_item_added)
270+
self._do_item_added(iterable_item_added, insert=True)
266271

267272
def _do_dictionary_item_added(self):
268273
dictionary_item_added = self.diff.get('dictionary_item_added')
@@ -274,7 +279,7 @@ def _do_attribute_added(self):
274279
if attribute_added:
275280
self._do_item_added(attribute_added)
276281

277-
def _do_item_added(self, items, sort=True):
282+
def _do_item_added(self, items, sort=True, insert=False):
278283
if sort:
279284
# sorting items by their path so that the items with smaller index
280285
# are applied first (unless `sort` is `False` so that order of
@@ -289,6 +294,11 @@ def _do_item_added(self, items, sort=True):
289294
elements, parent, parent_to_obj_elem, parent_to_obj_action, obj, elem, action = elem_and_details
290295
else:
291296
continue # pragma: no cover. Due to cPython peephole optimizer, this line doesn't get covered. https://github.com/nedbat/coveragepy/issues/198
297+
298+
# Insert is only true for iterables, make sure it is a valid index.
299+
if(insert and elem < len(obj)):
300+
obj.insert(elem, None)
301+
292302
self._set_new_value(parent, parent_to_obj_elem, parent_to_obj_action,
293303
obj, elements, path, elem, action, new_value)
294304

@@ -397,10 +407,18 @@ def _do_item_removed(self, items):
397407
self._do_verify_changes(path, expected_old_value, current_old_value)
398408

399409
def _do_iterable_item_removed(self):
400-
iterable_item_removed = self.diff.get('iterable_item_removed')
410+
iterable_item_removed = self.diff.get('iterable_item_removed', {})
411+
412+
iterable_item_moved = self.diff.get('iterable_item_moved')
413+
if iterable_item_moved:
414+
# These will get added back during items_added
415+
removed_dict = {k: v["new_value"] for k, v in iterable_item_moved.items()}
416+
iterable_item_removed.update(removed_dict)
417+
401418
if iterable_item_removed:
402419
self._do_item_removed(iterable_item_removed)
403420

421+
404422
def _do_dictionary_item_removed(self):
405423
dictionary_item_removed = self.diff.get('dictionary_item_removed')
406424
if dictionary_item_removed:

deepdiff/diff.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
number_to_string, datetime_normalize, KEY_TO_VAL_STR, booleans,
2222
np_ndarray, get_numpy_ndarray_rows, OrderedSetPlus, RepeatedTimer,
2323
TEXT_VIEW, TREE_VIEW, DELTA_VIEW,
24-
np, get_truncate_datetime, dict_)
24+
np, get_truncate_datetime, dict_, CannotCompare)
2525
from deepdiff.serialization import SerializationMixin
2626
from deepdiff.distance import DistanceMixin
2727
from deepdiff.model import (
@@ -139,6 +139,7 @@ def __init__(self,
139139
truncate_datetime=None,
140140
verbose_level=1,
141141
view=TEXT_VIEW,
142+
iterable_compare_func=None,
142143
_original_type=None,
143144
_parameters=None,
144145
_shared_parameters=None,
@@ -154,7 +155,8 @@ def __init__(self,
154155
"view, hasher, hashes, max_passes, max_diffs, "
155156
"cutoff_distance_for_pairs, cutoff_intersection_for_pairs, log_frequency_in_sec, cache_size, "
156157
"cache_tuning_sample_size, get_deep_distance, group_by, cache_purge_level, "
157-
"math_epsilon, _original_type, _parameters and _shared_parameters.") % ', '.join(kwargs.keys()))
158+
"math_epsilon, iterable_compare_func, _original_type, "
159+
"_parameters and _shared_parameters.") % ', '.join(kwargs.keys()))
158160

159161
if _parameters:
160162
self.__dict__.update(_parameters)
@@ -182,6 +184,7 @@ def __init__(self,
182184
self.ignore_string_case = ignore_string_case
183185
self.exclude_obj_callback = exclude_obj_callback
184186
self.number_to_string = number_to_string_func or number_to_string
187+
self.iterable_compare_func = iterable_compare_func
185188
self.ignore_private_variables = ignore_private_variables
186189
self.ignore_nan_inequality = ignore_nan_inequality
187190
self.hasher = hasher
@@ -558,6 +561,72 @@ def _diff_iterable(self, level, parents_ids=frozenset(), _original_type=None):
558561
else:
559562
self._diff_iterable_in_order(level, parents_ids, _original_type=_original_type)
560563

564+
def _compare_in_order(self, level):
565+
"""
566+
Default compare if `iterable_compare_func` is not provided.
567+
This will compare in sequence order.
568+
"""
569+
570+
return [((i, i), (x, y)) for i, (x, y) in enumerate(
571+
zip_longest(
572+
level.t1, level.t2, fillvalue=ListItemRemovedOrAdded))]
573+
574+
def _get_matching_pairs(self, level):
575+
"""
576+
Given a level get matching pairs. This returns list of two tuples in the form:
577+
[
578+
(t1 index, t2 index), (t1 item, t2 item)
579+
]
580+
581+
This will compare using the passed in `iterable_compare_func` if available.
582+
Default it to compare in order
583+
"""
584+
585+
if(self.iterable_compare_func is None):
586+
# Match in order if there is no compare function provided
587+
return self._compare_in_order(level)
588+
try:
589+
matches = []
590+
y_matched = set()
591+
y_index_matched = set()
592+
for i, x in enumerate(level.t1):
593+
x_found = False
594+
for j, y in enumerate(level.t2):
595+
596+
if(j in y_index_matched):
597+
# This ensures a one-to-one relationship of matches from t1 to t2.
598+
# If y this index in t2 has already been matched to another x
599+
# it cannot have another match, so just continue.
600+
continue
601+
602+
if(self.iterable_compare_func(x, y, level)):
603+
deep_hash = DeepHash(y,
604+
hashes=self.hashes,
605+
apply_hash=True,
606+
**self.deephash_parameters,
607+
)
608+
y_index_matched.add(j)
609+
y_matched.add(deep_hash[y])
610+
matches.append(((i, j), (x, y)))
611+
x_found = True
612+
break
613+
614+
if(not x_found):
615+
matches.append(((i, -1), (x, ListItemRemovedOrAdded)))
616+
for j, y in enumerate(level.t2):
617+
618+
deep_hash = DeepHash(y,
619+
hashes=self.hashes,
620+
apply_hash=True,
621+
**self.deephash_parameters,
622+
)
623+
if(deep_hash[y] not in y_matched):
624+
matches.append(((-1, j), (ListItemRemovedOrAdded, y)))
625+
return matches
626+
except CannotCompare:
627+
return self._compare_in_order(level)
628+
629+
561630
def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type=None):
562631
# We're handling both subscriptable and non-subscriptable iterables. Which one is it?
563632
subscriptable = self._iterables_subscriptable(level.t1, level.t2)
@@ -566,10 +635,7 @@ def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type
566635
else:
567636
child_relationship_class = NonSubscriptableIterableRelationship
568637

569-
for i, (x, y) in enumerate(
570-
zip_longest(
571-
level.t1, level.t2, fillvalue=ListItemRemovedOrAdded)):
572-
638+
for (i, j), (x, y) in self._get_matching_pairs(level):
573639
if self._count_diff() is StopIteration:
574640
return # pragma: no cover. This is already covered for addition.
575641

@@ -586,10 +652,22 @@ def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type
586652
notpresent,
587653
y,
588654
child_relationship_class=child_relationship_class,
589-
child_relationship_param=i)
655+
child_relationship_param=j)
590656
self._report_result('iterable_item_added', change_level)
591657

592658
else: # check if item value has changed
659+
660+
if (i != j):
661+
# Item moved
662+
change_level = level.branch_deeper(
663+
x,
664+
y,
665+
child_relationship_class=child_relationship_class,
666+
child_relationship_param=i,
667+
child_relationship_param2=j
668+
)
669+
self._report_result('iterable_item_moved', change_level)
670+
593671
item_id = id(x)
594672
if parents_ids and item_id in parents_ids:
595673
continue

deepdiff/helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ def __repr__(self):
190190
__str__ = __repr__
191191

192192

193+
class CannotCompare(Exception):
194+
"""
195+
Exception when two items cannot be compared in the compare function.
196+
"""
197+
pass
198+
199+
193200
unprocessed = Unprocessed()
194201
skipped = Skipped()
195202
not_hashed = NotHashed()

deepdiff/model.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"unprocessed",
1717
"iterable_item_added",
1818
"iterable_item_removed",
19+
"iterable_item_moved",
1920
"attribute_added",
2021
"attribute_removed",
2122
"set_item_removed",
@@ -100,6 +101,7 @@ def __init__(self, tree_results=None, verbose_level=1):
100101
"unprocessed": [],
101102
"iterable_item_added": dict_(),
102103
"iterable_item_removed": dict_(),
104+
"iterable_item_moved": dict_(),
103105
"attribute_added": self.__set_or_dict(),
104106
"attribute_removed": self.__set_or_dict(),
105107
"set_item_removed": PrettyOrderedSet(),
@@ -126,6 +128,7 @@ def _from_tree_results(self, tree):
126128
self._from_tree_unprocessed(tree)
127129
self._from_tree_default(tree, 'iterable_item_added')
128130
self._from_tree_default(tree, 'iterable_item_removed')
131+
self._from_tree_iterable_item_moved(tree)
129132
self._from_tree_default(tree, 'attribute_added')
130133
self._from_tree_default(tree, 'attribute_removed')
131134
self._from_tree_set_item_removed(tree)
@@ -187,6 +190,13 @@ def _from_tree_value_changed(self, tree):
187190
if 'diff' in change.additional:
188191
the_changed.update({'diff': change.additional['diff']})
189192

193+
def _from_tree_iterable_item_moved(self, tree):
194+
if 'iterable_item_moved' in tree:
195+
for change in tree['iterable_item_moved']:
196+
the_changed = {'new_path': change.path(use_t2=True), 'new_value': change.t2}
197+
self['iterable_item_moved'][change.path(
198+
force=FORCE_DEFAULT)] = the_changed
199+
190200
def _from_tree_unprocessed(self, tree):
191201
if 'unprocessed' in tree:
192202
for change in tree['unprocessed']:
@@ -244,6 +254,7 @@ def __init__(self, tree_results=None, ignore_order=None):
244254
"values_changed": dict_(),
245255
"iterable_item_added": dict_(),
246256
"iterable_item_removed": dict_(),
257+
"iterable_item_moved": dict_(),
247258
"attribute_added": dict_(),
248259
"attribute_removed": dict_(),
249260
"set_item_removed": dict_(),
@@ -273,6 +284,7 @@ def _from_tree_results(self, tree):
273284
else:
274285
self._from_tree_default(tree, 'iterable_item_added')
275286
self._from_tree_default(tree, 'iterable_item_removed')
287+
self._from_tree_iterable_item_moved(tree)
276288
self._from_tree_default(tree, 'attribute_added')
277289
self._from_tree_default(tree, 'attribute_removed')
278290
self._from_tree_set_item_removed(tree)
@@ -528,7 +540,7 @@ def __setattr__(self, key, value):
528540
def repetition(self):
529541
return self.additional['repetition']
530542

531-
def auto_generate_child_rel(self, klass, param):
543+
def auto_generate_child_rel(self, klass, param, param2=None):
532544
"""
533545
Auto-populate self.child_rel1 and self.child_rel2.
534546
This requires self.down to be another valid DiffLevel object.
@@ -542,7 +554,7 @@ def auto_generate_child_rel(self, klass, param):
542554
klass=klass, parent=self.t1, child=self.down.t1, param=param)
543555
if self.down.t2 is not notpresent:
544556
self.t2_child_rel = ChildRelationship.create(
545-
klass=klass, parent=self.t2, child=self.down.t2, param=param)
557+
klass=klass, parent=self.t2, child=self.down.t2, param=param if param2 is None else param2)
546558

547559
@property
548560
def all_up(self):
@@ -572,7 +584,7 @@ def all_down(self):
572584
def _format_result(root, result):
573585
return None if result is None else "{}{}".format(root, result)
574586

575-
def path(self, root="root", force=None, get_parent_too=False):
587+
def path(self, root="root", force=None, get_parent_too=False, use_t2=False):
576588
"""
577589
A python syntax string describing how to descend to this level, assuming the top level object is called root.
578590
Returns None if the path is not representable as a string.
@@ -594,7 +606,7 @@ def path(self, root="root", force=None, get_parent_too=False):
594606
This will pretend all iterables are subscriptable, for example.
595607
"""
596608
# TODO: We could optimize this by building on top of self.up's path if it is cached there
597-
cache_key = "{}{}".format(force, get_parent_too)
609+
cache_key = "{}{}{}".format(force, get_parent_too, use_t2)
598610
if cache_key in self._path:
599611
cached = self._path[cache_key]
600612
if get_parent_too:
@@ -609,7 +621,10 @@ def path(self, root="root", force=None, get_parent_too=False):
609621
# traverse all levels of this relationship
610622
while level and level is not self:
611623
# get this level's relationship object
612-
next_rel = level.t1_child_rel or level.t2_child_rel # next relationship object to get a formatted param from
624+
if(use_t2):
625+
next_rel = level.t2_child_rel
626+
else:
627+
next_rel = level.t1_child_rel or level.t2_child_rel # next relationship object to get a formatted param from
613628

614629
# t1 and t2 both are empty
615630
if next_rel is None:
@@ -642,6 +657,7 @@ def create_deeper(self,
642657
new_t2,
643658
child_relationship_class,
644659
child_relationship_param=None,
660+
child_relationship_param2=None,
645661
report_type=None):
646662
"""
647663
Start a new comparison level and correctly link it to this one.
@@ -653,14 +669,15 @@ def create_deeper(self,
653669
new_t1, new_t2, down=None, up=level, report_type=report_type)
654670
level.down = result
655671
level.auto_generate_child_rel(
656-
klass=child_relationship_class, param=child_relationship_param)
672+
klass=child_relationship_class, param=child_relationship_param, param2=child_relationship_param2)
657673
return result
658674

659675
def branch_deeper(self,
660676
new_t1,
661677
new_t2,
662678
child_relationship_class,
663679
child_relationship_param=None,
680+
child_relationship_param2=None,
664681
report_type=None):
665682
"""
666683
Branch this comparison: Do not touch this comparison line, but create a new one with exactly the same content,
@@ -670,7 +687,7 @@ def branch_deeper(self,
670687
"""
671688
branch = self.copy()
672689
return branch.create_deeper(new_t1, new_t2, child_relationship_class,
673-
child_relationship_param, report_type)
690+
child_relationship_param, child_relationship_param2, report_type)
674691

675692
def copy(self):
676693
"""

0 commit comments

Comments
 (0)