Skip to content

Commit c36aecb

Browse files
committed
address comments
1 parent 5e4e065 commit c36aecb

File tree

4 files changed

+35
-23
lines changed

4 files changed

+35
-23
lines changed

bigframes/ml/compose.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"ML.BUCKETIZE": preprocessing.KBinsDiscretizer,
4141
"ML.QUANTILE_BUCKETIZE": preprocessing.KBinsDiscretizer,
4242
"ML.LABEL_ENCODER": preprocessing.LabelEncoder,
43-
"ML.IMPUTER": preprocessing.Imputer,
43+
"ML.IMPUTER": preprocessing.SimpleImputer,
4444
}
4545
)
4646

bigframes/ml/preprocessing.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
import typing
2121
from typing import Any, cast, List, Literal, Optional, Tuple, Union
2222

23+
import bigframes_vendored.sklearn.impute._base
2324
import bigframes_vendored.sklearn.preprocessing._data
2425
import bigframes_vendored.sklearn.preprocessing._discretization
2526
import bigframes_vendored.sklearn.preprocessing._encoder
26-
import bigframes_vendored.sklearn.preprocessing._imputation
2727
import bigframes_vendored.sklearn.preprocessing._label
2828

2929
from bigframes.core import log_adapter
@@ -417,12 +417,12 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
417417

418418

