diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index 73b8ba8dbc..d8eaec76ec 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -217,6 +217,21 @@ def generate_table( generate_table_tvf = TvfDef(generate_table, "status") + def ai_generate( + self, + input_data: bpd.DataFrame, + options: dict[str, Union[int, float, bool, Mapping]], + ) -> bpd.DataFrame: + return self._apply_ml_tvf( + input_data, + lambda source_sql: self._sql_generator.ai_generate( + source_sql=source_sql, + struct_options=options, + ), + ) + + ai_generate_tvf = TvfDef(ai_generate, "status") + def detect_anomalies( self, input_data: bpd.DataFrame, options: Mapping[str, int | float] ) -> bpd.DataFrame: diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index 2937368c92..01bac17446 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -435,3 +435,13 @@ def ai_generate_table( struct_options_sql = self.struct_options(**struct_options) return f"""SELECT * FROM AI.GENERATE_TABLE(MODEL {self._model_ref_sql()}, ({source_sql}), {struct_options_sql})""" + + def ai_generate( + self, + source_sql: str, + struct_options: Mapping[str, Union[int, float, bool, Mapping]], + ) -> str: + """Encode AI.GENERATE for BQML""" + struct_options_sql = self.struct_options(**struct_options) + return f"""SELECT * FROM AI.GENERATE(MODEL {self._model_ref_sql()}, + ({source_sql}), {struct_options_sql})""" diff --git a/bigframes/ml/utils.py b/bigframes/ml/utils.py index 5c02789576..1a51250200 100644 --- a/bigframes/ml/utils.py +++ b/bigframes/ml/utils.py @@ -191,8 +191,16 @@ def combine_training_and_evaluation_data( def standardize_type(v: str, supported_dtypes: Optional[Iterable[str]] = None): + """Standardize type string to BQML supported type string.""" t = v.lower() - t = t.replace("boolean", "bool") + if t == "boolean": + t = "bool" + elif t == "integer": + t = "int64" + elif t == "str": + t = "string" + elif t == "float": + t = "float64" if supported_dtypes: if t not in supported_dtypes: diff --git a/bigframes/testing/utils.py b/bigframes/testing/utils.py index 5da24c5b9b..fd09945ee1 100644 --- a/bigframes/testing/utils.py +++ b/bigframes/testing/utils.py @@ -49,6 +49,11 @@ "ml_generate_text_status", "prompt", ] +AI_GENERATE_OUTPUT = [ + "result", + "full_response", + "status", +] ML_GENERATE_EMBEDDING_OUTPUT = [ "ml_generate_embedding_result", "ml_generate_embedding_statistics", diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 245fead028..7644cec816 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -15,6 +15,7 @@ from typing import Callable from unittest import mock +from google.api_core import exceptions as api_core_exceptions import pandas as pd import pyarrow as pa import pytest @@ -216,7 +217,9 @@ def test_gemini_text_generator_predict_output_schema_success( llm_text_df: bpd.DataFrame, model_name, session, bq_connection ): gemini_text_generator_model = llm.GeminiTextGenerator( - model_name=model_name, connection_name=bq_connection, session=session + model_name=model_name, + connection_name=bq_connection, + session=session, ) output_schema = { "bool_output": "bool", @@ -807,3 +810,152 @@ def test_text_embedding_generator_no_default_model_warning(model_class): message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message." with pytest.warns(FutureWarning, match=message): model_class(model_name=None) + + +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) +def test_gemini_text_generator_predict_struct_schema_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection, model_name +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, + connection_name=bq_connection, + session=session, + ) + output_schema = { + "struct_output": "struct", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert set(field.name for field in df["struct_output"].dtype.pyarrow_dtype) == { + "name", + "age", + } + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) +def test_gemini_text_generator_predict_struct_schema_flat_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection, model_name +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, + connection_name=bq_connection, + session=session, + ) + output_schema = { + "name": "string", + "age": "int64", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert df["name"].dtype == pd.StringDtype(storage="pyarrow") + assert df["age"].dtype == pd.Int64Dtype() + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) +def test_gemini_text_generator_predict_array_schema_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection, model_name +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, + connection_name=bq_connection, + session=session, + ) + output_schema = { + "array_output": "array", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert df["array_output"].dtype == pd.ArrowDtype(pa.list_(pa.string())) + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) +def test_gemini_text_generator_predict_array_struct_schema_succeeds( + llm_text_df: bpd.DataFrame, session, bq_connection, model_name +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, + connection_name=bq_connection, + session=session, + ) + output_schema = { + "array_output": "array>", + } + df = gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) + assert set( + field.name for field in df["array_output"].dtype.pyarrow_dtype.value_type + ) == {"name", "age"} + + pd_df = df.to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=list(output_schema.keys()) + ["prompt", "full_response", "status"], + index=3, + col_exact=False, + ) + + +@pytest.mark.parametrize( + "model_name", + ( + "gemini-2.0-flash-001", + "gemini-2.0-flash-lite-001", + ), +) +def test_gemini_text_generator_predict_invalid_schema_fails( + llm_text_df: bpd.DataFrame, session, bq_connection, model_name +): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, + connection_name=bq_connection, + session=session, + ) + output_schema = { + "invalid_output": "invalid_type", + } + with pytest.raises(api_core_exceptions.BadRequest): + gemini_text_generator_model.predict(llm_text_df, output_schema=output_schema) diff --git a/tests/unit/ml/test_sql.py b/tests/unit/ml/test_sql.py index d605b571f3..e36d7d8acb 100644 --- a/tests/unit/ml/test_sql.py +++ b/tests/unit/ml/test_sql.py @@ -529,3 +529,20 @@ def test_ml_principal_component_info_correct( sql == """SELECT * FROM ML.PRINCIPAL_COMPONENT_INFO(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`)""" ) + + +def test_ai_generate_correct( + model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator, + mock_df: bpd.DataFrame, +): + sql = model_manipulation_sql_generator.ai_generate( + source_sql=mock_df.sql, + struct_options={"option_key1": 1, "option_key2": 2.2}, + ) + assert ( + sql + == """SELECT * FROM AI.GENERATE(MODEL `my_project_id`.`my_dataset_id`.`my_model_id`, + (input_X_y_sql), STRUCT( + 1 AS `option_key1`, + 2.2 AS `option_key2`))""" + )