diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index f62867cdd5..2517178d89 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -571,6 +571,8 @@ class GeminiTextGenerator(base.BaseEstimator): Connection to connect with remote service. str of the format ... If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach permission if the connection isn't fully set up. + max_iterations (Optional[int], Default to 300): + The number of steps to run when performing supervised tuning. """ def __init__( @@ -581,9 +583,11 @@ def __init__( ] = "gemini-pro", session: Optional[bigframes.Session] = None, connection_name: Optional[str] = None, + max_iterations: int = 300, ): self.model_name = model_name self.session = session or bpd.get_global_session() + self.max_iterations = max_iterations self._bq_connection_manager = self.session.bqconnectionmanager connection_name = connection_name or self.session._bq_connection @@ -647,6 +651,55 @@ def _from_bq( model._bqml_model = core.BqmlModel(session, bq_model) return model + @property + def _bqml_options(self) -> dict: + """The model options as they will be set for BQML""" + options = { + "max_iterations": self.max_iterations, + "data_split_method": "NO_SPLIT", + } + return options + + def fit( + self, + X: Union[bpd.DataFrame, bpd.Series], + y: Union[bpd.DataFrame, bpd.Series], + ) -> GeminiTextGenerator: + """Fine tune GeminiTextGenerator model. Only support "gemini-pro" model for now. + + .. note:: + + This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the + Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" + and might have limited support. For more information, see the launch stage descriptions + (https://cloud.google.com/products#product-launch-stages). + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series): + DataFrame of shape (n_samples, n_features). Training data. + y (bigframes.dataframe.DataFrame or bigframes.series.Series: + Training labels. + + Returns: + GeminiTextGenerator: Fitted estimator. + """ + if self._bqml_model.model_name.startswith("gemini-1.5"): + raise NotImplementedError("Fit is not supported for gemini-1.5 model.") + + X, y = utils.convert_to_dataframe(X, y) + + options = self._bqml_options + options["endpoint"] = "gemini-1.0-pro-002" + options["prompt_col"] = X.columns.tolist()[0] + + self._bqml_model = self._bqml_model_factory.create_llm_remote_model( + X, + y, + options=options, + connection_name=self.connection_name, + ) + return self + def predict( self, X: Union[bpd.DataFrame, bpd.Series], diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index beed884686..fd047b3ba6 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -99,3 +99,30 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index): "evaluation_status", ] assert all(col in score_result_col for col in expected_col) + + +@pytest.mark.flaky(retries=2) +def test_llm_gemini_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_df): + model = bigframes.ml.llm.GeminiTextGenerator( + model_name="gemini-pro", max_iterations=1 + ) + + X_train = llm_fine_tune_df_default_index[["prompt"]] + y_train = llm_fine_tune_df_default_index[["label"]] + model.fit(X_train, y_train) + + assert model is not None + + df = model.predict( + llm_remote_text_df["prompt"], + temperature=0.5, + max_output_tokens=100, + top_k=20, + top_p=0.5, + ).to_pandas() + assert df.shape == (3, 4) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() == 1) + + # TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept