Skip to content

Commit aa7b7b0

Browse files
committed
address more comments
1 parent c36aecb commit aa7b7b0

File tree

12 files changed

+328
-202
lines changed

12 files changed

+328
-202
lines changed

bigframes/ml/compose.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from bigframes import constants
3030
from bigframes.core import log_adapter
31-
from bigframes.ml import base, core, globals, preprocessing, utils
31+
from bigframes.ml import base, core, globals, impute, preprocessing, utils
3232
import bigframes.pandas as bpd
3333

3434
_BQML_TRANSFROM_TYPE_MAPPING = types.MappingProxyType(
@@ -40,7 +40,7 @@
4040
"ML.BUCKETIZE": preprocessing.KBinsDiscretizer,
4141
"ML.QUANTILE_BUCKETIZE": preprocessing.KBinsDiscretizer,
4242
"ML.LABEL_ENCODER": preprocessing.LabelEncoder,
43-
"ML.IMPUTER": preprocessing.SimpleImputer,
43+
"ML.IMPUTER": impute.SimpleImputer,
4444
}
4545
)
4646

@@ -59,7 +59,7 @@ def __init__(
5959
transformers: List[
6060
Tuple[
6161
str,
62-
preprocessing.PreprocessingType,
62+
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
6363
Union[str, List[str]],
6464
]
6565
],
@@ -74,12 +74,14 @@ def __init__(
7474
@property
7575
def transformers_(
7676
self,
77-
) -> List[Tuple[str, preprocessing.PreprocessingType, str,]]:
77+
) -> List[
78+
Tuple[str, Union[preprocessing.PreprocessingType, impute.SimpleImputer], str]
79+
]:
7880
"""The collection of transformers as tuples of (name, transformer, column)."""
7981
result: List[
8082
Tuple[
8183
str,
82-
preprocessing.PreprocessingType,
84+
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
8385
str,
8486
]
8587
] = []
@@ -108,7 +110,7 @@ def _extract_from_bq_model(
108110
transformers: List[
109111
Tuple[
110112
str,
111-
preprocessing.PreprocessingType,
113+
Union[preprocessing.PreprocessingType, impute.SimpleImputer],
112114
Union[str, List[str]],
113115
]
114116
] = []
@@ -153,7 +155,9 @@ def camel_to_snake(name):
153155

154156
def _merge(
155157
self, bq_model: bigquery.Model
156-
) -> Union[ColumnTransformer, preprocessing.PreprocessingType,]:
158+
) -> Union[
159+
ColumnTransformer, Union[preprocessing.PreprocessingType, impute.SimpleImputer]
160+
]:
157161
"""Try to merge the column transformer to a simple transformer. Depends on all the columns in bq_model are transformed with the same transformer."""
158162
transformers = self.transformers_
159163

bigframes/ml/impute.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Transformers for missing value imputation. This module is styled after
16+
scikit-learn's preprocessing module: https://scikit-learn.org/stable/modules/preprocessing.html."""
17+
18+
from __future__ import annotations
19+
20+
import typing
21+
from typing import Any, List, Literal, Optional, Tuple, Union
22+
23+
import bigframes_vendored.sklearn.impute._base
24+
25+
from bigframes.core import log_adapter
26+
from bigframes.ml import base, core, globals, utils
27+
import bigframes.pandas as bpd
28+
29+
30+
@log_adapter.class_logger
31+
class SimpleImputer(
32+
base.Transformer,
33+
bigframes_vendored.sklearn.impute._base.SimpleImputer,
34+
):
35+
36+
__doc__ = bigframes_vendored.sklearn.impute._base.SimpleImputer.__doc__
37+
38+
def __init__(
39+
self,
40+
strategy: Literal["mean", "median", "most_frequent"] = "mean",
41+
):
42+
self.strategy = strategy
43+
self._bqml_model: Optional[core.BqmlModel] = None
44+
self._bqml_model_factory = globals.bqml_model_factory()
45+
self._base_sql_generator = globals.base_sql_generator()
46+
47+
# TODO(garrettwu): implement __hash__
48+
def __eq__(self, other: Any) -> bool:
49+
return (
50+
type(other) is SimpleImputer
51+
and self.strategy == other.strategy
52+
and self._bqml_model == other._bqml_model
53+
)
54+
55+
def _compile_to_sql(
56+
self,
57+
columns: List[str],
58+
X=None,
59+
) -> List[Tuple[str, str]]:
60+
"""Compile this transformer to a list of SQL expressions that can be included in
61+
a BQML TRANSFORM clause
62+
63+
Args:
64+
columns:
65+
a list of column names to transform
66+
X:
67+
The Dataframe with training data.
68+
69+
Returns: a list of tuples of (sql_expression, output_name)"""
70+
return [
71+
(
72+
self._base_sql_generator.ml_imputer(
73+
column, self.strategy, f"imputer_{column}"
74+
),
75+
f"imputer_{column}",
76+
)
77+
for column in columns
78+
]
79+
80+
@classmethod
81+
def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]:
82+
"""Parse SQL to tuple(SimpleImputer, column_label).
83+
84+
Args:
85+
sql: SQL string of format "ML.IMPUTER({col_label}, {strategy}) OVER()"
86+
87+
Returns:
88+
tuple(SimpleImputer, column_label)"""
89+
s = sql[sql.find("(") + 1 : sql.find(")")]
90+
col_label, strategy = s.split(", ")
91+
return cls(strategy[1:-1]), col_label # type: ignore
92+
93+
def fit(
94+
self,
95+
X: Union[bpd.DataFrame, bpd.Series],
96+
y=None, # ignored
97+
) -> SimpleImputer:
98+
(X,) = utils.convert_to_dataframe(X)
99+
100+
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
101+
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
102+
103+
self._bqml_model = self._bqml_model_factory.create_model(
104+
X,
105+
options={"model_type": "transform_only"},
106+
transforms=transform_sqls,
107+
)
108+
109+
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
110+
self._output_names = [name for _, name in compiled_transforms]
111+
return self
112+
113+
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
114+
if not self._bqml_model:
115+
raise RuntimeError("Must be fitted before transform")
116+
117+
(X,) = utils.convert_to_dataframe(X)
118+
119+
df = self._bqml_model.transform(X)
120+
return typing.cast(
121+
bpd.DataFrame,
122+
df[self._output_names],
123+
)

