diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 111ad20f8a..f93ba8b720 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -53,6 +53,7 @@ _GEMINI_2_FLASH_EXP_ENDPOINT = "gemini-2.0-flash-exp" _GEMINI_2_FLASH_001_ENDPOINT = "gemini-2.0-flash-001" _GEMINI_2_FLASH_LITE_001_ENDPOINT = "gemini-2.0-flash-lite-001" +_GEMINI_2P5_PRO_PREVIEW_ENDPOINT = "gemini-2.5-pro-preview-05-06" _GEMINI_ENDPOINTS = ( _GEMINI_1P5_PRO_PREVIEW_ENDPOINT, _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, @@ -104,6 +105,12 @@ _REMOVE_DEFAULT_MODEL_WARNING = "Since upgrading the default model can cause unintended breakages, the default model will be removed in BigFrames 3.0. Please supply an explicit model to avoid this message." +_GEMINI_MULTIMODAL_MODEL_NOT_SUPPORTED_WARNING = ( + "The model '{model_name}' may not be fully supported by GeminiTextGenerator for Multimodal prompts. " + "GeminiTextGenerator is known to support the following models for Multimodal prompts: {known_models}. " + "If you proceed with '{model_name}', it might not work as expected or could lead to errors with multimodal inputs." +) + @log_adapter.class_logger class TextEmbeddingGenerator(base.RetriableRemotePredictor): @@ -540,9 +547,10 @@ def fit( GeminiTextGenerator: Fitted estimator. """ if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: - raise NotImplementedError( + msg = exceptions.format_message( "fit() only supports gemini-1.5-pro-002, or gemini-1.5-flash-002 model." ) + warnings.warn(msg) X, y = utils.batch_convert_to_dataframe(X, y) @@ -651,9 +659,13 @@ def predict( if prompt: if self.model_name not in _GEMINI_MULTIMODAL_ENDPOINTS: - raise NotImplementedError( - f"GeminiTextGenerator only supports model_name {', '.join(_GEMINI_MULTIMODAL_ENDPOINTS)} for Multimodal prompt." + msg = exceptions.format_message( + _GEMINI_MULTIMODAL_MODEL_NOT_SUPPORTED_WARNING.format( + model_name=self.model_name, + known_models=", ".join(_GEMINI_MULTIMODAL_ENDPOINTS), + ) ) + warnings.warn(msg) df_prompt = X[[X.columns[0]]].rename( columns={X.columns[0]: "bigframes_placeholder_col"} @@ -750,9 +762,10 @@ def score( raise RuntimeError("A model must be fitted before score") if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS: - raise NotImplementedError( + msg = exceptions.format_message( "score() only supports gemini-1.5-pro-002, gemini-1.5-flash-2, gemini-2.0-flash-001, and gemini-2.0-flash-lite-001 model." ) + warnings.warn(msg) X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session) diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index 83c665a50b..a6366273fe 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -66,6 +66,7 @@ llm._GEMINI_2_FLASH_EXP_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_2_FLASH_001_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_2_FLASH_LITE_001_ENDPOINT: llm.GeminiTextGenerator, + llm._GEMINI_2P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator, llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator, llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator,