Skip to content

Commit 409921b

Browse files
committed
This allows for a custom compare function to compare iterable items.
The new compare function takes two items of an iterable and should return True it matching, False if no match, and raise CannotCompare if unable to compare the two items. The default behavior is the same as before which is comparing each item in order. If the compare function raises CannotCompare then behavior reverts back to the default in order. This also introduces a new report key which is `iterable_item_moved` to track if iterable items have moved.
1 parent 5b66b96 commit 409921b

File tree

5 files changed

+314
-48
lines changed

5 files changed

+314
-48
lines changed

deepdiff/delta.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,14 @@ def _del_elem(self, parent, parent_to_obj_elem, parent_to_obj_action,
260260
value=obj, action=parent_to_obj_action)
261261

262262
def _do_iterable_item_added(self):
263-
iterable_item_added = self.diff.get('iterable_item_added')
263+
iterable_item_added = self.diff.get('iterable_item_added', {})
264+
iterable_item_moved = self.diff.get('iterable_item_moved')
265+
if iterable_item_moved:
266+
added_dict = {v["new_path"]: v["new_value"] for k, v in iterable_item_moved.items()}
267+
iterable_item_added.update(added_dict)
268+
264269
if iterable_item_added:
265-
self._do_item_added(iterable_item_added)
270+
self._do_item_added(iterable_item_added, insert=True)
266271

267272
def _do_dictionary_item_added(self):
268273
dictionary_item_added = self.diff.get('dictionary_item_added')
@@ -274,7 +279,7 @@ def _do_attribute_added(self):
274279
if attribute_added:
275280
self._do_item_added(attribute_added)
276281

277-
def _do_item_added(self, items, sort=True):
282+
def _do_item_added(self, items, sort=True, insert=False):
278283
if sort:
279284
# sorting items by their path so that the items with smaller index
280285
# are applied first (unless `sort` is `False` so that order of
@@ -289,6 +294,11 @@ def _do_item_added(self, items, sort=True):
289294
elements, parent, parent_to_obj_elem, parent_to_obj_action, obj, elem, action = elem_and_details
290295
else:
291296
continue # pragma: no cover. Due to cPython peephole optimizer, this line doesn't get covered. https://github.com/nedbat/coveragepy/issues/198
297+
298+
# Insert is only true for iterables, make sure it is a valid index.
299+
if(insert and elem < len(obj)):
300+
obj.insert(elem, None)
301+
292302
self._set_new_value(parent, parent_to_obj_elem, parent_to_obj_action,
293303
obj, elements, path, elem, action, new_value)
294304

@@ -397,10 +407,18 @@ def _do_item_removed(self, items):
397407
self._do_verify_changes(path, expected_old_value, current_old_value)
398408

399409
def _do_iterable_item_removed(self):
400-
iterable_item_removed = self.diff.get('iterable_item_removed')
410+
iterable_item_removed = self.diff.get('iterable_item_removed', {})
411+
412+
iterable_item_moved = self.diff.get('iterable_item_moved')
413+
if iterable_item_moved:
414+
# These will get added back during items_added
415+
removed_dict = {k: v["new_value"] for k, v in iterable_item_moved.items()}
416+
iterable_item_removed.update(removed_dict)
417+
401418
if iterable_item_removed:
402419
self._do_item_removed(iterable_item_removed)
403420

421+
404422
def _do_dictionary_item_removed(self):
405423
dictionary_item_removed = self.diff.get('dictionary_item_removed')
406424
if dictionary_item_removed:

deepdiff/diff.py

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
number_to_string, datetime_normalize, KEY_TO_VAL_STR, booleans,
2222
np_ndarray, get_numpy_ndarray_rows, OrderedSetPlus, RepeatedTimer,
2323
TEXT_VIEW, TREE_VIEW, DELTA_VIEW,
24-
np, get_truncate_datetime, dict_)
24+
np, get_truncate_datetime, dict_, CannotCompare)
2525
from deepdiff.serialization import SerializationMixin
2626
from deepdiff.distance import DistanceMixin
2727
from deepdiff.model import (
@@ -139,6 +139,7 @@ def __init__(self,
139139
truncate_datetime=None,
140140
verbose_level=1,
141141
view=TEXT_VIEW,
142+
iterable_compare_func=None,
142143
_original_type=None,
143144
_parameters=None,
144145
_shared_parameters=None,
@@ -154,7 +155,8 @@ def __init__(self,
154155
"view, hasher, hashes, max_passes, max_diffs, "
155156
"cutoff_distance_for_pairs, cutoff_intersection_for_pairs, log_frequency_in_sec, cache_size, "
156157
"cache_tuning_sample_size, get_deep_distance, group_by, cache_purge_level, "
157-
"math_epsilon, _original_type, _parameters and _shared_parameters.") % ', '.join(kwargs.keys()))
158+
"math_epsilon, iterable_compare_func, _original_type, "
159+
"_parameters and _shared_parameters.") % ', '.join(kwargs.keys()))
158160

