16
16
import pytest
17
17
18
18
import bigframes .ml .llm
19
+ from tests .system import utils
19
20
20
21
21
22
@pytest .fixture (scope = "session" )
@@ -114,7 +115,6 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index):
114
115
assert all (col in score_result_col for col in expected_col )
115
116
116
117
117
- @pytest .mark .flaky (retries = 2 )
118
118
def test_llm_gemini_pro_score (llm_fine_tune_df_default_index ):
119
119
model = bigframes .ml .llm .GeminiTextGenerator (model_name = "gemini-pro" )
120
120
@@ -123,18 +123,19 @@ def test_llm_gemini_pro_score(llm_fine_tune_df_default_index):
123
123
X = llm_fine_tune_df_default_index [["prompt" ]],
124
124
y = llm_fine_tune_df_default_index [["label" ]],
125
125
).to_pandas ()
126
- score_result_col = score_result .columns .to_list ()
127
- expected_col = [
128
- "bleu4_score" ,
129
- "rouge-l_precision" ,
130
- "rouge-l_recall" ,
131
- "rouge-l_f1_score" ,
132
- "evaluation_status" ,
133
- ]
134
- assert all (col in score_result_col for col in expected_col )
126
+ utils .check_pandas_df_schema_and_index (
127
+ score_result ,
128
+ columns = [
129
+ "bleu4_score" ,
130
+ "rouge-l_precision" ,
131
+ "rouge-l_recall" ,
132
+ "rouge-l_f1_score" ,
133
+ "evaluation_status" ,
134
+ ],
135
+ index = 1 ,
136
+ )
135
137
136
138
137
- @pytest .mark .flaky (retries = 2 )
138
139
def test_llm_gemini_pro_score_params (llm_fine_tune_df_default_index ):
139
140
model = bigframes .ml .llm .GeminiTextGenerator (model_name = "gemini-pro" )
140
141
@@ -144,12 +145,14 @@ def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index):
144
145
y = llm_fine_tune_df_default_index ["label" ],
145
146
task_type = "classification" ,
146
147
).to_pandas ()
147
- score_result_col = score_result .columns .to_list ()
148
- expected_col = [
149
- "precision" ,
150
- "recall" ,
151
- "f1_score" ,
152
- "label" ,
153
- "evaluation_status" ,
154
- ]
155
- assert all (col in score_result_col for col in expected_col )
148
+ utils .check_pandas_df_schema_and_index (
149
+ score_result ,
150
+ columns = [
151
+ "precision" ,
152
+ "recall" ,
153
+ "f1_score" ,
154
+ "label" ,
155
+ "evaluation_status" ,
156
+ ],
157
+ index = 1 ,
158
+ )
0 commit comments