21
21
number_to_string , datetime_normalize , KEY_TO_VAL_STR , booleans ,
22
22
np_ndarray , get_numpy_ndarray_rows , OrderedSetPlus , RepeatedTimer ,
23
23
TEXT_VIEW , TREE_VIEW , DELTA_VIEW ,
24
- np , get_truncate_datetime , dict_ )
24
+ np , get_truncate_datetime , dict_ , CannotCompare )
25
25
from deepdiff .serialization import SerializationMixin
26
26
from deepdiff .distance import DistanceMixin
27
27
from deepdiff .model import (
@@ -139,6 +139,7 @@ def __init__(self,
139
139
truncate_datetime = None ,
140
140
verbose_level = 1 ,
141
141
view = TEXT_VIEW ,
142
+ iterable_compare_func = None ,
142
143
_original_type = None ,
143
144
_parameters = None ,
144
145
_shared_parameters = None ,
@@ -154,7 +155,8 @@ def __init__(self,
154
155
"view, hasher, hashes, max_passes, max_diffs, "
155
156
"cutoff_distance_for_pairs, cutoff_intersection_for_pairs, log_frequency_in_sec, cache_size, "
156
157
"cache_tuning_sample_size, get_deep_distance, group_by, cache_purge_level, "
157
- "math_epsilon, _original_type, _parameters and _shared_parameters." ) % ', ' .join (kwargs .keys ()))
158
+ "math_epsilon, iterable_compare_func, _original_type, "
159
+ "_parameters and _shared_parameters." ) % ', ' .join (kwargs .keys ()))
158
160
159
161
if _parameters :
160
162
self .__dict__ .update (_parameters )
@@ -182,6 +184,7 @@ def __init__(self,
182
184
self .ignore_string_case = ignore_string_case
183
185
self .exclude_obj_callback = exclude_obj_callback
184
186
self .number_to_string = number_to_string_func or number_to_string
187
+ self .iterable_compare_func = iterable_compare_func
185
188
self .ignore_private_variables = ignore_private_variables
186
189
self .ignore_nan_inequality = ignore_nan_inequality
187
190
self .hasher = hasher
@@ -558,6 +561,53 @@ def _diff_iterable(self, level, parents_ids=frozenset(), _original_type=None):
558
561
else :
559
562
self ._diff_iterable_in_order (level , parents_ids , _original_type = _original_type )
560
563
564
+ def _compare_in_order (self , level ):
565
+ """
566
+ Default compare if `iterable_compare_func` is not provided.
567
+ This will compare in sequence order.
568
+ """
569
+
570
+ return [((i , i ), (x , y )) for i , (x , y ) in enumerate (
571
+ zip_longest (
572
+ level .t1 , level .t2 , fillvalue = ListItemRemovedOrAdded ))]
573
+
574
+ def _get_matching_pairs (self , level ):
575
+ """
576
+ Given a level get matching pairs. This returns list of two tuples in the form:
577
+ [
578
+ (t1 index, t2 index), (t1 item, t2 item)
579
+ ]
580
+
581
+ This will compare using the passed in `iterable_compare_func` if available.
582
+ Default it to compare in order
583
+ """
584
+
585
+ if (self .iterable_compare_func is None ):
586
+ # Match in order if there is no compare function provided
587
+ return self ._compare_in_order (level )
588
+ try :
589
+ matches = []
590
+ y_matched = set ()
591
+ for i , x in enumerate (level .t1 ):
592
+ x_found = False
593
+ for j , y in enumerate (level .t2 ):
594
+
595
+ if (self .iterable_compare_func (x , y )):
596
+ y_matched .add (id (y ))
597
+ matches .append (((i , j ), (x , y )))
598
+ x_found = True
599
+ break
600
+
601
+ if (not x_found ):
602
+ matches .append (((i , - 1 ), (x , ListItemRemovedOrAdded )))
603
+ for j , y in enumerate (level .t2 ):
604
+ if (id (y ) not in y_matched ):
605
+ matches .append (((- 1 , j ), (ListItemRemovedOrAdded , y )))
606
+ return matches
607
+ except CannotCompare :
608
+ return self ._compare_in_order (level )
609
+
610
+
561
611
def _diff_iterable_in_order (self , level , parents_ids = frozenset (), _original_type = None ):
562
612
# We're handling both subscriptable and non-subscriptable iterables. Which one is it?
563
613
subscriptable = self ._iterables_subscriptable (level .t1 , level .t2 )
@@ -566,42 +616,53 @@ def _diff_iterable_in_order(self, level, parents_ids=frozenset(), _original_type
566
616
else :
567
617
child_relationship_class = NonSubscriptableIterableRelationship
568
618
569
- for i , (x , y ) in enumerate (
570
- zip_longest (
571
- level .t1 , level .t2 , fillvalue = ListItemRemovedOrAdded )):
572
619
573
- if self ._count_diff () is StopIteration :
574
- return # pragma: no cover. This is already covered for addition.
575
620
576
- if y is ListItemRemovedOrAdded : # item removed completely
577
- change_level = level .branch_deeper (
578
- x ,
579
- notpresent ,
580
- child_relationship_class = child_relationship_class ,
581
- child_relationship_param = i )
582
- self ._report_result ('iterable_item_removed' , change_level )
621
+ for (i , j ), (x , y ) in self ._get_matching_pairs (level ):
622
+ if self ._count_diff () is StopIteration :
623
+ return # pragma: no cover. This is already covered for addition.
583
624
584
- elif x is ListItemRemovedOrAdded : # new item added
585
- change_level = level .branch_deeper (
586
- notpresent ,
587
- y ,
588
- child_relationship_class = child_relationship_class ,
589
- child_relationship_param = i )
590
- self ._report_result ('iterable_item_added' , change_level )
591
-
592
- else : # check if item value has changed
593
- item_id = id (x )
594
- if parents_ids and item_id in parents_ids :
595
- continue
596
- parents_ids_added = add_to_frozen_set (parents_ids , item_id )
597
-
598
- # Go one level deeper
599
- next_level = level .branch_deeper (
600
- x ,
601
- y ,
602
- child_relationship_class = child_relationship_class ,
603
- child_relationship_param = i )
604
- self ._diff (next_level , parents_ids_added )
625
+ if y is ListItemRemovedOrAdded : # item removed completely
626
+ change_level = level .branch_deeper (
627
+ x ,
628
+ notpresent ,
629
+ child_relationship_class = child_relationship_class ,
630
+ child_relationship_param = i )
631
+ self ._report_result ('iterable_item_removed' , change_level )
632
+
633
+ elif x is ListItemRemovedOrAdded : # new item added
634
+ change_level = level .branch_deeper (
635
+ notpresent ,
636
+ y ,
637
+ child_relationship_class = child_relationship_class ,
638
+ child_relationship_param = j )
639
+ self ._report_result ('iterable_item_added' , change_level )
640
+
641
+ else : # check if item value has changed
642
+
643
+ if (i != j ):
644
+ # Item moved
645
+ change_level = level .branch_deeper (
646
+ x ,
647
+ y ,
648
+ child_relationship_class = child_relationship_class ,
649
+ child_relationship_param = i ,
650
+ child_relationship_param2 = j
651
+ )
652
+ self ._report_result ('iterable_item_moved' , change_level )
653
+
654
+ item_id = id (x )
655
+ if parents_ids and item_id in parents_ids :
656
+ continue
657
+ parents_ids_added = add_to_frozen_set (parents_ids , item_id )
658
+
659
+ # Go one level deeper
660
+ next_level = level .branch_deeper (
661
+ x ,
662
+ y ,
663
+ child_relationship_class = child_relationship_class ,
664
+ child_relationship_param = i )
665
+ self ._diff (next_level , parents_ids_added )
605
666
606
667
def _diff_str (self , level ):
607
668
"""Compare strings"""
0 commit comments