159161
if _parameters:
160162
self.__dict__.update(_parameters)
@@ -182,6 +184,7 @@ def __init__(self,
182184
self.ignore_string_case = ignore_string_case
183185
self.exclude_obj_callback = exclude_obj_callback
184186
self.number_to_string = number_to_string_func or number_to_string
187+
self.iterable_compare_func = iterable_compare_func
185188
self.ignore_private_variables = ignore_private_variables
186189
self.ignore_nan_inequality = ignore_nan_inequality
187190
self.hasher = hasher
@@ -558,6 +561,53 @@ def _diff_iterable(self, level, parents_ids=frozenset(), _original_type=None):
558561
else:
559562
self._diff_iterable_in_order(level, parents_ids, _original_type=_original_type)
560563

564+
def _compare_in_order(self, level):
565+
"""
566+
Default compare if `iterable_compare_func` is not provided.
567+
This will compare in sequence order.
568+
"""
569+
570+
return [((i, i), (x, y)) for i, (x, y) in enumerate(
571+
zip_longest(
572+
level.t1, level.t2, fillvalue=ListItemRemovedOrAdded))]
573+
574+
def _get_matching_pairs(self, level):
575+
"""
576+
Given a level get matching pairs. This returns list of two tuples in the form:
577+
[
578+
(t1 index, t2 index), (t1 item, t2 item)
579+
]
580+
581+
This will compare using the passed in `iterable_compare_func` if available.
582+
Default it to compare in order
583+
"""
584+
585+
if(self.iterable_compare_func is None):
586+
# Match in order if there is no compare function provided
587+
return self._compare_in_order(level)
588+
try:
589+
matches = []
590+
y_matched = set()
591+
for i, x in enumerate(level.t1):
592+
x_found = False
593+
for j, y in enumerate(level.t2):
594+
595+
if(self.iterable_compare_func(x, y)):
596+
y_matched.add(id(y))
597+
matches.append(((i, j), (x, y)))
598+
x_found = True
599+
break
600+
601+
if(not x_found):
602+
matches.append(((i, -1), (x, ListItemRemovedOrAdded)))
603+
for j, y in enumerate(level.t2):
604+
if(id(y) not in y_matched):
605+
matches.append(((-1, j), (ListItemRemovedOrAdded, y)))
606+
return matches
607+
except CannotCompare:
608+
return self._compare_in_order(level)
609+
610+
561611
def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type=None):
562612
# We're handling both subscriptable and non-subscriptable iterables. Which one is it?
563613
subscriptable = self._iterables_subscriptable(level.t1, level.t2)
@@ -566,42 +616,53 @@ def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type
566616
else:
567617
child_relationship_class = NonSubscriptableIterableRelationship
568618

569-
for i, (x, y) in enumerate(
570-
zip_longest(
571-
level.t1, level.t2, fillvalue=ListItemRemovedOrAdded)):
572619

573-
if self._count_diff() is StopIteration:
574-
return # pragma: no cover. This is already covered for addition.
575620

576-
if y is ListItemRemovedOrAdded: # item removed completely
577-
change_level = level.branch_deeper(
578-
x,
579-
notpresent,
580-
child_relationship_class=child_relationship_class,
581-
child_relationship_param=i)
582-
self._report_result('iterable_item_removed', change_level)
621+
for (i, j), (x, y) in self._get_matching_pairs(level):
622+
if self._count_diff() is StopIteration:
623+
return # pragma: no cover. This is already covered for addition.
583624

