Skip to content

feat: support AI.GENERATE #2010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ def generate_table(

generate_table_tvf = TvfDef(generate_table, "status")

def ai_generate(
self,
input_data: bpd.DataFrame,
options: dict[str, Union[int, float, bool, Mapping]],
) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
lambda source_sql: self._sql_generator.ai_generate(
source_sql=source_sql,
struct_options=options,
),
)

ai_generate_tvf = TvfDef(ai_generate, "status")

def detect_anomalies(
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
) -> bpd.DataFrame:
Expand Down
10 changes: 10 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,13 @@ def ai_generate_table(
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM AI.GENERATE_TABLE(MODEL {self._model_ref_sql()},
({source_sql}), {struct_options_sql})"""

def ai_generate(
self,
source_sql: str,
struct_options: Mapping[str, Union[int, float, bool, Mapping]],
) -> str:
"""Encode AI.GENERATE for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM AI.GENERATE(MODEL {self._model_ref_sql()},
({source_sql}), {struct_options_sql})"""
10 changes: 9 additions & 1 deletion bigframes/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,16 @@ def combine_training_and_evaluation_data(


def standardize_type(v: str, supported_dtypes: Optional[Iterable[str]] = None):
"""Standardize type string to BQML supported type string."""
t = v.lower()
t = t.replace("boolean", "bool")
if t == "boolean":
t = "bool"
elif t == "integer":
t = "int64"
elif t == "str":
t = "string"
elif t == "float":
t = "float64"

if supported_dtypes:
if t not in supported_dtypes:
Expand Down
5 changes: 5 additions & 0 deletions bigframes/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
"ml_generate_text_status",
"prompt",
]
AI_GENERATE_OUTPUT = [
"result",
"full_response",
"status",
]
ML_GENERATE_EMBEDDING_OUTPUT = [
"ml_generate_embedding_result",
"ml_generate_embedding_statistics",
Expand Down
154 changes: 153 additions & 1 deletion tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Callable
from unittest import mock

from google.api_core import exceptions as api_core_exceptions
import pandas as pd
import pyarrow as pa
import pytest
Expand Down Expand Up @@ -216,7 +217,9 @@ def test_gemini_text_generator_predict_output_schema_success(
llm_text_df: bpd.DataFrame, model_name, session, bq_connection
):
gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name, connection_name=bq_connection, session=session
model_name=model_name,
connection_name=bq_connection,
session=session,
)
output_schema = {
"bool_output": "bool",
Expand Down Expand Up @@ -807,3 +810,152 @@ def test_text_embedding_generator_no_default_model_warning(model_class):
message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message."
with pytest.warns(FutureWarning, match=message):
model_class(model_name=None)


@pytest.mark.parametrize(
"model_name",
(
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-001",
),
)
def test_gemini_text_generator_predict_struct_schema_succeeds(
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
):
gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name,
connection_name=bq_connection,
session=session,
)
output_schema = {
"struct_output": "struct<name string, age int64>",
}
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
assert set(field.name for field in df["struct_output"].dtype.pyarrow_dtype) == {
"name",
"age",
}

pd_df = df.to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
index=3,
col_exact=False,
)


@pytest.mark.parametrize(
"model_name",
(
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-001",
),
)
def test_gemini_text_generator_predict_struct_schema_flat_succeeds(
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
):
gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name,
connection_name=bq_connection,
session=session,
)
output_schema = {
"name": "string",
"age": "int64",
}
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
assert df["name"].dtype == pd.StringDtype(storage="pyarrow")
assert df["age"].dtype == pd.Int64Dtype()

pd_df = df.to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
index=3,
col_exact=False,
)


@pytest.mark.parametrize(
"model_name",
(
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-001",
),
)
def test_gemini_text_generator_predict_array_schema_succeeds(
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
):
gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name,
connection_name=bq_connection,
session=session,
)
output_schema = {
"array_output": "array<string>",
}
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.string()))

pd_df = df.to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
index=3,
col_exact=False,
)


@pytest.mark.parametrize(
"model_name",
(
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-001",
),
)
def test_gemini_text_generator_predict_array_struct_schema_succeeds(
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
):
gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name,
connection_name=bq_connection,
session=session,
)
output_schema = {
"array_output": "array<struct<name string, age int64>>",
}
df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
assert set(
field.name for field in df["array_output"].dtype.pyarrow_dtype.value_type
) == {"name", "age"}

pd_df = df.to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=list(output_schema.keys()) + ["prompt", "full_response", "status"],
index=3,
col_exact=False,
)


@pytest.mark.parametrize(
"model_name",
(
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-001",
),
)
def test_gemini_text_generator_predict_invalid_schema_fails(
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
):
gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name,
connection_name=bq_connection,
session=session,
)
output_schema = {
"invalid_output": "invalid_type",
}
with pytest.raises(api_core_exceptions.BadRequest):
gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema)
17 changes: 17 additions & 0 deletions tests/unit/ml/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,20 @@ def test_ml_principal_component_info_correct(
sql
== """SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)"""
)


def test_ai_generate_correct(
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
mock_df: bpd.DataFrame,
):
sql = model_manipulation_sql_generator.ai_generate(
source_sql=mock_df.sql,
struct_options={"option_key1": 1, "option_key2": 2.2},
)
assert (
sql
== """SELECT * FROM AI.GENERATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`,
(input_X_y_sql), STRUCT(
1 AS `option_key1`,
2.2 AS `option_key2`))"""
)