Skip to content

Commit a75cd37

Browse files
Introducing Seed Checker for #85 (#87)
* Reformatting package_checker.py. * Finishing seed checker. * A dummpy commit to trigger workflow rerun, cuz I don't have permission to rerun the workflow inside the PR. * Fix typo. * Fix typo. * Allow seeds to be the same if they are logged on different lines or in different source files. * Update comments.
1 parent e7cbf88 commit a75cd37

File tree

2 files changed

+216
-26
lines changed

2 files changed

+216
-26
lines changed

mlperf_logging/package_checker/package_checker.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010

1111
from ..compliance_checker import mlp_compliance
12+
from .seed_checker import find_source_files_under, SeedChecker
1213

1314
_ALLOWED_BENCHMARKS = [
1415
'bert',
@@ -34,11 +35,12 @@
3435

3536

3637
def _get_sub_folders(folder):
37-
sub_folders = [os.path.join(folder, sub_folder)
38-
for sub_folder in os.listdir(folder)]
39-
return [sub_folder
40-
for sub_folder in sub_folders
41-
if os.path.isdir(sub_folder)]
38+
sub_folders = [
39+
os.path.join(folder, sub_folder) for sub_folder in os.listdir(folder)
40+
]
41+
return [
42+
sub_folder for sub_folder in sub_folders if os.path.isdir(sub_folder)
43+
]
4244

4345

4446
def _print_divider_bar():
@@ -53,6 +55,7 @@ def check_training_result_files(folder, ruleset, quiet, werror):
5355
ruleset: The ruleset such as 0.6.0 or 0.7.0.
5456
"""
5557

58+
seed_checker = SeedChecker(ruleset)
5659
too_many_errors = False
5760
result_folder = os.path.join(folder, 'results')
5861
for system_folder in _get_sub_folders(result_folder):
@@ -76,16 +79,21 @@ def check_training_result_files(folder, ruleset, quiet, werror):
7679
print('No Result Files!')
7780
continue
7881

82+
# Find all source codes for this benchmark.
83+
source_files = find_source_files_under(
84+
os.path.join(folder, 'benchmarks', benchmark))
85+
7986
_print_divider_bar()
8087
print('System {}'.format(system))
8188
print('Benchmark {}'.format(benchmark))
8289

83-
# If the organization did submit results for this benchmark, the number
84-
# of result files must be an exact number.
90+
# If the organization did submit results for this benchmark, the
91+
# number of result files must be an exact number.
8592
if len(result_files) != _EXPECTED_RESULT_FILE_COUNTS[benchmark]:
8693
print('Expected {} runs, but detected {} runs.'.format(
8794
_EXPECTED_RESULT_FILE_COUNTS[benchmark],
88-
len(result_files)))
95+
len(result_files),
96+
))
8997

9098
errors_found = 0
9199
result_files.sort()
@@ -99,23 +107,34 @@ def check_training_result_files(folder, ruleset, quiet, werror):
99107
print('Run {}'.format(run))
100108
config_file = '{ruleset}/common.yaml'.format(
101109
ruleset=ruleset,
102-
benchmark=benchmark)
110+
benchmark=benchmark,
111+
)
103112
checker = mlp_compliance.make_checker(
104113
ruleset=ruleset,
105114
quiet=quiet,
106-
werror=werror)
107-
valid, _, _, _ = mlp_compliance.main(result_file, config_file, checker)
115+
werror=werror,
116+
)
117+
valid, _, _, _ = mlp_compliance.main(
118+
result_file,
119+
config_file,
120+
checker,
121+
)
108122
if not valid:
109-
errors_found += 1
123+
errors_found += 1
110124
if errors_found == 1:
111-
print('WARNING: One file does not comply.')
112-
print('WARNING: Allowing this failure under olympic scoring rules.')
125+
print('WARNING: One file does not comply.')
126+
print('WARNING: Allowing this failure under olympic scoring '
127+
'rules.')
113128
if errors_found > 1:
114-
too_many_errors = True
129+
too_many_errors = True
130+
# Check if each run use unique seeds.
131+
if not seed_checker.check_seeds(result_files, source_files):
132+
too_many_errors = True
115133

116134
_print_divider_bar()
117135
if too_many_errors:
118-
raise Exception('Found too many errors in logging, see log above for details.')
136+
raise Exception(
137+
'Found too many errors in logging, see log above for details.')
119138

120139

121140
def check_training_package(folder, ruleset, quiet, werror):
@@ -134,16 +153,31 @@ def get_parser():
134153
description='Lint MLPerf submission packages.',
135154
)
136155

137-
parser.add_argument('folder', type=str,
138-
help='the folder for a submission package')
139-
parser.add_argument('usage', type=str,
140-
help='the usage such as training, inference_edge, inference_server')
141-
parser.add_argument('ruleset', type=str,
142-
help='the ruleset such as 0.6.0, 0.7.0')
143-
parser.add_argument('--werror', action='store_true',
144-
help='Treat warnings as errors')
145-
parser.add_argument('--quiet', action='store_true',
146-
help='Suppress warnings. Does nothing if --werror is set')
156+
parser.add_argument(
157+
'folder',
158+
type=str,
159+
help='the folder for a submission package',
160+
)
161+
parser.add_argument(
162+
'usage',
163+
type=str,
164+
help='the usage such as training, inference_edge, inference_server',
165+
)
166+
parser.add_argument(
167+
'ruleset',
168+
type=str,
169+
help='the ruleset such as 0.6.0, 0.7.0',
170+
)
171+
parser.add_argument(
172+
'--werror',
173+
action='store_true',
174+
help='Treat warnings as errors',
175+
)
176+
parser.add_argument(
177+
'--quiet',
178+
action='store_true',
179+
help='Suppress warnings. Does nothing if --werror is set',
180+
)
147181

148182
return parser
149183

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import warnings
2+
import os
3+
4+
from ..compliance_checker import mlp_parser
5+
6+
# What are source files?
7+
SOURCE_FILE_EXT = {
8+
'.py', '.cc', '.cpp', '.cxx', '.c', '.h', '.hh', '.hpp', '.hxx', '.sh',
9+
'.sub', '.cu', '.cuh'
10+
}
11+
12+
13+
def is_source_file(path):
14+
""" Check if a file is considered as a "source file" by extensions.
15+
16+
The extensions that are considered as "source file" are listed in
17+
SOURCE_FILE_EXT.
18+
19+
Args:
20+
path: The absolute path, relative path or name to/of the file.
21+
"""
22+
return os.path.splitext(path)[1].lower() in SOURCE_FILE_EXT
23+
24+
25+
def find_source_files_under(path):
26+
""" Find all source files in all sub-directories under a directory.
27+
28+
Args:
29+
path: The absolute or relative path to the directory under query.
30+
"""
31+
source_files = []
32+
for root, subdirs, files in os.walk(path):
33+
for file_name in files:
34+
if is_source_file(file_name):
35+
source_files.append(os.path.join(root, file_name))
36+
return source_files
37+
38+
39+
class SeedChecker:
40+
""" Check if the seeds fit MLPerf submission requirements.
41+
Current requirements are:
42+
43+
1. All seeds must be logged through mllog (if choose to log seeds). Any seed
44+
logged via any other method will be discarded.
45+
2. All seeds, if choose to be logged, must be valid integer (convertible via
46+
int()).
47+
3. If any run log at least one seed, we expect all runs to log at least
48+
one seed.
49+
4. The set of seed(s) that one run logs must be completely different from
50+
the set of seed(s) any other run logs.
51+
4. If one run logs one seed on a certain line in a certain source file, no
52+
other run can log the same seed on the same line in the same file.
53+
54+
Unsatisfying any of the above requirements results in check failure.
55+
56+
A warning is raised for the following situations:
57+
58+
1. Any run logs more than one seed.
59+
2. No seed is logged. However, the source files contain the keyword whose
60+
lowercase is "seed". What files are considered as source files are
61+
defined in SOURCE_FILE_EXT and is_source_file().
62+
63+
"""
64+
def __init__(self, ruleset):
65+
self._ruleset = ruleset
66+
67+
def _get_seed_records(self, result_file):
68+
loglines, errors = mlp_parser.parse_file(
69+
result_file,
70+
ruleset=self._ruleset,
71+
)
72+
if len(errors) > 0:
73+
raise ValueError('\n'.join(
74+
['Found parsing errors:'] +
75+
['{}\n ^^ {}'.format(line, error)
76+
for line, error in errors] +
77+
['', 'Log lines had parsing errors.']))
78+
return [(
79+
line.value['metadata']['file'],
80+
line.value['metadata']['lineno'],
81+
int(line.value['value']),
82+
) for line in loglines if line.key == 'seed']
83+
84+
def _assert_unique_seed_per_run(self, result_files):
85+
no_logged_seed = True
86+
error_messages = []
87+
seed_to_result_file = {}
88+
for result_file in result_files:
89+
try:
90+
seed_records = self._get_seed_records(result_file)
91+
except Exception as e:
92+
error_messages.append("Error found when querying seeds from "
93+
"{}: {}".format(result_file, e))
94+
continue
95+
96+
if not no_logged_seed and len(seed_records) == 0:
97+
error_messages.append(
98+
"Result file {} logs no seed. However, other "
99+
"result files, including {}, already logs some "
100+
"seeds.".format(result_file,
101+
list(seed_to_result_file.keys())))
102+
if no_logged_seed and len(seed_records) > 0:
103+
no_logged_seed = False
104+
if len(seed_records) > 1:
105+
warnings.warn(
106+
"Result file {} logs more than one seeds {}!".format(
107+
result_file, seed_records))
108+
for f, ln, s in seed_records:
109+
if (f, ln, s) in seed_to_result_file:
110+
error_messages.append(
111+
"Result file {} logs seed {} on {}:{}. However, "
112+
"result file {} already logs the same seed on the same "
113+
"line.".format(
114+
result_file,
115+
s,
116+
f,
117+
ln,
118+
seed_to_result_file[(f, ln, s)],
119+
))
120+
else:
121+
seed_to_result_file[(f, ln, s)] = result_file
122+
123+
return no_logged_seed, error_messages
124+
125+
def _has_seed_keyword(self, source_file):
126+
with open(source_file, 'r') as file_handle:
127+
for line in file_handle.readlines():
128+
if 'seed' in line.lower():
129+
return True
130+
return False
131+
132+
def check_seeds(self, result_files, source_files):
133+
""" Check the seeds for a specific benchmark submission.
134+
135+
Args:
136+
result_files: An iterable contains paths to all the result files for
137+
this benchmark.
138+
source_files: An iterable contains paths to all the source files for
139+
this benchmark.
140+
141+
"""
142+
no_logged_seed, error_messages = self._assert_unique_seed_per_run(
143+
result_files)
144+
145+
if len(error_messages) > 0:
146+
print("Seed checker failed and found the following "
147+
"errors:\n{}".format('\n'.join(error_messages)))
148+
return False
149+
150+
if no_logged_seed:
151+
for source_file in source_files:
152+
if self._has_seed_keyword(source_file):
153+
warnings.warn(
154+
"Source file {} contains the keyword 'seed' but no "
155+
"seed value is logged!".format(source_file))
156+
return True

0 commit comments

Comments
 (0)