11import functools
22from logging import Logger
3- from typing import Any , Dict , List , Mapping , Optional , Tuple , Union , cast
3+ from typing import Any , Dict , List , Mapping , Optional , Set , Tuple , Union , cast
44
55import numpy as np
66
2121from autoPyTorch .data .utils import (
2222 DatasetCompressionInputType ,
2323 DatasetDTypeContainerType ,
24+ has_object_columns ,
2425 reduce_dataset_size_if_too_large
2526)
2627from autoPyTorch .utils .common import ispandas
@@ -105,9 +106,10 @@ def __init__(
105106 logger : Optional [Union [PicklableClientLogger , Logger ]] = None ,
106107 dataset_compression : Optional [Mapping [str , Any ]] = None ,
107108 ) -> None :
109+ super ().__init__ (logger )
108110 self ._dataset_compression = dataset_compression
109111 self ._reduced_dtype : Optional [DatasetDTypeContainerType ] = None
110- super (). __init__ ( logger )
112+ self . all_nan_columns : Optional [ Set [ str ]] = None
111113
112114 @staticmethod
113115 def _comparator (cmp1 : str , cmp2 : str ) -> int :
@@ -132,6 +134,41 @@ def _comparator(cmp1: str, cmp2: str) -> int:
132134 idx1 , idx2 = choices .index (cmp1 ), choices .index (cmp2 )
133135 return idx1 - idx2
134136
137+ def _convert_all_nan_columns_to_numeric (self , X : pd .DataFrame , fit : bool = False ) -> pd .DataFrame :
138+ """
139+ Convert columns whose values were all nan in the training dataset to numeric.
140+
141+ Args:
142+ X (pd.DataFrame):
143+ The data to transform.
144+ fit (bool):
145+ Whether this call is the fit to X or the transform using pre-fitted transformer.
146+ """
147+ if not fit and self .all_nan_columns is None :
148+ raise ValueError ('_fit must be called before calling transform' )
149+
150+ if fit :
151+ all_nan_columns = X .columns [X .isna ().all ()]
152+ else :
153+ assert self .all_nan_columns is not None
154+ all_nan_columns = list (self .all_nan_columns )
155+
156+ for col in all_nan_columns :
157+ X [col ] = np .nan
158+ X [col ] = pd .to_numeric (X [col ])
159+ if fit and len (self .dtypes ):
160+ self .dtypes [list (X .columns ).index (col )] = X [col ].dtype
161+
162+ if has_object_columns (X .dtypes .values ):
163+ X = self .infer_objects (X )
164+
165+ if fit :
166+ # TODO: Check how to integrate below
167+ # self.dtypes = [dt.name for dt in X.dtypes]
168+ self .all_nan_columns = set (all_nan_columns )
169+
170+ return X
171+
135172 def _fit (
136173 self ,
137174 X : SupportedFeatTypes ,
@@ -158,22 +195,7 @@ def _fit(
158195
159196 if ispandas (X ) and not issparse (X ):
160197 X = cast (pd .DataFrame , X )
161- # Treat a column with all instances a NaN as numerical
162- # This will prevent doing encoding to a categorical column made completely
163- # out of nan values -- which will trigger a fail, as encoding is not supported
164- # with nan values.
165- # Columns that are completely made of NaN values are provided to the pipeline
166- # so that later stages decide how to handle them
167- if np .any (pd .isnull (X )):
168- for column in X .columns :
169- if X [column ].isna ().all ():
170- X [column ] = pd .to_numeric (X [column ])
171- # Also note this change in self.dtypes
172- if len (self .dtypes ) != 0 :
173- self .dtypes [list (X .columns ).index (column )] = X [column ].dtype
174-
175- if not X .select_dtypes (include = 'object' ).empty :
176- X = self .infer_objects (X )
198+ X = self ._convert_all_nan_columns_to_numeric (X , fit = True )
177199
178200 self .transformed_columns , self .feat_type = self ._get_columns_to_encode (X )
179201
@@ -247,14 +269,7 @@ def transform(
247269 X = self .numpy_array_to_pandas (X )
248270
249271 if ispandas (X ) and not issparse (X ):
250- if np .any (pd .isnull (X )):
251- for column in X .columns :
252- if X [column ].isna ().all ():
253- X [column ] = pd .to_numeric (X [column ])
254-
255- # Also remove the object dtype for new data
256- if not X .select_dtypes (include = 'object' ).empty :
257- X = self .infer_objects (X )
272+ X = self ._convert_all_nan_columns_to_numeric (X )
258273
259274 # Check the data here so we catch problems on new test data
260275 self ._check_data (X )
@@ -369,7 +384,7 @@ def _check_data(
369384 X = cast (pd .DataFrame , X )
370385
371386 # Handle objects if possible
372- if not X . select_dtypes ( include = 'object' ). empty :
387+ if has_object_columns ( X . dtypes . values ) :
373388 X = self .infer_objects (X )
374389
375390 # Define the column to be encoded here as the feature validator is fitted once
0 commit comments