diff --git a/userbenchmark/test_bench/run.py b/userbenchmark/test_bench/run.py index 45154ec76c..a9c894989b 100644 --- a/userbenchmark/test_bench/run.py +++ b/userbenchmark/test_bench/run.py @@ -111,6 +111,57 @@ def generate_model_configs_from_bisect_yaml( return result +def generate_model_configs_from_yaml( + yaml_file: str, +) -> List[TorchBenchModelConfig]: + """ + The configuration might like this: + + devices: + - "foo" + models: + - model: BERT_pytorch + batch_size: 1 + + - model: yolov3 + skip: true + extra_args: + - "--accuracy" + """ + yaml_file_path = os.path.join(yaml_file) + assert os.path.exists(yaml_file_path) + + def _get_val(d: dict, key: str, default_value = None): + if d is None: + return default_value + else: + return d.get(key, default_value) + + with open(yaml_file_path, "r") as yf: + config_obj = yaml.safe_load(yf) + devices = _get_val(config_obj, "devices") + batch_size = _get_val(config_obj, "batch_size") + extra_args = _get_val(config_obj, "extra_args", []) + + model_names = set(list_models(internal=False)) + cfgs = itertools.product(*[devices, model_names]) + configs = [] + for device, model in cfgs: + model_cfg = next(filter(lambda c: c["model"] == model, config_obj["models"]), None) + tests = _get_val(model_cfg, "tests", ["eval"]) + for test in tests: + config = TorchBenchModelConfig( + name=model, + device=device, + test=test, + batch_size=_get_val(model_cfg, "batch_size", batch_size), + extra_args=_get_val(model_cfg, "extra_args", extra_args), + skip=_get_val(model_cfg, "skip", False), + ) + configs.append(config) + return configs + + def init_output_dir( configs: List[TorchBenchModelConfig], output_dir: pathlib.Path ) -> List[TorchBenchModelConfig]: @@ -340,6 +391,8 @@ def run(args: List[str]): args, extra_args = parse_known_args(args) if args.run_bisect: configs = generate_model_configs_from_bisect_yaml(args.run_bisect) + elif args.config: + configs = generate_model_configs_from_yaml(args.config) else: modelset = set(list_models(internal=(not args.oss))) timm_set = set(list_extended_models(suite_name="timm"))