|
53 | 53 | _GEMINI_2_FLASH_EXP_ENDPOINT = "gemini-2.0-flash-exp"
|
54 | 54 | _GEMINI_2_FLASH_001_ENDPOINT = "gemini-2.0-flash-001"
|
55 | 55 | _GEMINI_2_FLASH_LITE_001_ENDPOINT = "gemini-2.0-flash-lite-001"
|
| 56 | +_GEMINI_2P5_PRO_PREVIEW_ENDPOINT = "gemini-2.5-pro-preview-05-06" |
56 | 57 | _GEMINI_ENDPOINTS = (
|
57 | 58 | _GEMINI_1P5_PRO_PREVIEW_ENDPOINT,
|
58 | 59 | _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
|
|
104 | 105 |
|
105 | 106 | _REMOVE_DEFAULT_MODEL_WARNING = "Since upgrading the default model can cause unintended breakages, the default model will be removed in BigFrames 3.0. Please supply an explicit model to avoid this message."
|
106 | 107 |
|
| 108 | +_GEMINI_MULTIMODAL_MODEL_NOT_SUPPORTED_WARNING = ( |
| 109 | + "The model '{model_name}' may not be fully supported by GeminiTextGenerator for Multimodal prompts. " |
| 110 | + "GeminiTextGenerator is known to support the following models for Multimodal prompts: {known_models}. " |
| 111 | + "If you proceed with '{model_name}', it might not work as expected or could lead to errors with multimodal inputs." |
| 112 | +) |
| 113 | + |
107 | 114 |
|
108 | 115 | @log_adapter.class_logger
|
109 | 116 | class TextEmbeddingGenerator(base.RetriableRemotePredictor):
|
@@ -540,9 +547,10 @@ def fit(
|
540 | 547 | GeminiTextGenerator: Fitted estimator.
|
541 | 548 | """
|
542 | 549 | if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS:
|
543 |
| - raise NotImplementedError( |
| 550 | + msg = exceptions.format_message( |
544 | 551 | "fit() only supports gemini-1.5-pro-002, or gemini-1.5-flash-002 model."
|
545 | 552 | )
|
| 553 | + warnings.warn(msg) |
546 | 554 |
|
547 | 555 | X, y = utils.batch_convert_to_dataframe(X, y)
|
548 | 556 |
|
@@ -651,9 +659,13 @@ def predict(
|
651 | 659 |
|
652 | 660 | if prompt:
|
653 | 661 | if self.model_name not in _GEMINI_MULTIMODAL_ENDPOINTS:
|
654 |
| - raise NotImplementedError( |
655 |
| - f"GeminiTextGenerator only supports model_name {', '.join(_GEMINI_MULTIMODAL_ENDPOINTS)} for Multimodal prompt." |
| 662 | + msg = exceptions.format_message( |
| 663 | + _GEMINI_MULTIMODAL_MODEL_NOT_SUPPORTED_WARNING.format( |
| 664 | + model_name=self.model_name, |
| 665 | + known_models=", ".join(_GEMINI_MULTIMODAL_ENDPOINTS), |
| 666 | + ) |
656 | 667 | )
|
| 668 | + warnings.warn(msg) |
657 | 669 |
|
658 | 670 | df_prompt = X[[X.columns[0]]].rename(
|
659 | 671 | columns={X.columns[0]: "bigframes_placeholder_col"}
|
@@ -750,9 +762,10 @@ def score(
|
750 | 762 | raise RuntimeError("A model must be fitted before score")
|
751 | 763 |
|
752 | 764 | if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS:
|
753 |
| - raise NotImplementedError( |
| 765 | + msg = exceptions.format_message( |
754 | 766 | "score() only supports gemini-1.5-pro-002, gemini-1.5-flash-2, gemini-2.0-flash-001, and gemini-2.0-flash-lite-001 model."
|
755 | 767 | )
|
| 768 | + warnings.warn(msg) |
756 | 769 |
|
757 | 770 | X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session)
|
758 | 771 |
|
|
0 commit comments