584-
elif x is ListItemRemovedOrAdded: # new item added
585-
change_level = level.branch_deeper(
586-
notpresent,
587-
y,
588-
child_relationship_class=child_relationship_class,
589-
child_relationship_param=i)
590-
self._report_result('iterable_item_added', change_level)
591-
592-
else: # check if item value has changed
593-
item_id = id(x)
594-
if parents_ids and item_id in parents_ids:
595-
continue
596-
parents_ids_added = add_to_frozen_set(parents_ids, item_id)
597-
598-
# Go one level deeper
599-
next_level = level.branch_deeper(
600-
x,
601-
y,
602-
child_relationship_class=child_relationship_class,
603-
child_relationship_param=i)
604-
self._diff(next_level, parents_ids_added)
625+
if y is ListItemRemovedOrAdded: # item removed completely
626+
change_level = level.branch_deeper(
627+
x,
628+
notpresent,
629+
child_relationship_class=child_relationship_class,
630+
child_relationship_param=i)
631+
self._report_result('iterable_item_removed', change_level)
632+
633+
elif x is ListItemRemovedOrAdded: # new item added
634+
change_level = level.branch_deeper(
635+
notpresent,
636+
y,
637+
child_relationship_class=child_relationship_class,
638+
child_relationship_param=j)
639+
self._report_result('iterable_item_added', change_level)
640+
641+
else: # check if item value has changed
642+
643+
if (i != j):
644+
# Item moved
645+
change_level = level.branch_deeper(
646+
x,
647+
y,
648+
child_relationship_class=child_relationship_class,
649+
child_relationship_param=i,
650+
child_relationship_param2=j
651+
)
652+
self._report_result('iterable_item_moved', change_level)
653+
654+
item_id = id(x)
655+
if parents_ids and item_id in parents_ids:
656+
continue
657+
parents_ids_added = add_to_frozen_set(parents_ids, item_id)
658+
659+
# Go one level deeper
660+
next_level = level.branch_deeper(
661+
x,
662+
y,
663+
child_relationship_class=child_relationship_class,
664+
child_relationship_param=i)
665+
self._diff(next_level, parents_ids_added)
605666

606667
def _diff_str(self, level):
607668
"""Compare strings"""

deepdiff/helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ def __repr__(self):
190190
__str__ = __repr__
191191

192192

193+
class CannotCompare(Exception):
194+
"""
195+
Exception when two items cannot be compared in the compare function.
196+
"""
197+
pass
198+
199+
193200
unprocessed = Unprocessed()
194201
skipped = Skipped()
195202
not_hashed = NotHashed()

deepdiff/model.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"unprocessed",
1717
"iterable_item_added",
1818
"iterable_item_removed",
19+
"iterable_item_moved",
1920
"attribute_added",
2021
"attribute_removed",
2122
"set_item_removed",
@@ -100,6 +101,7 @@ def __init__(self, tree_results=None, verbose_level=1):
100101
"unprocessed": [],
101102
"iterable_item_added": dict_(),
102103
"iterable_item_removed": dict_(),
104+
"iterable_item_moved": dict_(),
103105
"attribute_added": self.__set_or_dict(),
104106
"attribute_removed": self.__set_or_dict(),
105107
"set_item_removed": PrettyOrderedSet(),
@@ -126,6 +128,7 @@ def _from_tree_results(self, tree):
126128
self._from_tree_unprocessed(tree)
127129
self._from_tree_default(tree, 'iterable_item_added')
128130
self._from_tree_default(tree, 'iterable_item_removed')
131+
self._from_tree_iterable_item_moved(tree)
129132
self._from_tree_default(tree, 'attribute_added')
130133
self._from_tree_default(tree, 'attribute_removed')
131134
self._from_tree_set_item_removed(tree)
@@ -187,6 +190,13 @@ def _from_tree_value_changed(self, tree):
187190
if 'diff' in change.additional:
188191
the_changed.update({'diff': change.additional['diff']})
189192

193+
def _from_tree_iterable_item_moved(self, tree):
194+
if 'iterable_item_moved' in tree:
195+
for change in tree['iterable_item_moved']:
196+
the_changed = {'new_path': change.path(use_t2=True), 'new_value': change.t2}
197+
self['iterable_item_moved'][change.path(
198+
force=FORCE_DEFAULT)] = the_changed
199+
190200
def _from_tree_unprocessed(self, tree):
191201
if 'unprocessed' in tree:
192202
for change in tree['unprocessed']:
@@ -244,6 +254,7 @@ def __init__(self, tree_results=None, ignore_order=None):
244254
"values_changed": dict_(),
245255
"iterable_item_added": dict_(),
246256
"iterable_item_removed": dict_(),
257+
"iterable_item_moved": dict_(),
247258
"attribute_added": dict_(),
248259
"attribute_removed": dict_(),
249260
"set_item_removed": dict_(),
@@ -273,6 +284,7 @@ def _from_tree_results(self, tree):
273284
else:
274285
self._from_tree_default(tree, 'iterable_item_added')
275286
self._from_tree_default(tree, 'iterable_item_removed')
287+
self._from_tree_iterable_item_moved(tree)
276288
self._from_tree_default(tree, 'attribute_added')
277289
self._from_tree_default(tree, 'attribute_removed')
278290
self._from_tree_set_item_removed(tree)
@@ -528,7 +540,7 @@ def __setattr__(self, key, value):
528540
def repetition(self):
529541
return self.additional['repetition']
530542

