Skip to content

Commit 9e8ade5

Browse files
committed
Add categories setter to CategoricalAccessor and CategoricalIndex.
1 parent 376fadc commit 9e8ade5

File tree

4 files changed

+69
-6
lines changed

4 files changed

+69
-6
lines changed

python/pyspark/pandas/categorical.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from typing import Optional, TYPE_CHECKING, cast
17+
from typing import List, Optional, Union, TYPE_CHECKING, cast
1818

1919
import pandas as pd
2020
from pandas.api.types import CategoricalDtype
@@ -89,8 +89,20 @@ def categories(self) -> pd.Index:
8989
return self._dtype.categories
9090

9191
@categories.setter
92-
def categories(self, categories: pd.Index) -> None:
93-
raise NotImplementedError()
92+
def categories(self, categories: Union[pd.Index, List]) -> None:
93+
dtype = CategoricalDtype(categories, ordered=self.ordered)
94+
95+
if len(self.categories) != len(dtype.categories):
96+
raise ValueError(
97+
"new categories need to have the same number of items as the old categories!"
98+
)
99+
100+
internal = self._data._psdf._internal.with_new_spark_column(
101+
self._data._column_label,
102+
self._data.spark.column,
103+
field=self._data._internal.data_fields[0].copy(dtype=dtype),
104+
)
105+
self._data._psdf._update_internal_frame(internal)
94106

95107
@property
96108
def ordered(self) -> bool:

python/pyspark/pandas/indexes/category.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717
from functools import partial
18-
from typing import Any, Optional, cast, no_type_check
18+
from typing import Any, List, Optional, Union, cast, no_type_check
1919

2020
import pandas as pd
2121
from pandas.api.types import is_hashable, CategoricalDtype
@@ -174,8 +174,18 @@ def categories(self) -> pd.Index:
174174
return self.dtype.categories
175175

176176
@categories.setter
177-
def categories(self, categories: pd.Index) -> None:
178-
raise NotImplementedError()
177+
def categories(self, categories: Union[pd.Index, List]) -> None:
178+
dtype = CategoricalDtype(categories, ordered=self.ordered)
179+
180+
if len(self.categories) != len(dtype.categories):
181+
raise ValueError(
182+
"new categories need to have the same number of items as the old categories!"
183+
)
184+
185+
internal = self._psdf._internal.copy(
186+
index_fields=[self._internal.index_fields[0].copy(dtype=dtype)]
187+
)
188+
self._psdf._update_internal_frame(internal)
179189

180190
@property
181191
def ordered(self) -> bool:

python/pyspark/pandas/tests/indexes/test_category.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,33 @@ def test_categorical_index(self):
6767
self.assert_eq(psidx.codes, pd.Index(pidx.codes))
6868
self.assert_eq(psidx.ordered, pidx.ordered)
6969

70+
def test_categories_setter(self):
71+
pdf = pd.DataFrame(
72+
{
73+
"a": pd.Categorical([1, 2, 3, 1, 2, 3]),
74+
"b": pd.Categorical(["a", "b", "c", "a", "b", "c"], categories=["c", "b", "a"]),
75+
},
76+
index=pd.Categorical([10, 20, 30, 20, 30, 10], categories=[30, 10, 20], ordered=True),
77+
)
78+
psdf = ps.from_pandas(pdf)
79+
80+
pidx = pdf.index
81+
psidx = psdf.index
82+
83+
pidx.categories = ["z", "y", "x"]
84+
psidx.categories = ["z", "y", "x"]
85+
if LooseVersion(pd.__version__) >= LooseVersion("1.0.5"):
86+
self.assert_eq(pidx, psidx)
87+
self.assert_eq(pdf, psdf)
88+
else:
89+
pidx = pidx.set_categories(pidx.categories)
90+
pdf.index = pidx
91+
self.assert_eq(pidx, psidx)
92+
self.assert_eq(pdf, psdf)
93+
94+
with self.assertRaises(ValueError):
95+
psidx.categories = [1, 2, 3, 4]
96+
7097
def test_as_ordered_unordered(self):
7198
pidx = pd.CategoricalIndex(["x", "y", "z"], categories=["z", "y", "x"])
7299
psidx = ps.from_pandas(pidx)

python/pyspark/pandas/tests/test_categorical.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ def test_categorical_series(self):
6565
self.assert_eq(psser.cat.codes, pser.cat.codes)
6666
self.assert_eq(psser.cat.ordered, pser.cat.ordered)
6767

68+
def test_categories_setter(self):
69+
pdf, psdf = self.df_pair
70+
71+
pser = pdf.a
72+
psser = psdf.a
73+
74+
pser.cat.categories = ["z", "y", "x"]
75+
psser.cat.categories = ["z", "y", "x"]
76+
self.assert_eq(pser, psser)
77+
self.assert_eq(pdf, psdf)
78+
79+
with self.assertRaises(ValueError):
80+
psser.cat.categories = [1, 2, 3, 4]
81+
6882
def test_as_ordered_unordered(self):
6983
pdf, psdf = self.df_pair
7084

0 commit comments

Comments
 (0)