Skip to content

Commit db345bd

Browse files
authored
Allow benchmark_model to accept args and kwargs (#586)
Summary: Previously it accepts a single input_tensor, changing it to accept args and kwargs Test Plan: python test/integration/test_integration.py -k test_benchmark_model_cuda python test/integration/test_integration.py -k test_benchmark_model_cpu Reviewers: Subscribers: Tasks: Tags:
1 parent c023f71 commit db345bd

File tree

6 files changed

+29
-22
lines changed

6 files changed

+29
-22
lines changed

benchmarks/benchmark_aq.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,14 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
9393
# warmup
9494
WARMUP = 5
9595
RUNS = 100
96-
input_tensor = example_inputs[0]
9796
m = torch.compile(m, mode='max-autotune', fullgraph=True)
9897

99-
benchmark_model(m, WARMUP, input_tensor)
100-
elapsed_time = benchmark_model(m, RUNS, input_tensor)
98+
benchmark_model(m, WARMUP, example_inputs)
99+
elapsed_time = benchmark_model(m, RUNS, example_inputs)
101100

102101
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
103-
benchmark_model(m_ref, WARMUP, input_tensor)
104-
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)
102+
benchmark_model(m_ref, WARMUP, example_inputs)
103+
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
105104

106105
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
107106
assert elapsed_time < 1.05 * ref_elapsed_time

test/integration/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,7 @@ def run_benchmark_model(self, device):
15321532
example_inputs = m.example_inputs(dtype=dtype, device=device)
15331533
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
15341534
num_runs = 1
1535-
return benchmark_model(m_bf16, num_runs, example_inputs[0])
1535+
return benchmark_model(m_bf16, num_runs, example_inputs)
15361536

15371537
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
15381538
def test_benchmark_model_cuda(self):

torchao/quantization/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ from torchao.utils import benchmark_model
119119

120120
num_runs = 100
121121
torch._dynamo.reset()
122-
bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0])
122+
bf16_time = benchmark_model(m_bf16, num_runs, example_inputs)
123123
print(f"bf16 mean time: {bf16_time}")
124-
int4_time = benchmark_model(m, num_runs, example_inputs[0])
124+
int4_time = benchmark_model(m, num_runs, example_inputs)
125125
print(f"int4 weight only quantized mean time: {int4_time}")
126126
print(f"speedup: {bf16_time / int4_time}")
127127

torchao/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,16 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
4242
return device
4343

4444

45-
def benchmark_model(model, num_runs, input_tensor):
46-
device_type = _assert_and_get_unique_device(model).type
45+
def benchmark_model(model, num_runs, args=(), kwargs=None, device_type=None):
46+
"""Benchmark model runs with `args` and `kwargs` both are optional
47+
"""
48+
if kwargs is None:
49+
kwargs = {}
50+
51+
if device_type is None:
52+
assert isinstance(model, torch.nn.Module), "Expecting `model` to be torch.nn.Module if device_type is not provided"
53+
device_type = _assert_and_get_unique_device(model).type
54+
4755
if device_type == "cuda":
4856
torch.cuda.synchronize()
4957
start_event = torch.cuda.Event(enable_timing=True)
@@ -53,7 +61,7 @@ def benchmark_model(model, num_runs, input_tensor):
5361
# benchmark
5462
for _ in range(num_runs):
5563
with torch.autograd.profiler.record_function("timed region"):
56-
model(input_tensor)
64+
model(*args, **kwargs)
5765

5866
end_event.record()
5967
torch.cuda.synchronize()
@@ -68,7 +76,7 @@ def benchmark_model(model, num_runs, input_tensor):
6876
# benchmark
6977
for _ in range(num_runs):
7078
with torch.autograd.profiler.record_function("timed region"):
71-
model(input_tensor)
79+
model(*args, **kwargs)
7280

7381
end_event.record()
7482
torch.mps.synchronize()
@@ -81,7 +89,7 @@ def benchmark_model(model, num_runs, input_tensor):
8189
# benchmark
8290
for _ in range(num_runs):
8391
with torch.autograd.profiler.record_function("timed region"):
84-
model(input_tensor)
92+
model(*args, **kwargs)
8593

8694
end_time = time.time()
8795
torch.cpu.synchronize()
@@ -264,7 +272,7 @@ def unwrap_tensor_subclass(model, filter_fn=None):
264272
parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass())
265273
unwrap_tensor_subclass(child)
266274
return model
267-
275+
268276
def is_fbcode():
269277
return not hasattr(torch.version, "git_version")
270278

tutorials/quantize_vit/run_vit_b.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
model.eval().cuda().to(torch.bfloat16)
1212

1313
# Input tensor (batch_size, channels, height, width)
14-
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')
14+
inputs = (torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda'),)
1515

1616
model = torch.compile(model, mode='max-autotune')
1717

1818
# Must run with no_grad when optimizing for inference
1919
with torch.no_grad():
2020
# warmup
21-
benchmark_model(model, 5, input_tensor)
21+
benchmark_model(model, 5, inputs)
2222
# benchmark
23-
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
23+
print("elapsed_time: ", benchmark_model(model, 100, inputs), " milliseconds")
2424
# Create a trace
25-
profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, input_tensor)
25+
profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, inputs)

tutorials/quantize_vit/run_vit_b_quant.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
model.eval().cuda().to(torch.bfloat16)
1313

1414
# Input tensor (batch_size, channels, height, width)
15-
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')
15+
inputs = (torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda'),)
1616

1717
## Quantization code - start
1818
# int8 dynamic quantization act, int8 weight, see ao/torchao/quantization/README.md
@@ -39,8 +39,8 @@
3939
# Must run with no_grad when optimizing for inference
4040
with torch.no_grad():
4141
# warmup
42-
benchmark_model(model, 20, input_tensor)
42+
benchmark_model(model, 20, inputs)
4343
# benchmark
44-
print("elapsed_time: ", benchmark_model(model, 1000, input_tensor), " milliseconds")
44+
print("elapsed_time: ", benchmark_model(model, 1000, inputs), " milliseconds")
4545
# Create a trace
46-
profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor)
46+
profiler_runner("quant.json.gz", benchmark_model, model, 5, inputs)

0 commit comments

Comments
 (0)