Skip to content

Commit a68312d

Browse files
committed
address comments
1 parent 9182de1 commit a68312d

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

bigframes/ml/llm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def score(
740740
"text_generation", "classification", "summarization", "question_answering"
741741
] = "text_generation",
742742
) -> bpd.DataFrame:
743-
"""Calculate evaluation metrics of the model.
743+
"""Calculate evaluation metrics of the model. Only "gemini-pro" model is supported for now.
744744
745745
.. note::
746746
@@ -772,10 +772,9 @@ def score(
772772
if not self._bqml_model:
773773
raise RuntimeError("A model must be fitted before score")
774774

775+
# TODO(ashleyxu): Support gemini-1.5 when the rollout is ready. b/344891364.
775776
if self._bqml_model.model_name.startswith("gemini-1.5"):
776-
raise NotImplementedError(
777-
"Score is not supported for gemini-1.5 model. Please use gemini-pro-1.0 model instead."
778-
)
777+
raise NotImplementedError("Score is not supported for gemini-1.5 model.")
779778

780779
X, y = utils.convert_to_dataframe(X, y)
781780

tests/system/load/test_llm.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717

1818
import bigframes.ml.llm
19+
from tests.system import utils
1920

2021

2122
@pytest.fixture(scope="session")
@@ -114,7 +115,6 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index):
114115
assert all(col in score_result_col for col in expected_col)
115116

116117

117-
@pytest.mark.flaky(retries=2)
118118
def test_llm_gemini_pro_score(llm_fine_tune_df_default_index):
119119
model = bigframes.ml.llm.GeminiTextGenerator(model_name="gemini-pro")
120120

@@ -123,18 +123,19 @@ def test_llm_gemini_pro_score(llm_fine_tune_df_default_index):
123123
X=llm_fine_tune_df_default_index[["prompt"]],
124124
y=llm_fine_tune_df_default_index[["label"]],
125125
).to_pandas()
126-
score_result_col = score_result.columns.to_list()
127-
expected_col = [
128-
"bleu4_score",
129-
"rouge-l_precision",
130-
"rouge-l_recall",
131-
"rouge-l_f1_score",
132-
"evaluation_status",
133-
]
134-
assert all(col in score_result_col for col in expected_col)
126+
utils.check_pandas_df_schema_and_index(
127+
score_result,
128+
columns=[
129+
"bleu4_score",
130+
"rouge-l_precision",
131+
"rouge-l_recall",
132+
"rouge-l_f1_score",
133+
"evaluation_status",
134+
],
135+
index=1,
136+
)
135137

136138

137-
@pytest.mark.flaky(retries=2)
138139
def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index):
139140
model = bigframes.ml.llm.GeminiTextGenerator(model_name="gemini-pro")
140141

@@ -144,12 +145,14 @@ def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index):
144145
y=llm_fine_tune_df_default_index["label"],
145146
task_type="classification",
146147
).to_pandas()
147-
score_result_col = score_result.columns.to_list()
148-
expected_col = [
149-
"precision",
150-
"recall",
151-
"f1_score",
152-
"label",
153-
"evaluation_status",
154-
]
155-
assert all(col in score_result_col for col in expected_col)
148+
utils.check_pandas_df_schema_and_index(
149+
score_result,
150+
columns=[
151+
"precision",
152+
"recall",
153+
"f1_score",
154+
"label",
155+
"evaluation_status",
156+
],
157+
index=1,
158+
)

0 commit comments

Comments
 (0)