Skip to content

Commit 19a7442

Browse files
authored
[Benchmark CI] Change DEFAULT_NUM_INPUTS to MAX_NUM_INPUTS (#702)
1 parent e9511ec commit 19a7442

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

benchmarks/run.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def describe_tensor(obj: object) -> object:
6868

6969
logger: logging.Logger = logging.getLogger(__name__)
7070

71-
# Default number of inputs to use when not specified in kernel config
72-
DEFAULT_NUM_INPUTS = 20
71+
# Maximum number of inputs to use
72+
MAX_NUM_INPUTS = 20
7373

7474

7575
@dataclasses.dataclass
@@ -576,7 +576,10 @@ def run_kernel_variants(
576576

577577
# Apply num_inputs if not specified in command line
578578
if "--num-inputs" not in tritonbench_args:
579-
num_inputs = (operator_args or {}).get("num_inputs", DEFAULT_NUM_INPUTS)
579+
# Get per-kernel num_inputs or use MAX_NUM_INPUTS as default
580+
per_kernel_num_inputs = (operator_args or {}).get("num_inputs", MAX_NUM_INPUTS)
581+
# Use the smaller of per_kernel_num_inputs and MAX_NUM_INPUTS
582+
num_inputs = min(per_kernel_num_inputs, MAX_NUM_INPUTS)
580583
tritonbench_args.extend(["--num-inputs", str(num_inputs)])
581584
print(
582585
f"Using num_inputs={num_inputs} for {operator_name}",

0 commit comments

Comments
 (0)