419419
@log_adapter.class_logger
420-
class Imputer(
420+
class SimpleImputer(
421421
base.Transformer,
422-
bigframes_vendored.sklearn.preprocessing._imputation.Imputer,
422+
bigframes_vendored.sklearn.impute._base.SimpleImputer,
423423
):
424424

425-
__doc__ = bigframes_vendored.sklearn.preprocessing._imputation.Imputer.__doc__
425+
__doc__ = bigframes_vendored.sklearn.impute._base.SimpleImputer.__doc__
426426

427427
def __init__(
428428
self,
@@ -436,7 +436,7 @@ def __init__(
436436
# TODO(garrettwu): implement __hash__
437437
def __eq__(self, other: Any) -> bool:
438438
return (
439-
type(other) is Imputer
439+
type(other) is SimpleImputer
440440
and self.strategy == other.strategy
441441
and self._bqml_model == other._bqml_model
442442
)
@@ -467,14 +467,14 @@ def _compile_to_sql(
467467
]
468468

469469
@classmethod
470-
def _parse_from_sql(cls, sql: str) -> tuple[Imputer, str]:
471-
"""Parse SQL to tuple(Imputer, column_label).
470+
def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]:
471+
"""Parse SQL to tuple(SimpleImputer, column_label).
472472
473473
Args:
474474
sql: SQL string of format "ML.IMPUTER({col_label}, {strategy}) OVER()"
475475
476476
Returns:
477-
tuple(Imputer, column_label)"""
477+
tuple(SimpleImputer, column_label)"""
478478
s = sql[sql.find("(") + 1 : sql.find(")")]
479479
col_label, strategy = s.split(", ")
480480
return cls(strategy[1:-1]), col_label # type: ignore
@@ -483,7 +483,7 @@ def fit(
483483
self,
484484
X: Union[bpd.DataFrame, bpd.Series],
485485
y=None, # ignored
486-
) -> Imputer:
486+
) -> SimpleImputer:
487487
(X,) = utils.convert_to_dataframe(X)
488488

489489
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
@@ -765,5 +765,5 @@ def transform(self, y: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
765765
MinMaxScaler,
766766
KBinsDiscretizer,
767767
LabelEncoder,
768-
Imputer,
768+
SimpleImputer,
769769
]

tests/system/small/ml/test_preprocessing.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -548,15 +548,15 @@ def test_k_bins_discretizer_save_load_quantile(new_penguins_df, dataset_id):
548548
assert reloaded_transformer._bqml_model is not None
549549

550550

551-
def test_imputer_normalized_fit_transform_default_params():
551+
def test_simple_imputer_normalized_fit_transform_default_params():
552552
missing_df = bpd.DataFrame(
553553
{
554554
"culmen_length_mm": [39.5, 38.5, 37.9],
555555
"culmen_depth_mm": [np.nan, 17.2, 18.1],
556556
"flipper_length_mm": [np.nan, 181.0, 188.0],
557557
}
558558
)
559-
imputer = preprocessing.Imputer(strategy="mean")
559+
imputer = preprocessing.SimpleImputer(strategy="mean")
560560
result = imputer.fit_transform(
561561
missing_df[["culmen_length_mm", "culmen_depth_mm", "flipper_length_mm"]]
562562
).to_pandas()
@@ -574,15 +574,15 @@ def test_imputer_normalized_fit_transform_default_params():
574574
pd.testing.assert_frame_equal(result, expected)
575575

576576

577-
def test_imputer_series_normalizes(new_penguins_df):
577+
def test_simple_imputer_series_normalizes(new_penguins_df):
578578
missing_df = bpd.DataFrame(
579579
{
580580
"culmen_length_mm": [39.5, 38.5, 37.9],
581581
"culmen_depth_mm": [np.nan, 17.2, 18.1],
582582
"flipper_length_mm": [np.nan, 181.0, 188.0],
583583
}
584584
)
585-
imputer = preprocessing.Imputer()
585+
imputer = preprocessing.SimpleImputer()
586586
imputer.fit(missing_df["culmen_depth_mm"])
587587

588588
result = imputer.transform(missing_df["culmen_depth_mm"]).to_pandas()
@@ -599,44 +599,44 @@ def test_imputer_series_normalizes(new_penguins_df):
599599
pd.testing.assert_frame_equal(result, expected, rtol=0.1)
600600

601601

602-
def test_imputer_save_load_mean(dataset_id):
602+
def test_simple_imputer_save_load_mean(dataset_id):
603603
missing_df = bpd.DataFrame(
604604
{
605605
"culmen_length_mm": [39.5, 38.5, 37.9],
606606
"culmen_depth_mm": [np.nan, 17.2, 18.1],
607607
"flipper_length_mm": [np.nan, 181.0, 188.0],
608608
}
609609
)
610-
transformer = preprocessing.Imputer(strategy="mean")
610+
transformer = preprocessing.SimpleImputer(strategy="mean")
611611
transformer.fit(
612612
missing_df[["culmen_length_mm", "culmen_depth_mm", "flipper_length_mm"]]
613613
)
614614

615615
reloaded_transformer = transformer.to_gbq(
616616
f"{dataset_id}.temp_configured_model", replace=True
617617
)
618-
assert isinstance(reloaded_transformer, preprocessing.Imputer)
618+
assert isinstance(reloaded_transformer, preprocessing.SimpleImputer)
619619
assert reloaded_transformer.strategy == transformer.strategy
620620
assert reloaded_transformer._bqml_model is not None
621621

622622

623-
def test_imputer_save_load_most_frequent(dataset_id):
623+
def test_simple_imputer_save_load_most_frequent(dataset_id):
624624
missing_df = bpd.DataFrame(
625625
{
626626
"culmen_length_mm": [39.5, 38.5, 37.9],
627627
"culmen_depth_mm": [np.nan, 17.2, 18.1],
628628
"flipper_length_mm": [np.nan, 181.0, 188.0],
629629
}
630630
)
631-
transformer = preprocessing.Imputer(strategy="most_frequent")
631+
transformer = preprocessing.SimpleImputer(strategy="most_frequent")
632632
transformer.fit(
633633
missing_df[["culmen_length_mm", "culmen_depth_mm", "flipper_length_mm"]]
634634
)
635635

636636
reloaded_transformer = transformer.to_gbq(
637637
f"{dataset_id}.temp_configured_model", replace=True
638638
)
639-
assert isinstance(reloaded_transformer, preprocessing.Imputer)
639+
assert isinstance(reloaded_transformer, preprocessing.SimpleImputer)
640640
assert reloaded_transformer.strategy == transformer.strategy
641641
assert reloaded_transformer._bqml_model is not None
642642

third_party/bigframes_vendored/sklearn/preprocessing/_imputation.py renamed to third_party/bigframes_vendored/sklearn/impute/_base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
# Authors: Nicolas Tresegnie <[email protected]>
2+
# Sergey Feldman <[email protected]>
23
# License: BSD 3 clause
34

45
from bigframes_vendored.sklearn.base import BaseEstimator, TransformerMixin
56

67
from bigframes import constants
78

89

9-
class Imputer(BaseEstimator, TransformerMixin):
10-
"""Imputation transformer for completing missing values.
10+
class _BaseImputer(TransformerMixin, BaseEstimator):
11+
"""Base class for all imputers.
12+
13+
It adds automatically support for `add_indicator`.
14+
"""
15+
16+
17+
class SimpleImputer(_BaseImputer):
18+
"""
19+
Univariate imputer for completing missing values with simple strategies.
20+
21+
Replace missing values using a descriptive statistic (e.g. mean, median, or
22+
most frequent) along each column, or using a constant value.
1123
1224
Args:
1325
strategy ({'mean', 'median', 'most_frequent'}, default='mean'):

0 commit comments

Comments
 (0)