bigframes/ml/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ensemble,
3030
forecasting,
3131
imported,
32+
impute,
3233
linear_model,
3334
llm,
3435
pipeline,
@@ -84,6 +85,7 @@ def from_bq(
8485
pipeline.Pipeline,
8586
compose.ColumnTransformer,
8687
preprocessing.PreprocessingType,
88+
impute.SimpleImputer,
8789
]:
8890
"""Load a BQML model to BigQuery DataFrames ML.
8991

bigframes/ml/pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@
2626
import bigframes
2727
import bigframes.constants as constants
2828
from bigframes.core import log_adapter
29-
from bigframes.ml import base, compose, forecasting, loader, preprocessing, utils
29+
from bigframes.ml import (
30+
base,
31+
compose,
32+
forecasting,
33+
impute,
34+
loader,
35+
preprocessing,
36+
utils,
37+
)
3038
import bigframes.pandas as bpd
3139

3240

@@ -56,6 +64,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
5664
preprocessing.MinMaxScaler,
5765
preprocessing.KBinsDiscretizer,
5866
preprocessing.LabelEncoder,
67+
impute.SimpleImputer,
5968
),
6069
):
6170
self._transform = transform

bigframes/ml/preprocessing.py

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import typing
2121
from typing import Any, cast, List, Literal, Optional, Tuple, Union
2222

23-
import bigframes_vendored.sklearn.impute._base
2423
import bigframes_vendored.sklearn.preprocessing._data
2524
import bigframes_vendored.sklearn.preprocessing._discretization
2625
import bigframes_vendored.sklearn.preprocessing._encoder
@@ -416,102 +415,6 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
416415
)
417416

418417

419-
@log_adapter.class_logger
420-
class SimpleImputer(
421-
base.Transformer,
422-
bigframes_vendored.sklearn.impute._base.SimpleImputer,
423-
):
424-
425-
__doc__ = bigframes_vendored.sklearn.impute._base.SimpleImputer.__doc__
426-
427-
def __init__(
428-
self,
429-
strategy: Literal["mean", "median", "most_frequent"] = "mean",
430-
):
431-
self.strategy = strategy
432-
self._bqml_model: Optional[core.BqmlModel] = None
433-
self._bqml_model_factory = globals.bqml_model_factory()
434-
self._base_sql_generator = globals.base_sql_generator()
435-
436-
# TODO(garrettwu): implement __hash__
437-
def __eq__(self, other: Any) -> bool:
438-
return (
439-
type(other) is SimpleImputer
440-
and self.strategy == other.strategy
441-
and self._bqml_model == other._bqml_model
442-
)
443-
444-
def _compile_to_sql(
445-
self,
446-
columns: List[str],
447-
X=None,
448-
) -> List[Tuple[str, str]]:
449-
"""Compile this transformer to a list of SQL expressions that can be included in
450-
a BQML TRANSFORM clause
451-
452-
Args:
453-
columns:
454-
a list of column names to transform
455-
X:
456-
The Dataframe with training data.
457-
458-
Returns: a list of tuples of (sql_expression, output_name)"""
459-
return [
460-
(
461-
self._base_sql_generator.ml_imputer(
462-
column, self.strategy, f"imputer_{column}"
463-
),
464-
f"imputer_{column}",
465-
)
466-
for column in columns
467-
]
468-
469-
@classmethod
470-
def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]:
471-
"""Parse SQL to tuple(SimpleImputer, column_label).
472-
473-
Args:
474-
sql: SQL string of format "ML.IMPUTER({col_label}, {strategy}) OVER()"
475-
476-
Returns:
477-
tuple(SimpleImputer, column_label)"""
478-
s = sql[sql.find("(") + 1 : sql.find(")")]
479-
col_label, strategy = s.split(", ")
480-
return cls(strategy[1:-1]), col_label # type: ignore
481-
482-
def fit(
483-
self,
484-
X: Union[bpd.DataFrame, bpd.Series],
485-
y=None, # ignored
486-
) -> SimpleImputer:
487-
(X,) = utils.convert_to_dataframe(X)
488-
489-
compiled_transforms = self._compile_to_sql(X.columns.tolist(), X)
490-
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
491-
492-
self._bqml_model = self._bqml_model_factory.create_model(
493-
X,
494-
options={"model_type": "transform_only"},
495-
transforms=transform_sqls,
496-
)
497-
498-
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
499-
self._output_names = [name for _, name in compiled_transforms]
500-
return self
501-
502-
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
503-
if not self._bqml_model:
504-
raise RuntimeError("Must be fitted before transform")
505-
506-
(X,) = utils.convert_to_dataframe(X)
507-
508-
df = self._bqml_model.transform(X)
509-
return typing.cast(
510-
bpd.DataFrame,
511-
df[self._output_names],
512-
)
513-
514-
515418
@log_adapter.class_logger
516419
class OneHotEncoder(
517420
base.Transformer,
@@ -765,5 +668,4 @@ def transform(self, y: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
765668
MinMaxScaler,
766669
KBinsDiscretizer,
767670
LabelEncoder,
768-
SimpleImputer,
769671
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
bigframes.ml.impute
2+
==========================
3+
4+
.. automodule:: bigframes.ml.impute
5+
:members:
6+
:inherited-members:
7+
:undoc-members:

docs/reference/bigframes.ml/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ API Reference
1919

2020
imported
2121

22+
impute
23+
2224
linear_model
2325

2426
llm

docs/templates/toc.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@
134134
- name: XGBoostModel
135135
uid: bigframes.ml.imported.XGBoostModel
136136
name: imported
137+
- items:
138+
- name: Overview
139+
uid: bigframes.ml.impute
140+
- name: SimpleImputer
141+
uid: bigframes.ml.impute.SimpleImputer
142+
name: impute
137143
- items:
138144
- name: Overview
139145
uid: bigframes.ml.linear_model

0 commit comments

Comments
 (0)