Skip to content

Commit b6ad6af

Browse files
committed
init e2e check
1 parent 993ab70 commit b6ad6af

File tree

5 files changed

+945
-384
lines changed

5 files changed

+945
-384
lines changed
Lines changed: 231 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,243 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Accuracy Check Script
4+
Compares test results against reference data and calculates pass rates.
5+
Reference last updated: https://github.com/intel/torch-xpu-ops/pull/1223
6+
"""
17

8+
import re
9+
import json
210
import argparse
311
import pandas as pd
4-
import pathlib
5-
6-
# Reference last updated is https://github.com/intel/torch-xpu-ops/pull/1223
7-
8-
parser = argparse.ArgumentParser(description="Accuracy Check", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
9-
parser.add_argument("--driver", type=str, default="rolling", help="rolling or lts")
10-
parser.add_argument("--category", type=str, default="inductor", help="inductor")
11-
parser.add_argument("--suite", type=str, required=True, help="huggingface, timm_models or torchbench")
12-
parser.add_argument("--mode", type=str, required=True, help="inference or training")
13-
parser.add_argument("--dtype", type=str, required=True, help="float32, bfloat16, float16, amp_bf16 or amp_fp16")
14-
# parser.add_argument("--scenario", type=str, required=True, help="accuracy or performance")
15-
parser.add_argument("--csv_file", type=str, required=True, help="your result csv path")
16-
parser.add_argument('--update', action='store_true', help="whether update new pass and new failed info")
17-
args = parser.parse_args()
18-
19-
20-
# load csv files
21-
test_data = pd.read_csv(args.csv_file, comment='#')
22-
# test_data = test_data.reset_index() # make sure indexes pair with number of rows
23-
# test_data = test_data.sort_values(by=["name"], ascending=True)
24-
test_names = [row["name"] for index, row in test_data.iterrows()]
25-
26-
current_path = pathlib.Path(__file__).parent.resolve()
27-
refer_file = str(current_path) + "/" + args.driver + "/" + args.category + "_" + args.suite + "_" + args.mode + ".csv"
28-
refer_data = pd.read_csv(refer_file, comment='#')
29-
# refer_data = refer_data.reset_index() # make sure indexes pair with number of rows
30-
# refer_data = refer_data.sort_values(by=["name"], ascending=True)
31-
refer_names = [row["name"] for index, row in refer_data.iterrows()]
32-
33-
# summary
34-
model_names = set(refer_names + test_names)
35-
passed_models = []
36-
real_failed_models = []
37-
expected_failed_models = []
38-
new_models = []
39-
new_pass_models = []
40-
lost_models = []
41-
timeout_models = []
42-
# for index, row in refer_data.iterrows():
43-
for model_name in model_names:
44-
test_row = next(([i, line] for i, line in test_data.iterrows() if line["name"] == model_name), "N/A")
45-
refer_row = next(([i, line] for i, line in refer_data.iterrows() if line["name"] == model_name), "N/A")
46-
test_accuracy = test_row[1]["accuracy"] if test_row != "N/A" else "N/A"
47-
refer_accuracy = refer_row[1][args.dtype] if refer_row != "N/A" else "N/A"
48-
test_accuracy = str(test_accuracy)
49-
refer_accuracy = str(refer_accuracy)
12+
from pathlib import Path
13+
14+
15+
def load_data(csv_file):
16+
"""Load CSV file with comment support."""
17+
return pd.read_csv(csv_file, comment='#')
18+
19+
20+
def find_model_row(dataframe, model_name):
21+
"""Find row for a specific model in dataframe."""
22+
matches = dataframe[dataframe['name'] == model_name]
23+
return matches.iloc[0] if not matches.empty else None
24+
25+
26+
def get_test_result(data, suite, dtype, mode, model):
27+
"""
28+
Get test result for specific test configuration.
29+
30+
Args:
31+
data: JSON data containing test results
32+
suite: Test suite name
33+
dtype: Data type
34+
mode: Inference or training mode
35+
model: Model name
36+
37+
Returns:
38+
Test result or "N/A" if not found
39+
"""
40+
for issue in data:
41+
for row in issue.get('table_rows', []):
42+
if len(row) >= 6 and row[:4] == [suite, dtype, mode, model]:
43+
return row[4]
44+
return "N/A"
45+
46+
47+
def parse_file_name(filename):
48+
"""
49+
Parse benchmark file name to extract suite, dtype, and mode.
50+
51+
Args:
52+
filename: Input filename to parse
53+
54+
Returns:
55+
tuple: (suite, dtype, mode) or ("N/A", "N/A", "N/A") if pattern not found
56+
"""
57+
pattern = (
58+
r"_(huggingface|timm_models|torchbench)_"
59+
r"(float32|bfloat16|float16|amp_bf16|amp_fp16)_"
60+
r"(inference|training)_"
61+
)
62+
match = re.search(pattern, filename)
63+
return match.groups() if match else ("N/A", "N/A", "N/A")
64+
65+
66+
def load_known_data(issue_file):
67+
"""Load known test data from JSON file."""
68+
try:
69+
with open(issue_file, 'r', encoding='utf-8') as file:
70+
return json.load(file)
71+
except (FileNotFoundError, json.JSONDecodeError) as e:
72+
print(f"Error loading known data from {issue_file}: {e}")
73+
return []
74+
75+
76+
def update_reference_dataframe(refer_data, model_name, dtype, accuracy):
77+
"""
78+
Update reference dataframe with new or updated model results.
79+
80+
Args:
81+
refer_data: Reference dataframe to update
82+
model_name: Name of the model to update
83+
dtype: Data type column to update
84+
accuracy: Accuracy value to set
85+
86+
Returns:
87+
Updated dataframe
88+
"""
89+
mask = refer_data['name'] == model_name
90+
if mask.any():
91+
refer_data.loc[mask, dtype] = accuracy
92+
else:
93+
new_row = {'name': model_name, dtype: accuracy}
94+
refer_data = pd.concat([refer_data, pd.DataFrame([new_row])],
95+
ignore_index=True)
96+
return refer_data
97+
98+
99+
def categorize_model(test_accuracy, refer_accuracy, known_accuracy):
100+
"""
101+
Categorize a model based on its test results.
102+
103+
Returns:
104+
tuple: (category, should_update_reference)
105+
"""
50106
if test_accuracy == "N/A":
51-
lost_models.append([model_name, test_accuracy])
107+
return "lost", False
52108
elif 'pass' in test_accuracy:
53-
passed_models.append([model_name, test_accuracy])
54109
if refer_accuracy == "N/A":
55-
new_models.append([model_name, test_accuracy])
56-
refer_data.loc[len(refer_data), :] = "N/A"
57-
refer_data.at[len(refer_data) - 1, "name"] = model_name
58-
refer_data.at[len(refer_data) - 1, args.dtype] = test_accuracy
110+
return "new", True
59111
elif 'pass' not in refer_accuracy:
60-
new_pass_models.append([model_name, test_accuracy])
61-
refer_data.at[refer_row[0], args.dtype] = test_accuracy
112+
return "new_pass", True
113+
return "passed", False
62114
elif 'timeout' in test_accuracy:
63-
timeout_models.append([model_name, test_accuracy])
64115
if refer_accuracy == "N/A":
65-
new_models.append([model_name, test_accuracy])
66-
refer_data.loc[len(refer_data), :] = "N/A"
67-
refer_data.at[len(refer_data) - 1, "name"] = model_name
68-
refer_data.at[len(refer_data) - 1, args.dtype] = test_accuracy
69-
else:
116+
return "new", True
117+
return "timeout", False
118+
else: # Failed cases
70119
if refer_accuracy == "N/A":
71-
new_models.append([model_name, test_accuracy])
72-
# Not failed for new models
73-
expected_failed_models.append([model_name, test_accuracy])
74-
refer_data.loc[len(refer_data), :] = "N/A"
75-
refer_data.at[len(refer_data) - 1, "name"] = model_name
76-
refer_data.at[len(refer_data) - 1, args.dtype] = test_accuracy
77-
elif "pass" in refer_accuracy:
78-
real_failed_models.append([model_name, test_accuracy])
120+
return "expected_failed", True
121+
elif "pass" in refer_accuracy and known_accuracy != test_accuracy:
122+
return "real_failed", False
79123
else:
80-
expected_failed_models.append([model_name, test_accuracy])
81124
if test_accuracy != refer_accuracy:
82-
refer_data.at[refer_row[0], args.dtype] = test_accuracy
83-
84-
# pass rate
85-
print(f"============ Summary for {args.suite} {args.dtype} {args.mode} accuracy ============")
86-
print("Total models:", len(model_names))
87-
print("Passed models:", len(passed_models))
88-
print("Real failed models:", len(real_failed_models), real_failed_models)
89-
print("Expected failed models:", len(expected_failed_models), expected_failed_models)
90-
print("Warning timeout models:", len(timeout_models), timeout_models)
91-
print("New models:", len(new_models), new_models)
92-
print("Failed to passed models:", len(new_pass_models), new_pass_models)
93-
print("Not run/in models:", len(lost_models), lost_models)
94-
print(f"Pass rate: {len(passed_models) / len(model_names) * 100:.2f}%")
95-
96-
# update reference csv
97-
if len(new_pass_models + new_models) > 0 and args.update:
98-
refer_data.to_csv(refer_file, sep=',', encoding='utf-8', index=False)
125+
return "expected_failed", True
126+
return "expected_failed", False
127+
128+
129+
def print_results_summary(suite, dtype, mode, categories):
130+
"""Print formatted summary of results."""
131+
print(f"============ Summary for {suite} {dtype} {mode} accuracy ============")
132+
print(f"Total models: {len(categories['all_models'])}")
133+
print(f"Passed models: {len(categories['passed'])}")
134+
print(f"Real failed models: {len(categories['real_failed'])}")
135+
print(f"Expected failed models: {len(categories['expected_failed'])}")
136+
print(f"Warning timeout models: {len(categories['timeout'])}")
137+
print(f"New models: {len(categories['new'])}")
138+
print(f"Failed to passed models: {len(categories['new_pass'])}")
139+
print(f"Not run/in models: {len(categories['lost'])}")
140+
141+
if categories['real_failed']:
142+
print(f" Real failed details: {categories['real_failed']}")
143+
if categories['expected_failed']:
144+
print(f" Expected failed details: {categories['expected_failed']}")
145+
146+
total_models = len(categories['all_models'])
147+
if total_models > 0:
148+
pass_rate = len(categories['passed']) / total_models * 100
149+
print(f"Pass rate: {pass_rate:.2f}%")
150+
151+
152+
def main():
153+
"""Main function to run accuracy comparison."""
154+
parser = argparse.ArgumentParser(
155+
description="Accuracy Check",
156+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
157+
)
158+
parser.add_argument("--driver", type=str, default="rolling",
159+
help="rolling or lts")
160+
parser.add_argument("--category", type=str, default="inductor",
161+
help="inductor")
162+
parser.add_argument("--suite", type=str, required=True,
163+
help="huggingface, timm_models or torchbench")
164+
parser.add_argument("--mode", type=str, required=True,
165+
help="inference or training")
166+
parser.add_argument("--dtype", type=str, required=True,
167+
help="float32, bfloat16, float16, amp_bf16 or amp_fp16")
168+
parser.add_argument("--csv_file", type=str, required=True,
169+
help="Test results CSV file path")
170+
parser.add_argument("--issue_file", type=str, required=True,
171+
help="Known test data JSON file path")
172+
parser.add_argument('--update', action='store_true',
173+
help="Whether to update new pass and new failed info")
174+
175+
args = parser.parse_args()
176+
177+
# Load data files
178+
test_data = load_data(args.csv_file)
179+
test_known_data = load_known_data(args.issue_file)
180+
suite, dtype, mode = parse_file_name(args.csv_file)
181+
182+
# Load reference data
183+
current_path = Path(__file__).parent.resolve()
184+
refer_filename = f"{args.category}_{args.suite}_{args.mode}.csv"
185+
refer_file = current_path / args.driver / refer_filename
186+
refer_data = load_data(refer_file)
187+
188+
# Get model names
189+
test_names = test_data['name'].tolist()
190+
refer_names = refer_data['name'].tolist()
191+
model_names = set(refer_names + test_names)
192+
193+
# Initialize result categories
194+
categories = {
195+
'all_models': list(model_names),
196+
'passed': [],
197+
'real_failed': [],
198+
'expected_failed': [],
199+
'new': [],
200+
'new_pass': [],
201+
'lost': [],
202+
'timeout': []
203+
}
204+
205+
needs_update = False
206+
207+
# Process each model
208+
for model_name in model_names:
209+
test_row = find_model_row(test_data, model_name)
210+
refer_row = find_model_row(refer_data, model_name)
211+
212+
test_accuracy = str(test_row['accuracy']) if test_row is not None else "N/A"
213+
refer_accuracy = str(refer_row[args.dtype]) if refer_row is not None else "N/A"
214+
known_accuracy = get_test_result(test_known_data, suite, dtype, mode, model_name)
215+
216+
# Debug print (optional)
217+
print(f"{model_name}: test={test_accuracy}, ref={refer_accuracy}, known={known_accuracy}")
218+
219+
# Categorize model and determine if reference needs update
220+
category, should_update = categorize_model(
221+
test_accuracy, refer_accuracy, known_accuracy
222+
)
223+
224+
categories[category].append([model_name, test_accuracy])
225+
226+
# Update reference data if needed
227+
if should_update and args.update:
228+
refer_data = update_reference_dataframe(
229+
refer_data, model_name, args.dtype, test_accuracy
230+
)
231+
needs_update = True
232+
233+
# Print summary
234+
print_results_summary(args.suite, args.dtype, args.mode, categories)
235+
236+
# Update reference CSV if requested
237+
if needs_update:
238+
refer_data.to_csv(refer_file, sep=',', encoding='utf-8', index=False)
239+
print(f"Reference file updated: {refer_file}")
240+
241+
242+
if __name__ == "__main__":
243+
main()

0 commit comments

Comments
 (0)