@@ -284,10 +284,9 @@ def sample_hf_requests(
284
284
random_seed : int ,
285
285
fixed_output_len : Optional [int ] = None ,
286
286
) -> List [Tuple [str , str , int , Optional [Dict [str , Collection [str ]]]]]:
287
-
288
287
# Special case for vision_arena dataset
289
288
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
290
- and dataset_subset is None :
289
+ and dataset_subset is None :
291
290
assert dataset_split == "train"
292
291
dataset = load_dataset (dataset_path ,
293
292
name = dataset_subset ,
@@ -303,8 +302,8 @@ def sample_hf_requests(
303
302
streaming = True )
304
303
assert "conversations" in dataset .features , (
305
304
"HF Dataset must have 'conversations' column." )
306
- filter_func = lambda x : len ( x [ "conversations" ]) >= 2
307
- filtered_dataset = dataset . shuffle ( seed = random_seed ). filter ( filter_func )
305
+ filtered_dataset = dataset . shuffle ( seed = random_seed ). filter (
306
+ lambda x : len ( x [ "conversations" ]) >= 2 , )
308
307
sampled_requests : List [Tuple [str , int , int , Dict [str ,
309
308
Collection [str ]]]] = []
310
309
for data in filtered_dataset :
@@ -323,7 +322,7 @@ def sample_hf_requests(
323
322
# Prune too short sequences.
324
323
continue
325
324
if fixed_output_len is None and \
326
- (prompt_len > 1024 or prompt_len + output_len > 2048 ):
325
+ (prompt_len > 1024 or prompt_len + output_len > 2048 ):
327
326
# Prune too long sequences.
328
327
continue
329
328
@@ -342,7 +341,7 @@ def sample_hf_requests(
342
341
}
343
342
elif "image" in data and isinstance (data ["image" ], str ):
344
343
if (data ["image" ].startswith ("http://" ) or \
345
- data ["image" ].startswith ("file://" )):
344
+ data ["image" ].startswith ("file://" )):
346
345
image_url = data ["image" ]
347
346
else :
348
347
image_url = f"file://{ data ['image' ]} "
@@ -962,8 +961,8 @@ def main(args: argparse.Namespace):
962
961
)
963
962
964
963
# Traffic
965
- result_json ["request_rate" ] = (args . request_rate if args . request_rate
966
- < float ("inf" ) else "inf" )
964
+ result_json ["request_rate" ] = (
965
+ args . request_rate if args . request_rate < float ("inf" ) else "inf" )
967
966
result_json ["burstiness" ] = args .burstiness
968
967
result_json ["max_concurrency" ] = args .max_concurrency
969
968
@@ -974,7 +973,7 @@ def main(args: argparse.Namespace):
974
973
base_model_id = model_id .split ("/" )[- 1 ]
975
974
max_concurrency_str = (f"-concurrency{ args .max_concurrency } "
976
975
if args .max_concurrency is not None else "" )
977
- file_name = f"{ backend } -{ args .request_rate } qps{ max_concurrency_str } -{ base_model_id } -{ current_dt } .json" #noqa
976
+ file_name = f"{ backend } -{ args .request_rate } qps{ max_concurrency_str } -{ base_model_id } -{ current_dt } .json" # noqa
978
977
if args .result_filename :
979
978
file_name = args .result_filename
980
979
if args .result_dir :
0 commit comments