@@ -74,10 +74,9 @@ def test_create_text_generator_model_default_session(
7474 llm_text_df = bpd .read_pandas (llm_text_pandas_df )
7575
7676 df = model .predict (llm_text_df ).to_pandas ()
77- assert df .shape == (3 , 4 )
78- assert "ml_generate_text_llm_result" in df .columns
79- series = df ["ml_generate_text_llm_result" ]
80- assert all (series .str .len () > 20 )
77+ utils .check_pandas_df_schema_and_index (
78+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
79+ )
8180
8281
8382@pytest .mark .flaky (retries = 2 )
@@ -104,10 +103,9 @@ def test_create_text_generator_32k_model_default_session(
104103 llm_text_df = bpd .read_pandas (llm_text_pandas_df )
105104
106105 df = model .predict (llm_text_df ).to_pandas ()
107- assert df .shape == (3 , 4 )
108- assert "ml_generate_text_llm_result" in df .columns
109- series = df ["ml_generate_text_llm_result" ]
110- assert all (series .str .len () > 20 )
106+ utils .check_pandas_df_schema_and_index (
107+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
108+ )
111109
112110
113111@pytest .mark .flaky (retries = 2 )
@@ -131,10 +129,9 @@ def test_create_text_generator_model_default_connection(
131129 )
132130
133131 df = model .predict (llm_text_df ).to_pandas ()
134- assert df .shape == (3 , 4 )
135- assert "ml_generate_text_llm_result" in df .columns
136- series = df ["ml_generate_text_llm_result" ]
137- assert all (series .str .len () > 20 )
132+ utils .check_pandas_df_schema_and_index (
133+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
134+ )
138135
139136
140137# Marked as flaky only because BQML LLM is in preview, the service only has limited capacity, not stable enough.
@@ -143,21 +140,19 @@ def test_text_generator_predict_default_params_success(
143140 palm2_text_generator_model , llm_text_df
144141):
145142 df = palm2_text_generator_model .predict (llm_text_df ).to_pandas ()
146- assert df .shape == (3 , 4 )
147- assert "ml_generate_text_llm_result" in df .columns
148- series = df ["ml_generate_text_llm_result" ]
149- assert all (series .str .len () > 20 )
143+ utils .check_pandas_df_schema_and_index (
144+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
145+ )
150146
151147
152148@pytest .mark .flaky (retries = 2 )
153149def test_text_generator_predict_series_default_params_success (
154150 palm2_text_generator_model , llm_text_df
155151):
156152 df = palm2_text_generator_model .predict (llm_text_df ["prompt" ]).to_pandas ()
157- assert df .shape == (3 , 4 )
158- assert "ml_generate_text_llm_result" in df .columns
159- series = df ["ml_generate_text_llm_result" ]
160- assert all (series .str .len () > 20 )
153+ utils .check_pandas_df_schema_and_index (
154+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
155+ )
161156
162157
163158@pytest .mark .flaky (retries = 2 )
@@ -166,10 +161,9 @@ def test_text_generator_predict_arbitrary_col_label_success(
166161):
167162 llm_text_df = llm_text_df .rename (columns = {"prompt" : "arbitrary" })
168163 df = palm2_text_generator_model .predict (llm_text_df ).to_pandas ()
169- assert df .shape == (3 , 4 )
170- assert "ml_generate_text_llm_result" in df .columns
171- series = df ["ml_generate_text_llm_result" ]
172- assert all (series .str .len () > 20 )
164+ utils .check_pandas_df_schema_and_index (
165+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
166+ )
173167
174168
175169@pytest .mark .flaky (retries = 2 )
@@ -179,10 +173,9 @@ def test_text_generator_predict_with_params_success(
179173 df = palm2_text_generator_model .predict (
180174 llm_text_df , temperature = 0.5 , max_output_tokens = 100 , top_k = 20 , top_p = 0.5
181175 ).to_pandas ()
182- assert df .shape == (3 , 4 )
183- assert "ml_generate_text_llm_result" in df .columns
184- series = df ["ml_generate_text_llm_result" ]
185- assert all (series .str .len () > 20 )
176+ utils .check_pandas_df_schema_and_index (
177+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
178+ )
186179
187180
188181def test_create_embedding_generator_model (
@@ -379,10 +372,9 @@ def test_gemini_text_generator_predict_default_params_success(
379372 model_name = model_name , connection_name = bq_connection , session = session
380373 )
381374 df = gemini_text_generator_model .predict (llm_text_df ).to_pandas ()
382- assert df .shape == (3 , 4 )
383- assert "ml_generate_text_llm_result" in df .columns
384- series = df ["ml_generate_text_llm_result" ]
385- assert all (series .str .len () > 20 )
375+ utils .check_pandas_df_schema_and_index (
376+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
377+ )
386378
387379
388380@pytest .mark .parametrize (
@@ -399,10 +391,9 @@ def test_gemini_text_generator_predict_with_params_success(
399391 df = gemini_text_generator_model .predict (
400392 llm_text_df , temperature = 0.5 , max_output_tokens = 100 , top_k = 20 , top_p = 0.5
401393 ).to_pandas ()
402- assert df .shape == (3 , 4 )
403- assert "ml_generate_text_llm_result" in df .columns
404- series = df ["ml_generate_text_llm_result" ]
405- assert all (series .str .len () > 20 )
394+ utils .check_pandas_df_schema_and_index (
395+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
396+ )
406397
407398
408399@pytest .mark .parametrize (
@@ -444,10 +435,9 @@ def test_claude3_text_generator_predict_default_params_success(
444435 model_name = model_name , connection_name = bq_connection , session = session
445436 )
446437 df = claude3_text_generator_model .predict (llm_text_df ).to_pandas ()
447- assert df .shape == (3 , 3 )
448- assert "ml_generate_text_llm_result" in df .columns
449- series = df ["ml_generate_text_llm_result" ]
450- assert all (series .str .len () > 20 )
438+ utils .check_pandas_df_schema_and_index (
439+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
440+ )
451441
452442
453443@pytest .mark .parametrize (
@@ -466,10 +456,9 @@ def test_claude3_text_generator_predict_with_params_success(
466456 df = claude3_text_generator_model .predict (
467457 llm_text_df , max_output_tokens = 100 , top_k = 20 , top_p = 0.5
468458 ).to_pandas ()
469- assert df .shape == (3 , 3 )
470- assert "ml_generate_text_llm_result" in df .columns
471- series = df ["ml_generate_text_llm_result" ]
472- assert all (series .str .len () > 20 )
459+ utils .check_pandas_df_schema_and_index (
460+ df , columns = utils .ML_GENERATE_TEXT_OUTPUT , index = 3 , col_exact = False
461+ )
473462
474463
475464@pytest .mark .flaky (retries = 2 )
0 commit comments