Skip to content

feat: add Linear_Regression.global_explain() #1446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
4d3e4a3
feat: add Linear_Regression.global_explain()
rey-esp Mar 3, 2025
87db2b7
remove class_level_explain param
rey-esp Mar 4, 2025
0b73343
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 5, 2025
1d0c69b
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 5, 2025
6d563d5
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
024a989
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 11, 2025
813fbd7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
e99fdd7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
82a234a
working global_explain()
rey-esp Mar 11, 2025
6dc4fac
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 11, 2025
d583a37
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
ed73f88
begin adding tests
rey-esp Mar 11, 2025
5b7a4b7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
47b9862
update snippet
rey-esp Mar 12, 2025
606a7b8
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 12, 2025
5fe306f
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
b0e8a5d
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
0664d6a
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
eb33e09
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
31d741d
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
7046dc3
complete snippet
rey-esp Mar 12, 2025
b0b9552
failing, near complete linear model test
rey-esp Mar 12, 2025
3b0526e
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 13, 2025
1ad5208
passing system test
rey-esp Mar 14, 2025
7e24b4c
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 14, 2025
c754816
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 14, 2025
d2d8b0c
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 17, 2025
a600539
Update core.py - set index to have sorted by feature
rey-esp Mar 17, 2025
7fc0cc6
Update test_linear_model.py - remove set/set index
rey-esp Mar 17, 2025
57c3d4a
Update linear_model.py - fix doc section
rey-esp Mar 17, 2025
c2c0837
Update conftest.py - rename penguins w global explain
rey-esp Mar 17, 2025
b2f8c9f
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 17, 2025
3a0c6b9
Update linear_model.py - complete doc
rey-esp Mar 17, 2025
5dac41d
lint
rey-esp Mar 17, 2025
e5f4aad
passing test and fixed expected results
rey-esp Mar 18, 2025
cd321e6
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 18, 2025
26c6a74
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 18, 2025
f47f5b7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 18, 2025
7bcade0
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 19, 2025
1379a56
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 24, 2025
0bb9186
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 24, 2025
9a2b8e4
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
c8fec3a
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
562d0b8
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
083af6c
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ def explain_predict(
),
)

def global_explain(
self, input_data: bpd.DataFrame, options: Mapping[str, bool]
) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._model_manipulation_sql_generator.ml_global_explain(
source_sql=source_sql,
struct_options=options,
),
)

def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
Expand Down
45 changes: 45 additions & 0 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,51 @@ def predict_explain(
X, options={"top_k_features": top_k_features}
)

def global_explain(
self,
X: utils.ArrayType,
*,
class_level_explain: bool = False,
) -> bpd.DataFrame:
"""
Provide explanations for an entire linear regression model.

.. note::
Output matches that of the BigQuery ML.GLOBAL_PREDICT function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series or
pandas.core.frame.DataFrame or pandas.core.series.Series):
Series or a DataFrame to explain its predictions.
class_level_explain (bool, default False):
a BOOL value that specifies whether global feature importances
are returned for each class. Applies only to non-AutoML Tables
classification models. When set to FALSE, the global feature
importance of the entire model is returned rather than that of
each class. The default value is FALSE.

Regression models and AutoML Tables classification models only
have model-level global feature importance.

Returns:
bigframes.pandas.DataFrame:
The predicted DataFrames with feature and attribution columns.
"""
if class_level_explain is not True or False:
raise ValueError(
f"`class_level_explain` must be set to `True` or `False` but is currently {class_level_explain}"
)

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.global_explain(
X, options={"class_level_explain": class_level_explain}
)

def score(
self,
X: utils.ArrayType,
Expand Down
8 changes: 8 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ def ml_explain_predict(
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
({source_sql}), {struct_options_sql})"""

def ml_global_explain(
self, source_sql: str, struct_options: Mapping[str, bool]
) -> str:
"""Encode ML.GLOBAL_EXPLAIN for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {self._model_ref_sql()},
({source_sql}), {struct_options_sql})"""

def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
"""Encode ML.FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
Expand Down
4 changes: 4 additions & 0 deletions samples/snippets/linear_regression_tutorial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def test_linear_regression(random_model_id: str) -> None:
# 3 5349.603734 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 5349.603734 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.4 15.6 221.0 5000.0 MALE
# 4 4637.165037 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 4637.165037 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.1 13.2 211.0 4500.0 FEMALE
# [END bigquery_dataframes_bqml_linear_predict_explain]
# [START bigquery_dataframes_bqml_linear_global_explain]
explain_model = model.global_explain(biscoe_data, class_level_explain=True)
# [END bigquery_dataframes_bqml_linear_global_explain]
assert explain_model is not None
assert feature_columns is not None
assert label_columns is not None
assert model is not None
Expand Down