Skip to content

Commit e8085a7

Browse files
authored
BUG: crosstab with duplicate column or index labels (#37997)
1 parent dfa9d6f commit e8085a7

File tree

3 files changed

+105
-21
lines changed

3 files changed

+105
-21
lines changed

doc/source/whatsnew/v1.2.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ Groupby/resample/rolling
730730
Reshaping
731731
^^^^^^^^^
732732

733+
- Bug in :meth:`DataFrame.crosstab` was returning incorrect results on inputs with duplicate row names, duplicate column names or duplicate names between row and column labels (:issue:`22529`)
733734
- Bug in :meth:`DataFrame.pivot_table` with ``aggfunc='count'`` or ``aggfunc='sum'`` returning ``NaN`` for missing categories when pivoted on a ``Categorical``. Now returning ``0`` (:issue:`31422`)
734735
- Bug in :func:`concat` and :class:`DataFrame` constructor where input index names are not preserved in some cases (:issue:`13475`)
735736
- Bug in func :meth:`crosstab` when using multiple columns with ``margins=True`` and ``normalize=True`` (:issue:`35144`)

pandas/core/reshape/pivot.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
List,
66
Optional,
77
Sequence,
8+
Set,
89
Tuple,
910
Union,
1011
cast,
@@ -578,29 +579,37 @@ def crosstab(
578579
b 0 1 0
579580
c 0 0 0
580581
"""
582+
if values is None and aggfunc is not None:
583+
raise ValueError("aggfunc cannot be used without values.")
584+
585+
if values is not None and aggfunc is None:
586+
raise ValueError("values cannot be used without an aggfunc.")
587+
581588
index = com.maybe_make_list(index)
582589
columns = com.maybe_make_list(columns)
583590

584-
rownames = _get_names(index, rownames, prefix="row")
585-
colnames = _get_names(columns, colnames, prefix="col")
586-
587591
common_idx = None
588592
pass_objs = [x for x in index + columns if isinstance(x, (ABCSeries, ABCDataFrame))]
589593
if pass_objs:
590594
common_idx = get_objs_combined_axis(pass_objs, intersect=True, sort=False)
591595

592-
data: Dict = {}
593-
data.update(zip(rownames, index))
594-
data.update(zip(colnames, columns))
595-
596-
if values is None and aggfunc is not None:
597-
raise ValueError("aggfunc cannot be used without values.")
596+
rownames = _get_names(index, rownames, prefix="row")
597+
colnames = _get_names(columns, colnames, prefix="col")
598598

599-
if values is not None and aggfunc is None:
600-
raise ValueError("values cannot be used without an aggfunc.")
599+
# duplicate names mapped to unique names for pivot op
600+
(
601+
rownames_mapper,
602+
unique_rownames,
603+
colnames_mapper,
604+
unique_colnames,
605+
) = _build_names_mapper(rownames, colnames)
601606

602607
from pandas import DataFrame
603608

609+
data = {
610+
**dict(zip(unique_rownames, index)),
611+
**dict(zip(unique_colnames, columns)),
612+
}
604613
df = DataFrame(data, index=common_idx)
605614
original_df_cols = df.columns
606615

@@ -613,8 +622,8 @@ def crosstab(
613622

614623
table = df.pivot_table(
615624
["__dummy__"],
616-
index=rownames,
617-
columns=colnames,
625+
index=unique_rownames,
626+
columns=unique_colnames,
618627
margins=margins,
619628
margins_name=margins_name,
620629
dropna=dropna,
@@ -633,6 +642,9 @@ def crosstab(
633642
table, normalize=normalize, margins=margins, margins_name=margins_name
634643
)
635644

645+
table = table.rename_axis(index=rownames_mapper, axis=0)
646+
table = table.rename_axis(columns=colnames_mapper, axis=1)
647+
636648
return table
637649

638650

@@ -731,3 +743,57 @@ def _get_names(arrs, names, prefix: str = "row"):
731743
names = list(names)
732744

733745
return names
746+
747+
748+
def _build_names_mapper(
749+
rownames: List[str], colnames: List[str]
750+
) -> Tuple[Dict[str, str], List[str], Dict[str, str], List[str]]:
751+
"""
752+
Given the names of a DataFrame's rows and columns, returns a set of unique row
753+
and column names and mappers that convert to original names.
754+
755+
A row or column name is replaced if it is duplicate among the rows of the inputs,
756+
among the columns of the inputs or between the rows and the columns.
757+
758+
Paramters
759+
---------
760+
rownames: list[str]
761+
colnames: list[str]
762+
763+
Returns
764+
-------
765+
Tuple(Dict[str, str], List[str], Dict[str, str], List[str])
766+
767+
rownames_mapper: dict[str, str]
768+
a dictionary with new row names as keys and original rownames as values
769+
unique_rownames: list[str]
770+
a list of rownames with duplicate names replaced by dummy names
771+
colnames_mapper: dict[str, str]
772+
a dictionary with new column names as keys and original column names as values
773+
unique_colnames: list[str]
774+
a list of column names with duplicate names replaced by dummy names
775+
776+
"""
777+
778+
def get_duplicates(names):
779+
seen: Set = set()
780+
return {name for name in names if name not in seen}
781+
782+
shared_names = set(rownames).intersection(set(colnames))
783+
dup_names = get_duplicates(rownames) | get_duplicates(colnames) | shared_names
784+
785+
rownames_mapper = {
786+
f"row_{i}": name for i, name in enumerate(rownames) if name in dup_names
787+
}
788+
unique_rownames = [
789+
f"row_{i}" if name in dup_names else name for i, name in enumerate(rownames)
790+
]
791+
792+
colnames_mapper = {
793+
f"col_{i}": name for i, name in enumerate(colnames) if name in dup_names
794+
}
795+
unique_colnames = [
796+
f"col_{i}" if name in dup_names else name for i, name in enumerate(colnames)
797+
]
798+
799+
return rownames_mapper, unique_rownames, colnames_mapper, unique_colnames

pandas/tests/reshape/test_crosstab.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -535,15 +535,32 @@ def test_crosstab_with_numpy_size(self):
535535
)
536536
tm.assert_frame_equal(result, expected)
537537

538-
def test_crosstab_dup_index_names(self):
539-
# GH 13279
540-
s = Series(range(3), name="foo")
538+
def test_crosstab_duplicate_names(self):
539+
# GH 13279 / 22529
540+
541+
s1 = Series(range(3), name="foo")
542+
s2_foo = Series(range(1, 4), name="foo")
543+
s2_bar = Series(range(1, 4), name="bar")
544+
s3 = Series(range(3), name="waldo")
545+
546+
# check result computed with duplicate labels against
547+
# result computed with unique labels, then relabelled
548+
mapper = {"bar": "foo"}
549+
550+
# duplicate row, column labels
551+
result = crosstab(s1, s2_foo)
552+
expected = crosstab(s1, s2_bar).rename_axis(columns=mapper, axis=1)
553+
tm.assert_frame_equal(result, expected)
554+
555+
# duplicate row, unique column labels
556+
result = crosstab([s1, s2_foo], s3)
557+
expected = crosstab([s1, s2_bar], s3).rename_axis(index=mapper, axis=0)
558+
tm.assert_frame_equal(result, expected)
559+
560+
# unique row, duplicate column labels
561+
result = crosstab(s3, [s1, s2_foo])
562+
expected = crosstab(s3, [s1, s2_bar]).rename_axis(columns=mapper, axis=1)
541563

542-
result = crosstab(s, s)
543-
expected_index = Index(range(3), name="foo")
544-
expected = DataFrame(
545-
np.eye(3, dtype=np.int64), index=expected_index, columns=expected_index
546-
)
547564
tm.assert_frame_equal(result, expected)
548565

549566
@pytest.mark.parametrize("names", [["a", ("b", "c")], [("a", "b"), "c"]])

0 commit comments

Comments
 (0)