From 9e8ade5b640ae026d51e159ea043859e464f3ea6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 16 Jul 2021 16:59:16 -0700 Subject: [PATCH] Add categories setter to CategoricalAccessor and CategoricalIndex. --- python/pyspark/pandas/categorical.py | 18 ++++++++++--- python/pyspark/pandas/indexes/category.py | 16 ++++++++--- .../pandas/tests/indexes/test_category.py | 27 +++++++++++++++++++ .../pyspark/pandas/tests/test_categorical.py | 14 ++++++++++ 4 files changed, 69 insertions(+), 6 deletions(-) diff --git a/python/pyspark/pandas/categorical.py b/python/pyspark/pandas/categorical.py index b8cc88c95a1e1..aeba20d51046b 100644 --- a/python/pyspark/pandas/categorical.py +++ b/python/pyspark/pandas/categorical.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Optional, TYPE_CHECKING, cast +from typing import List, Optional, Union, TYPE_CHECKING, cast import pandas as pd from pandas.api.types import CategoricalDtype @@ -89,8 +89,20 @@ def categories(self) -> pd.Index: return self._dtype.categories @categories.setter - def categories(self, categories: pd.Index) -> None: - raise NotImplementedError() + def categories(self, categories: Union[pd.Index, List]) -> None: + dtype = CategoricalDtype(categories, ordered=self.ordered) + + if len(self.categories) != len(dtype.categories): + raise ValueError( + "new categories need to have the same number of items as the old categories!" + ) + + internal = self._data._psdf._internal.with_new_spark_column( + self._data._column_label, + self._data.spark.column, + field=self._data._internal.data_fields[0].copy(dtype=dtype), + ) + self._data._psdf._update_internal_frame(internal) @property def ordered(self) -> bool: diff --git a/python/pyspark/pandas/indexes/category.py b/python/pyspark/pandas/indexes/category.py index a7ad2a0fdd585..1b6588646f894 100644 --- a/python/pyspark/pandas/indexes/category.py +++ b/python/pyspark/pandas/indexes/category.py @@ -15,7 +15,7 @@ # limitations under the License. # from functools import partial -from typing import Any, Optional, cast, no_type_check +from typing import Any, List, Optional, Union, cast, no_type_check import pandas as pd from pandas.api.types import is_hashable, CategoricalDtype @@ -174,8 +174,18 @@ def categories(self) -> pd.Index: return self.dtype.categories @categories.setter - def categories(self, categories: pd.Index) -> None: - raise NotImplementedError() + def categories(self, categories: Union[pd.Index, List]) -> None: + dtype = CategoricalDtype(categories, ordered=self.ordered) + + if len(self.categories) != len(dtype.categories): + raise ValueError( + "new categories need to have the same number of items as the old categories!" + ) + + internal = self._psdf._internal.copy( + index_fields=[self._internal.index_fields[0].copy(dtype=dtype)] + ) + self._psdf._update_internal_frame(internal) @property def ordered(self) -> bool: diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index 02752ec0dd134..d04f89684e20d 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -67,6 +67,33 @@ def test_categorical_index(self): self.assert_eq(psidx.codes, pd.Index(pidx.codes)) self.assert_eq(psidx.ordered, pidx.ordered) + def test_categories_setter(self): + pdf = pd.DataFrame( + { + "a": pd.Categorical([1, 2, 3, 1, 2, 3]), + "b": pd.Categorical(["a", "b", "c", "a", "b", "c"], categories=["c", "b", "a"]), + }, + index=pd.Categorical([10, 20, 30, 20, 30, 10], categories=[30, 10, 20], ordered=True), + ) + psdf = ps.from_pandas(pdf) + + pidx = pdf.index + psidx = psdf.index + + pidx.categories = ["z", "y", "x"] + psidx.categories = ["z", "y", "x"] + if LooseVersion(pd.__version__) >= LooseVersion("1.0.5"): + self.assert_eq(pidx, psidx) + self.assert_eq(pdf, psdf) + else: + pidx = pidx.set_categories(pidx.categories) + pdf.index = pidx + self.assert_eq(pidx, psidx) + self.assert_eq(pdf, psdf) + + with self.assertRaises(ValueError): + psidx.categories = [1, 2, 3, 4] + def test_as_ordered_unordered(self): pidx = pd.CategoricalIndex(["x", "y", "z"], categories=["z", "y", "x"]) psidx = ps.from_pandas(pidx) diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py index a4c9b148305c8..fb0561d560c83 100644 --- a/python/pyspark/pandas/tests/test_categorical.py +++ b/python/pyspark/pandas/tests/test_categorical.py @@ -65,6 +65,20 @@ def test_categorical_series(self): self.assert_eq(psser.cat.codes, pser.cat.codes) self.assert_eq(psser.cat.ordered, pser.cat.ordered) + def test_categories_setter(self): + pdf, psdf = self.df_pair + + pser = pdf.a + psser = psdf.a + + pser.cat.categories = ["z", "y", "x"] + psser.cat.categories = ["z", "y", "x"] + self.assert_eq(pser, psser) + self.assert_eq(pdf, psdf) + + with self.assertRaises(ValueError): + psser.cat.categories = [1, 2, 3, 4] + def test_as_ordered_unordered(self): pdf, psdf = self.df_pair