6060
6161
6262@torch .no_grad ()
63- def get_gpu_kernel_time (m , x ):
63+ def get_gpu_kernel_time (m , x , trace_filename = None ):
6464 # warm up
6565 for _ in range (2 ):
6666 __ = m (x )
@@ -72,6 +72,12 @@ def get_gpu_kernel_time(m, x):
7272 for _ in range (n_iter ):
7373 __ = m (x )
7474 torch .cuda .synchronize ()
75+
76+ # save a trace, if requested
77+ if trace_filename is not None :
78+ print (f"exporting trace to { trace_filename } " )
79+ prof .export_chrome_trace (trace_filename )
80+
7581 # get the gpu kernel time and aggregate it
7682 num_leaf_tensors = 1 + len (list (m .parameters ()))
7783 ref_times = profiler_output_to_filtered_time_by_kernel_name (
@@ -161,13 +167,15 @@ def run(
161167 do_benchmarks : bool = True ,
162168 shape_gen_name : str = "pow2" ,
163169 n_limit : Optional [int ] = None ,
170+ save_profile_traces : bool = False ,
164171):
165172 """
166173 Args:
167174 * `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
168175 * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
169176 * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
170177 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
178+ # `save_profile_traces (optional)`: if True, saves profiling traces
171179 """
172180 config_table = [
173181 ["GPU" , torch .cuda .get_device_name (0 )],
@@ -289,7 +297,11 @@ def run(
289297 # get the bf16 gpu kernel time
290298 torch ._dynamo .reset ()
291299 m_bf16 = torch .compile (copy .deepcopy (m_orig ))
292- b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
300+
301+ bf16_trace_filename = None
302+ if save_profile_traces :
303+ bf16_trace_filename = f"{ outfile } _{ M_val } _{ K_val } _{ N_val } _bf16.json"
304+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x , bf16_trace_filename )
293305
294306 # get the float8 dynamic scaling gpu kernel time
295307 torch ._dynamo .reset ()
@@ -325,7 +337,11 @@ def run(
325337 quantize_ (m_fp8_dyn , config )
326338
327339 m_fp8_dyn = torch .compile (m_fp8_dyn )
328- b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
340+
341+ fp8_trace_filename = None
342+ if save_profile_traces :
343+ fp8_trace_filename = f"{ outfile } _{ M_val } _{ K_val } _{ N_val } _fp8.json"
344+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , fp8_trace_filename )
329345
330346 r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
331347
0 commit comments