Skip to content

Commit b50e37a

Browse files
authored
add option to save profiling traces in inference roofline script (#3196)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent d1a7fbc commit b50e37a

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161

6262
@torch.no_grad()
63-
def get_gpu_kernel_time(m, x):
63+
def get_gpu_kernel_time(m, x, trace_filename=None):
6464
# warm up
6565
for _ in range(2):
6666
__ = m(x)
@@ -72,6 +72,12 @@ def get_gpu_kernel_time(m, x):
7272
for _ in range(n_iter):
7373
__ = m(x)
7474
torch.cuda.synchronize()
75+
76+
# save a trace, if requested
77+
if trace_filename is not None:
78+
print(f"exporting trace to {trace_filename}")
79+
prof.export_chrome_trace(trace_filename)
80+
7581
# get the gpu kernel time and aggregate it
7682
num_leaf_tensors = 1 + len(list(m.parameters()))
7783
ref_times = profiler_output_to_filtered_time_by_kernel_name(
@@ -161,13 +167,15 @@ def run(
161167
do_benchmarks: bool = True,
162168
shape_gen_name: str = "pow2",
163169
n_limit: Optional[int] = None,
170+
save_profile_traces: bool = False,
164171
):
165172
"""
166173
Args:
167174
* `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
168175
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
169176
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
170177
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
178+
# `save_profile_traces (optional)`: if True, saves profiling traces
171179
"""
172180
config_table = [
173181
["GPU", torch.cuda.get_device_name(0)],
@@ -289,7 +297,11 @@ def run(
289297
# get the bf16 gpu kernel time
290298
torch._dynamo.reset()
291299
m_bf16 = torch.compile(copy.deepcopy(m_orig))
292-
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)
300+
301+
bf16_trace_filename = None
302+
if save_profile_traces:
303+
bf16_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_bf16.json"
304+
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, bf16_trace_filename)
293305

294306
# get the float8 dynamic scaling gpu kernel time
295307
torch._dynamo.reset()
@@ -325,7 +337,11 @@ def run(
325337
quantize_(m_fp8_dyn, config)
326338

327339
m_fp8_dyn = torch.compile(m_fp8_dyn)
328-
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
340+
341+
fp8_trace_filename = None
342+
if save_profile_traces:
343+
fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json"
344+
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename)
329345

330346
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
331347

0 commit comments

Comments
 (0)