Skip to content

Commit 934a393

Browse files
GarrettWuarwas11
authored andcommitted
test: stop checking text generation contents (#935)
1 parent 38b06e2 commit 934a393

File tree

3 files changed

+43
-59
lines changed

3 files changed

+43
-59
lines changed

tests/system/small/ml/test_core.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from datetime import datetime
1616
import typing
17-
from unittest import TestCase
1817

1918
import pandas as pd
2019
import pyarrow as pa
@@ -24,7 +23,7 @@
2423
import bigframes
2524
import bigframes.features
2625
from bigframes.ml import core
27-
import tests.system.utils
26+
from tests.system import utils
2827

2928

3029
def test_model_eval(
@@ -212,7 +211,7 @@ def test_pca_model_principal_components(penguins_bqml_pca_model: core.BqmlModel)
212211
.reset_index(drop=True)
213212
)
214213

215-
tests.system.utils.assert_pandas_df_equal_pca_components(
214+
utils.assert_pandas_df_equal_pca_components(
216215
result,
217216
expected,
218217
check_exact=False,
@@ -234,7 +233,7 @@ def test_pca_model_principal_component_info(penguins_bqml_pca_model: core.BqmlMo
234233
"cumulative_explained_variance_ratio": [0.469357, 0.651283, 0.812383],
235234
},
236235
)
237-
tests.system.utils.assert_pandas_df_equal(
236+
utils.assert_pandas_df_equal(
238237
result,
239238
expected,
240239
check_exact=False,
@@ -349,18 +348,9 @@ def test_model_generate_text(
349348
llm_text_df, options=options
350349
).to_pandas()
351350

352-
TestCase().assertSequenceEqual(df.shape, (3, 4))
353-
TestCase().assertSequenceEqual(
354-
[
355-
"ml_generate_text_llm_result",
356-
"ml_generate_text_rai_result",
357-
"ml_generate_text_status",
358-
"prompt",
359-
],
360-
df.columns.to_list(),
351+
utils.check_pandas_df_schema_and_index(
352+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
361353
)
362-
series = df["ml_generate_text_llm_result"]
363-
assert all(series.str.len() > 20)
364354

365355

366356
def test_model_forecast(time_series_bqml_arima_plus_model: core.BqmlModel):

tests/system/small/ml/test_llm.py

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,9 @@ def test_create_text_generator_model_default_session(
7474
llm_text_df = bpd.read_pandas(llm_text_pandas_df)
7575

7676
df = model.predict(llm_text_df).to_pandas()
77-
assert df.shape == (3, 4)
78-
assert "ml_generate_text_llm_result" in df.columns
79-
series = df["ml_generate_text_llm_result"]
80-
assert all(series.str.len() > 20)
77+
utils.check_pandas_df_schema_and_index(
78+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
79+
)
8180

8281

8382
@pytest.mark.flaky(retries=2)
@@ -104,10 +103,9 @@ def test_create_text_generator_32k_model_default_session(
104103
llm_text_df = bpd.read_pandas(llm_text_pandas_df)
105104

106105
df = model.predict(llm_text_df).to_pandas()
107-
assert df.shape == (3, 4)
108-
assert "ml_generate_text_llm_result" in df.columns
109-
series = df["ml_generate_text_llm_result"]
110-
assert all(series.str.len() > 20)
106+
utils.check_pandas_df_schema_and_index(
107+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
108+
)
111109

112110

113111
@pytest.mark.flaky(retries=2)
@@ -131,10 +129,9 @@ def test_create_text_generator_model_default_connection(
131129
)
132130

133131
df = model.predict(llm_text_df).to_pandas()
134-
assert df.shape == (3, 4)
135-
assert "ml_generate_text_llm_result" in df.columns
136-
series = df["ml_generate_text_llm_result"]
137-
assert all(series.str.len() > 20)
132+
utils.check_pandas_df_schema_and_index(
133+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
134+
)
138135

139136

140137
# Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough.
@@ -143,21 +140,19 @@ def test_text_generator_predict_default_params_success(
143140
palm2_text_generator_model, llm_text_df
144141
):
145142
df = palm2_text_generator_model.predict(llm_text_df).to_pandas()
146-
assert df.shape == (3, 4)
147-
assert "ml_generate_text_llm_result" in df.columns
148-
series = df["ml_generate_text_llm_result"]
149-
assert all(series.str.len() > 20)
143+
utils.check_pandas_df_schema_and_index(
144+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
145+
)
150146

151147

152148
@pytest.mark.flaky(retries=2)
153149
def test_text_generator_predict_series_default_params_success(
154150
palm2_text_generator_model, llm_text_df
155151
):
156152
df = palm2_text_generator_model.predict(llm_text_df["prompt"]).to_pandas()
157-
assert df.shape == (3, 4)
158-
assert "ml_generate_text_llm_result" in df.columns
159-
series = df["ml_generate_text_llm_result"]
160-
assert all(series.str.len() > 20)
153+
utils.check_pandas_df_schema_and_index(
154+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
155+
)
161156

162157

163158
@pytest.mark.flaky(retries=2)
@@ -166,10 +161,9 @@ def test_text_generator_predict_arbitrary_col_label_success(
166161
):
167162
llm_text_df = llm_text_df.rename(columns={"prompt": "arbitrary"})
168163
df = palm2_text_generator_model.predict(llm_text_df).to_pandas()
169-
assert df.shape == (3, 4)
170-
assert "ml_generate_text_llm_result" in df.columns
171-
series = df["ml_generate_text_llm_result"]
172-
assert all(series.str.len() > 20)
164+
utils.check_pandas_df_schema_and_index(
165+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
166+
)
173167

174168

175169
@pytest.mark.flaky(retries=2)
@@ -179,10 +173,9 @@ def test_text_generator_predict_with_params_success(
179173
df = palm2_text_generator_model.predict(
180174
llm_text_df, temperature=0.5, max_output_tokens=100, top_k=20, top_p=0.5
181175
).to_pandas()
182-
assert df.shape == (3, 4)
183-
assert "ml_generate_text_llm_result" in df.columns
184-
series = df["ml_generate_text_llm_result"]
185-
assert all(series.str.len() > 20)
176+
utils.check_pandas_df_schema_and_index(
177+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
178+
)
186179

187180

188181
def test_create_embedding_generator_model(
@@ -379,10 +372,9 @@ def test_gemini_text_generator_predict_default_params_success(
379372
model_name=model_name, connection_name=bq_connection, session=session
380373
)
381374
df = gemini_text_generator_model.predict(llm_text_df).to_pandas()
382-
assert df.shape == (3, 4)
383-
assert "ml_generate_text_llm_result" in df.columns
384-
series = df["ml_generate_text_llm_result"]
385-
assert all(series.str.len() > 20)
375+
utils.check_pandas_df_schema_and_index(
376+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
377+
)
386378

387379

388380
@pytest.mark.parametrize(
@@ -399,10 +391,9 @@ def test_gemini_text_generator_predict_with_params_success(
399391
df = gemini_text_generator_model.predict(
400392
llm_text_df, temperature=0.5, max_output_tokens=100, top_k=20, top_p=0.5
401393
).to_pandas()
402-
assert df.shape == (3, 4)
403-
assert "ml_generate_text_llm_result" in df.columns
404-
series = df["ml_generate_text_llm_result"]
405-
assert all(series.str.len() > 20)
394+
utils.check_pandas_df_schema_and_index(
395+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
396+
)
406397

407398

408399
@pytest.mark.parametrize(
@@ -444,10 +435,9 @@ def test_claude3_text_generator_predict_default_params_success(
444435
model_name=model_name, connection_name=bq_connection, session=session
445436
)
446437
df = claude3_text_generator_model.predict(llm_text_df).to_pandas()
447-
assert df.shape == (3, 3)
448-
assert "ml_generate_text_llm_result" in df.columns
449-
series = df["ml_generate_text_llm_result"]
450-
assert all(series.str.len() > 20)
438+
utils.check_pandas_df_schema_and_index(
439+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
440+
)
451441

452442

453443
@pytest.mark.parametrize(
@@ -466,10 +456,9 @@ def test_claude3_text_generator_predict_with_params_success(
466456
df = claude3_text_generator_model.predict(
467457
llm_text_df, max_output_tokens=100, top_k=20, top_p=0.5
468458
).to_pandas()
469-
assert df.shape == (3, 3)
470-
assert "ml_generate_text_llm_result" in df.columns
471-
series = df["ml_generate_text_llm_result"]
472-
assert all(series.str.len() > 20)
459+
utils.check_pandas_df_schema_and_index(
460+
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
461+
)
473462

474463

475464
@pytest.mark.flaky(retries=2)

tests/system/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
"log_loss",
4646
"roc_auc",
4747
]
48+
ML_GENERATE_TEXT_OUTPUT = [
49+
"ml_generate_text_llm_result",
50+
"ml_generate_text_status",
51+
"prompt",
52+
]
4853

4954

5055
def skip_legacy_pandas(test):

0 commit comments

Comments
 (0)