diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 77dc1d2b0f..7fa0e236eb 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -46,6 +46,14 @@ ) _GEMINI_PRO_ENDPOINT = "gemini-pro" +_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514" +_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514" +_GEMINI_ENDPOINTS = ( + _GEMINI_PRO_ENDPOINT, + _GEMINI_1P5_PRO_PREVIEW_ENDPOINT, + _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, +) + _ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" _ML_EMBED_TEXT_STATUS = "ml_embed_text_status" @@ -547,13 +555,16 @@ def to_gbq( class GeminiTextGenerator(base.BaseEstimator): """Gemini text generator LLM model. - .. note:: - This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the - Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" - and might have limited support. For more information, see the launch stage descriptions - (https://cloud.google.com/products#product-launch-stages). - Args: + model_name (str, Default to "gemini-pro"): + The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514". Default to "gemini-pro". + + .. note:: + "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the + Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" + and might have limited support. For more information, see the launch stage descriptions + (https://cloud.google.com/products#product-launch-stages). + session (bigframes.Session or None): BQ session to create the model. If None, use the global default session. connection_name (str or None): @@ -565,9 +576,13 @@ class GeminiTextGenerator(base.BaseEstimator): def __init__( self, *, + model_name: Literal[ + "gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514" + ] = "gemini-pro", session: Optional[bigframes.Session] = None, connection_name: Optional[str] = None, ): + self.model_name = model_name self.session = session or bpd.get_global_session() self._bq_connection_manager = self.session.bqconnectionmanager @@ -601,7 +616,12 @@ def _create_bqml_model(self): iam_role="aiplatform.user", ) - options = {"endpoint": _GEMINI_PRO_ENDPOINT} + if self.model_name not in _GEMINI_ENDPOINTS: + raise ValueError( + f"Model name {self.model_name} is not supported. We only support {', '.join(_GEMINI_ENDPOINTS)}." + ) + + options = {"endpoint": self.model_name} return self._bqml_model_factory.create_remote_model( session=self.session, connection_name=self.connection_name, options=options @@ -613,12 +633,17 @@ def _from_bq( ) -> GeminiTextGenerator: assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" assert "remoteModelInfo" in bq_model._properties + assert "endpoint" in bq_model._properties["remoteModelInfo"] assert "connection" in bq_model._properties["remoteModelInfo"] # Parse the remote model endpoint + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] model_connection = bq_model._properties["remoteModelInfo"]["connection"] + model_endpoint = bqml_endpoint.split("/")[-1] - model = cls(session=session, connection_name=model_connection) + model = cls( + model_name=model_endpoint, session=session, connection_name=model_connection + ) model._bqml_model = core.BqmlModel(session, bq_model) return model diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 8ae8d64301..66f207929a 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -61,6 +61,8 @@ llm._EMBEDDING_GENERATOR_GECKO_ENDPOINT: llm.PaLM2TextEmbeddingGenerator, llm._EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT: llm.PaLM2TextEmbeddingGenerator, llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator, + llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, + llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, } ) diff --git a/tests/system/small/ml/conftest.py b/tests/system/small/ml/conftest.py index 33351afe45..ee96646687 100644 --- a/tests/system/small/ml/conftest.py +++ b/tests/system/small/ml/conftest.py @@ -275,11 +275,6 @@ def palm2_embedding_generator_multilingual_model( ) -@pytest.fixture(scope="session") -def gemini_text_generator_model(session, bq_connection) -> llm.GeminiTextGenerator: - return llm.GeminiTextGenerator(session=session, connection_name=bq_connection) - - @pytest.fixture(scope="session") def linear_remote_model_params() -> dict: # Pre-deployed endpoint of linear reg model in Vertex. diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 8a6874b178..20e8dd0c19 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -303,10 +303,16 @@ def test_embedding_generator_predict_series_success( assert len(value) == 768 -def test_create_gemini_text_generator_model( - gemini_text_generator_model, dataset_id, bq_connection +@pytest.mark.parametrize( + "model_name", + ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), +) +def test_create_load_gemini_text_generator_model( + dataset_id, model_name, session, bq_connection ): - # Model creation doesn't return error + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) assert gemini_text_generator_model is not None assert gemini_text_generator_model._bqml_model is not None @@ -316,12 +322,25 @@ def test_create_gemini_text_generator_model( ) assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name assert reloaded_model.connection_name == bq_connection - - + assert reloaded_model.model_name == model_name + + +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + # TODO(garrrettwu): enable when cl/637028077 is in prod. + # "gemini-1.5-flash-preview-0514" + ), +) @pytest.mark.flaky(retries=2) def test_gemini_text_generator_predict_default_params_success( - gemini_text_generator_model, llm_text_df + llm_text_df, model_name, session, bq_connection ): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) df = gemini_text_generator_model.predict(llm_text_df).to_pandas() assert df.shape == (3, 4) assert "ml_generate_text_llm_result" in df.columns @@ -329,10 +348,17 @@ def test_gemini_text_generator_predict_default_params_success( assert all(series.str.len() > 20) +@pytest.mark.parametrize( + "model_name", + ("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"), +) @pytest.mark.flaky(retries=2) def test_gemini_text_generator_predict_with_params_success( - gemini_text_generator_model, llm_text_df + llm_text_df, model_name, session, bq_connection ): + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) df = gemini_text_generator_model.predict( llm_text_df, temperature=0.5, max_output_tokens=100, top_k=20, top_p=0.5 ).to_pandas()