1616import pandas as pd
1717from pandas .api .types import CategoricalDtype
1818
19- from ... import opcodes as OperandDef
19+ from ... import opcodes
2020from ...core import recursive_tile
21- from ...serialization .serializables import AnyField , StringField , ListField
21+ from ...serialization .serializables import AnyField , ListField , StringField
2222from ...tensor .base import sort
2323from ...utils import pd_release_version
2424from ..core import DATAFRAME_TYPE , SERIES_TYPE
2929
3030
3131class DataFrameAstype (DataFrameOperand , DataFrameOperandMixin ):
32- _op_type_ = OperandDef .ASTYPE
33-
34- _dtype_values = AnyField ("dtype_values" )
35- _errors = StringField ("errors" )
36- _category_cols = ListField ("category_cols" )
37-
38- def __init__ (
39- self ,
40- dtype_values = None ,
41- errors = None ,
42- category_cols = None ,
43- output_types = None ,
44- ** kw
45- ):
46- super ().__init__ (
47- _dtype_values = dtype_values ,
48- _errors = errors ,
49- _category_cols = category_cols ,
50- _output_types = output_types ,
51- ** kw
52- )
32+ _op_type_ = opcodes .ASTYPE
5333
54- @ property
55- def dtype_values ( self ):
56- return self . _dtype_values
34+ dtype_values = AnyField ( "dtype_values" , default = None )
35+ errors = StringField ( "errors" , default = None )
36+ category_cols = ListField ( "category_cols" , default = None )
5737
58- @property
59- def errors (self ):
60- return self ._errors
38+ def __init__ (self , output_types = None , ** kw ):
39+ super ().__init__ (_output_types = output_types , ** kw )
6140
62- @property
63- def category_cols (self ):
64- return self ._category_cols
41+ @classmethod
42+ def _is_categories_missing (cls , dtype ):
43+ return (isinstance (dtype , str ) and dtype == "category" ) or (
44+ isinstance (dtype , pd .CategoricalDtype ) and dtype .categories is None
45+ )
6546
6647 @classmethod
67- def _tile_one_chunk (cls , op ):
48+ def _tile_one_chunk (cls , op : "DataFrameAstype" ):
6849 c = op .inputs [0 ].chunks [0 ]
6950 chunk_op = op .copy ().reset_key ()
7051 chunk_params = op .outputs [0 ].params .copy ()
@@ -80,23 +61,22 @@ def _tile_one_chunk(cls, op):
8061 )
8162
8263 @classmethod
83- def _tile_series_index (cls , op ):
64+ def _tile_series_index (cls , op : "DataFrameAstype" ):
8465 in_series = op .inputs [0 ]
8566 out = op .outputs [0 ]
8667
8768 unique_chunk = None
88- if op .dtype_values == "category" and isinstance (op .dtype_values , str ):
89- unique_chunk = (yield from recursive_tile (sort (in_series .unique ()))).chunks [
90- 0
91- ]
69+ if cls ._is_categories_missing (op .dtype_values ):
70+ unique = yield from recursive_tile (sort (in_series .unique ()))
71+ unique_chunk = unique .chunks [0 ]
9272
9373 chunks = []
9474 for c in in_series .chunks :
9575 chunk_op = op .copy ().reset_key ()
9676 params = c .params .copy ()
9777 params ["dtype" ] = out .dtype
9878 if unique_chunk is not None :
99- chunk_op ._category_cols = [in_series .name ]
79+ chunk_op .category_cols = [in_series .name ]
10080 new_chunk = chunk_op .new_chunk ([c , unique_chunk ], ** params )
10181 else :
10282 new_chunk = chunk_op .new_chunk ([c ], ** params )
@@ -108,13 +88,13 @@ def _tile_series_index(cls, op):
10888 )
10989
11090 @classmethod
111- def _tile_dataframe (cls , op ):
91+ def _tile_dataframe (cls , op : "DataFrameAstype" ):
11292 in_df = op .inputs [0 ]
11393 out = op .outputs [0 ]
11494 cum_nsplits = np .cumsum ((0 ,) + in_df .nsplits [1 ])
11595 out_chunks = []
11696
117- if op .dtype_values == "category" :
97+ if cls . _is_categories_missing ( op .dtype_values ) :
11898 # all columns need unique values
11999 for c in in_df .chunks :
120100 chunk_op = op .copy ().reset_key ()
@@ -123,21 +103,19 @@ def _tile_dataframe(cls, op):
123103 cum_nsplits [c .index [1 ]] : cum_nsplits [c .index [1 ] + 1 ]
124104 ]
125105 params ["dtypes" ] = dtypes
126- chunk_op ._category_cols = list (c .columns_value .to_pandas ())
106+ chunk_op .category_cols = list (c .columns_value .to_pandas ())
127107 unique_chunks = []
128108 for col in c .columns_value .to_pandas ():
129109 unique = yield from recursive_tile (sort (in_df [col ].unique ()))
130110 unique_chunks .append (unique .chunks [0 ])
131111 new_chunk = chunk_op .new_chunk ([c ] + unique_chunks , ** params )
132112 out_chunks .append (new_chunk )
133- elif (
134- isinstance ( op . dtype_values , dict ) and "category" in op .dtype_values .values ()
113+ elif isinstance ( op . dtype_values , dict ) and any (
114+ cls . _is_categories_missing ( t ) for t in op .dtype_values .values ()
135115 ):
136116 # some columns' types are category
137117 category_cols = [
138- c
139- for c , v in op .dtype_values .items ()
140- if isinstance (v , str ) and v == "category"
118+ c for c , v in op .dtype_values .items () if cls ._is_categories_missing (v )
141119 ]
142120 unique_chunks = dict ()
143121 for col in category_cols :
@@ -156,7 +134,7 @@ def _tile_dataframe(cls, op):
156134 if col in category_cols :
157135 chunk_category_cols .append (col )
158136 chunk_unique_chunks .append (unique_chunks [col ])
159- chunk_op ._category_cols = chunk_category_cols
137+ chunk_op .category_cols = chunk_category_cols
160138 new_chunk = chunk_op .new_chunk ([c ] + chunk_unique_chunks , ** params )
161139 out_chunks .append (new_chunk )
162140 else :
@@ -176,7 +154,7 @@ def _tile_dataframe(cls, op):
176154 )
177155
178156 @classmethod
179- def tile (cls , op ):
157+ def tile (cls , op : "DataFrameAstype" ):
180158 if len (op .inputs [0 ].chunks ) == 1 :
181159 return cls ._tile_one_chunk (op )
182160 elif isinstance (op .inputs [0 ], DATAFRAME_TYPE ):
@@ -185,13 +163,14 @@ def tile(cls, op):
185163 return (yield from cls ._tile_series_index (op ))
186164
187165 @classmethod
188- def execute (cls , ctx , op ):
166+ def execute (cls , ctx , op : "DataFrameAstype" ):
189167 in_data = ctx [op .inputs [0 ].key ]
190168 if not isinstance (op .dtype_values , dict ):
191169 if op .category_cols is not None :
192170 uniques = [ctx [c .key ] for c in op .inputs [1 :]]
171+ ordered = getattr (op .dtype_values , "ordered" , None )
193172 dtype = dict (
194- (col , CategoricalDtype (unique_values ))
173+ (col , CategoricalDtype (unique_values , ordered = ordered ))
195174 for col , unique_values in zip (op .category_cols , uniques )
196175 )
197176 ctx [op .outputs [0 ].key ] = in_data .astype (dtype , errors = op .errors )
@@ -212,7 +191,10 @@ def execute(cls, ctx, op):
212191 if op .category_cols is not None :
213192 uniques = [ctx [c .key ] for c in op .inputs [1 :]]
214193 for col , unique_values in zip (op .category_cols , uniques ):
215- selected_dtype [col ] = CategoricalDtype (unique_values )
194+ ordered = getattr (selected_dtype [col ], "ordered" , None )
195+ selected_dtype [col ] = CategoricalDtype (
196+ unique_values , ordered = ordered
197+ )
216198 ctx [op .outputs [0 ].key ] = in_data .astype (selected_dtype , errors = op .errors )
217199
218200 def __call__ (self , df ):
0 commit comments