531-
def auto_generate_child_rel(self, klass, param):
543+
def auto_generate_child_rel(self, klass, param, param2=None):
532544
"""
533545
Auto-populate self.child_rel1 and self.child_rel2.
534546
This requires self.down to be another valid DiffLevel object.
@@ -542,7 +554,7 @@ def auto_generate_child_rel(self, klass, param):
542554
klass=klass, parent=self.t1, child=self.down.t1, param=param)
543555
if self.down.t2 is not notpresent:
544556
self.t2_child_rel = ChildRelationship.create(
545-
klass=klass, parent=self.t2, child=self.down.t2, param=param)
557+
klass=klass, parent=self.t2, child=self.down.t2, param=param if param2 is None else param2)
546558

547559
@property
548560
def all_up(self):
@@ -572,7 +584,7 @@ def all_down(self):
572584
def _format_result(root, result):
573585
return None if result is None else "{}{}".format(root, result)
574586

575-
def path(self, root="root", force=None, get_parent_too=False):
587+
def path(self, root="root", force=None, get_parent_too=False, use_t2=False):
576588
"""
577589
A python syntax string describing how to descend to this level, assuming the top level object is called root.
578590
Returns None if the path is not representable as a string.
@@ -594,7 +606,7 @@ def path(self, root="root", force=None, get_parent_too=False):
594606
This will pretend all iterables are subscriptable, for example.
595607
"""
596608
# TODO: We could optimize this by building on top of self.up's path if it is cached there
597-
cache_key = "{}{}".format(force, get_parent_too)
609+
cache_key = "{}{}{}".format(force, get_parent_too, use_t2)
598610
if cache_key in self._path:
599611
cached = self._path[cache_key]
600612
if get_parent_too:
@@ -609,7 +621,10 @@ def path(self, root="root", force=None, get_parent_too=False):
609621
# traverse all levels of this relationship
610622
while level and level is not self:
611623
# get this level's relationship object
612-
next_rel = level.t1_child_rel or level.t2_child_rel # next relationship object to get a formatted param from
624+
if(use_t2):
625+
next_rel = level.t2_child_rel
626+
else:
627+
next_rel = level.t1_child_rel or level.t2_child_rel # next relationship object to get a formatted param from
613628

614629
# t1 and t2 both are empty
615630
if next_rel is None:
@@ -642,6 +657,7 @@ def create_deeper(self,
642657
new_t2,
643658
child_relationship_class,
644659
child_relationship_param=None,
660+
child_relationship_param2=None,
645661
report_type=None):
646662
"""
647663
Start a new comparison level and correctly link it to this one.
@@ -653,14 +669,15 @@ def create_deeper(self,
653669
new_t1, new_t2, down=None, up=level, report_type=report_type)
654670
level.down = result
655671
level.auto_generate_child_rel(
656-
klass=child_relationship_class, param=child_relationship_param)
672+
klass=child_relationship_class, param=child_relationship_param, param2=child_relationship_param2)
657673
return result
658674

659675
def branch_deeper(self,
660676
new_t1,
661677
new_t2,
662678
child_relationship_class,
663679
child_relationship_param=None,
680+
child_relationship_param2=None,
664681
report_type=None):
665682
"""
666683
Branch this comparison: Do not touch this comparison line, but create a new one with exactly the same content,
@@ -670,7 +687,7 @@ def branch_deeper(self,
670687
"""
671688
branch = self.copy()
672689
return branch.create_deeper(new_t1, new_t2, child_relationship_class,
673-
child_relationship_param, report_type)
690+
child_relationship_param, child_relationship_param2, report_type)
674691

675692
def copy(self):
676693
"""

0 commit comments

Comments
 (0)