Skip to content

REF: Simplify Index.copy #35592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pandas._libs import Interval, Period, algos
from pandas._libs.tslibs import conversion
from pandas._typing import ArrayLike, DtypeObj
from pandas._typing import ArrayLike, DtypeObj, Optional

from pandas.core.dtypes.base import registry
from pandas.core.dtypes.dtypes import (
Expand Down Expand Up @@ -1732,6 +1732,32 @@ def _validate_date_like_dtype(dtype) -> None:
)


def validate_all_hashable(*args, error_name: Optional[str] = None) -> None:
"""
Return None if all args are hashable, else raise a TypeError.

Parameters
----------
*args
Arguments to validate.
error_name : str, optional
The name to use if error

Raises
------
TypeError : If an argument is not hashable

Returns
-------
None
"""
if not all(is_hashable(arg) for arg in args):
if error_name:
raise TypeError(f"{error_name} must be a hashable type")
else:
raise TypeError("All elements must be hashable")


def pandas_dtype(dtype) -> DtypeObj:
"""
Convert input into a pandas only dtype object or a numpy dtype object.
Expand Down
36 changes: 22 additions & 14 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
is_timedelta64_dtype,
is_unsigned_integer_dtype,
pandas_dtype,
validate_all_hashable,
)
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.generic import (
Expand Down Expand Up @@ -812,13 +813,11 @@ def copy(self, name=None, deep=False, dtype=None, names=None):
In most cases, there should be no functional difference from using
``deep``, but if ``deep`` is passed it will attempt to deepcopy.
"""
name = self._validate_names(name=name, names=names, deep=deep)[0]
if deep:
new_index = self._shallow_copy(self._data.copy())
new_index = self._shallow_copy(self._data.copy(), name=name)
else:
new_index = self._shallow_copy()

names = self._validate_names(name=name, names=names, deep=deep)
new_index = new_index.set_names(names)
new_index = self._shallow_copy(name=name)

if dtype:
new_index = new_index.astype(dtype)
Expand Down Expand Up @@ -1186,7 +1185,7 @@ def name(self, value):
maybe_extract_name(value, None, type(self))
self._name = value

def _validate_names(self, name=None, names=None, deep: bool = False):
def _validate_names(self, name=None, names=None, deep: bool = False) -> List[Label]:
"""
Handles the quirks of having a singular 'name' parameter for general
Index and plural 'names' parameter for MultiIndex.
Expand All @@ -1196,15 +1195,25 @@ def _validate_names(self, name=None, names=None, deep: bool = False):
if names is not None and name is not None:
raise TypeError("Can only provide one of `names` and `name`")
elif names is None and name is None:
return deepcopy(self.names) if deep else self.names
new_names = deepcopy(self.names) if deep else self.names
elif names is not None:
if not is_list_like(names):
raise TypeError("Must pass list-like as `names`.")
return names
new_names = names
elif not is_list_like(name):
new_names = [name]
else:
if not is_list_like(name):
return [name]
return name
new_names = name

if len(new_names) != len(self.names):
raise ValueError(
f"Length of new names must be {len(self.names)}, got {len(new_names)}"
)

# All items in 'new_names' need to be hashable
validate_all_hashable(*new_names, error_name=f"{type(self).__name__}.name")

return new_names

def _get_names(self):
return FrozenList((self.name,))
Expand Down Expand Up @@ -1232,9 +1241,8 @@ def _set_names(self, values, level=None):

# GH 20527
# All items in 'name' need to be hashable:
for name in values:
if not is_hashable(name):
raise TypeError(f"{type(self).__name__}.name must be a hashable type")
validate_all_hashable(*values, error_name=f"{type(self).__name__}.name")

self._name = values[0]

names = property(fset=_set_names, fget=_get_names)
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,8 @@ def _shallow_copy(self, values=None, name: Label = no_default):
def copy(self, name=None, deep=False, dtype=None, names=None):
self._validate_dtype(dtype)

new_index = self._shallow_copy()
names = self._validate_names(name=name, names=names, deep=deep)
new_index = new_index.set_names(names)
name = self._validate_names(name=name, names=names, deep=deep)[0]
new_index = self._shallow_copy(name=name)
return new_index

def _minmax(self, meth: str):
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
is_list_like,
is_object_dtype,
is_scalar,
validate_all_hashable,
)
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.inference import is_hashable
Expand Down Expand Up @@ -491,8 +492,7 @@ def name(self) -> Label:

@name.setter
def name(self, value: Label) -> None:
if not is_hashable(value):
raise TypeError("Series.name must be a hashable type")
validate_all_hashable(value, error_name=f"{type(self).__name__}.name")
object.__setattr__(self, "_name", value)

@property
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,3 +746,13 @@ def test_astype_object_preserves_datetime_na(from_type):
result = astype_nansafe(arr, dtype="object")

assert isna(result)[0]


def test_validate_allhashable():
assert com.validate_all_hashable(1, "a") is None

with pytest.raises(TypeError, match="All elements must be hashable"):
com.validate_all_hashable([])

with pytest.raises(TypeError, match="list must be a hashable type"):
com.validate_all_hashable([], error_name="list")
14 changes: 14 additions & 0 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,20 @@ def test_copy_name(self, index):
s3 = s1 * s2
assert s3.index.name == "mario"

def test_name2(self, index):
# gh-35592
if isinstance(index, MultiIndex):
return

assert index.copy(name="mario").name == "mario"

with pytest.raises(ValueError, match="Length of new names must be 1, got 2"):
index.copy(name=["mario", "luigi"])

msg = f"{type(index).__name__}.name must be a hashable type"
with pytest.raises(TypeError, match=msg):
index.copy(name=[["mario"]])

def test_ensure_copied_data(self, index):
# Check the "copy" argument of each Index.__new__ is honoured
# GH12309
Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/indexes/multi/test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def test_copy_names():
assert multi_idx.names == ["MyName1", "MyName2"]
assert multi_idx3.names == ["NewName1", "NewName2"]

# gh-35592
with pytest.raises(ValueError, match="Length of new names must be 2, got 1"):
multi_idx.copy(names=["mario"])

with pytest.raises(TypeError, match="MultiIndex.name must be a hashable type"):
multi_idx.copy(names=[["mario"], ["luigi"]])


def test_names(idx, index_names):

Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def setup_method(self, method):
).sum()

# use Int64Index, to make sure things work
self.ymd.index.set_levels(
[lev.astype("i8") for lev in self.ymd.index.levels], inplace=True
self.ymd.index = self.ymd.index.set_levels(
[lev.astype("i8") for lev in self.ymd.index.levels]
)
self.ymd.index.set_names(["year", "month", "day"], inplace=True)

Expand Down