Skip to content

Commit 76100f9

Browse files
pianpwkSilv3S
authored andcommitted
[DebugMode] record triton kernels, run-to-run determinism checks (pytorch#167028)
Following up on pytorch#166348, extends DebugMode to capture inductor triton kernels at runtime, and adds an API for checking run-to-run determinism based on tensor hashes. The workflow looks something like... ```python # do 1st run with hashes, get logs with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): compiled_model(*inputs) logs1 = debug_mode.logs # do 2nd run with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): compiled_model(*inputs) logs2 = debug_mode.logs # returns list of calls w/ mismatched outputs mismatches = DebugMode.check_hash_mismatches(logs1, logs2) ``` Example dump off a smaller version of @drisspg's FlexAttention fwd+bwd determinism tests [script](https://gist.github.com/pianpwk/f65cc63811d12853709dcc77d7eb69f1) (without forced reduction order): ``` cfg: TestConfig(name='Standard', B=2, Hq=32, Hkv=32, Q=2048, KV=2048, Dqk=128, Dv=128) DETERMINISM: fwd: True, bwd_q: False, bwd_k: False, bwd_v: True $$$ DEBUG MODE DUMP $$$ (this is what the logs look like) [triton] triton_tem_fused_0(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_MAX=t: f32[2, 32, 2048], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128]) # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_MAX: 81775.3811062593, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, out_ptr0: 924917.7918248245} [triton] triton_per_fused_zeros_0(in_ptr0=t: bf16[2, 32, 2048, 128], in_ptr1=t: bf16[2, 32, 2048, 128], out_ptr1=t: f32[2, 32, 2048], xnumel=131072, r0_numel=128) # post-kernel hashes: {in_ptr0: 924917.7918248245, in_ptr1: 13389213.797377996, out_ptr1: 81775.38106592931} [triton] triton_tem_fused_zeros_1(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_DELTA=t: f32[2, 32, 2048], arg_DO=t: bf16[2, 32, 2048, 128], arg_DQ=t: bf16[2, 32, 2048, 128], arg_DV=t: bf16[2, 32, 2048, 128], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_Q_NUM_BLKS=t: i32[2, 32, 16], arg_Q_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_Q_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_Q_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128]) # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_DELTA: 81775.38106592931, arg_DO: 13389213.797377996, arg_DQ: 874474.8084187683, arg_DV: 727742.3138379117, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_Q_NUM_BLKS: 1024.0, arg_Q_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, arg_FULL_Q_NUM_BLKS: 7680.0, arg_FULL_Q_IDX: 122880.0, out_ptr0: 700542.3431890717} $$$ MISMATCHES $$$ mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_0', 'arg_name': 'arg_MAX', 'pytree_path': None, 'hash1': 0.0, 'hash2': 81775.3811062593, 'rel_diff': 1.0, 'is_input_hash': False} # I guess this one is misleading? not sure if I'm doing something wrong with waiting for kernel results mismatch: {'call_type': 'triton kernel', 'call': 'triton_per_fused_zeros_0', 'arg_name': 'out_ptr1', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False} mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DELTA', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False} mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DQ', 'pytree_path': None, 'hash1': 874474.8097136207, 'hash2': 874474.8084187683, 'rel_diff': 1.480720012120795e-09, 'is_input_hash': False} mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'out_ptr0', 'pytree_path': None, 'hash1': 700542.3488049245, 'hash2': 700542.3431890717, 'rel_diff': 8.016435812581196e-09, 'is_input_hash': False} ``` note: current hash implementation is basically tensor norm, so tensor closeness -> hash closeness. This is likely to change soon, e.g. maybe to `torch.hash_tensor` (pytorch#154149) by default Sample paste diff between log dumps from 2 runs: <img width="1665" height="445" alt="Screenshot 2025-11-05 at 11 27 24 PM" src="https://github.com/user-attachments/assets/41402e37-f50b-4a9e-a17c-bb98b5917076" /> Another case where running this for FSDP2 on Llama3-8B, helped narrow down divergence b/w aot_eager <-> inductor, to inductor's FWD RMSNorm kernels: P2027003180 Pull Request resolved: pytorch#167028 Approved by: https://github.com/v0i0
1 parent 4c5e920 commit 76100f9

File tree

4 files changed

+574
-32
lines changed

4 files changed

+574
-32
lines changed

test/distributed/tensor/debug/test_debug_mode.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Owner(s): ["oncall: distributed"]
22

33
import contextlib
4+
import unittest
45

56
import torch
67
import torch.distributed as dist
@@ -23,8 +24,15 @@
2324
TestCase,
2425
)
2526
from torch.testing._internal.distributed.fake_pg import FakeStore
26-
from torch.utils._debug_mode import _OpCall, _RedistributeCall, DebugMode
27+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
28+
from torch.utils._debug_mode import (
29+
_OpCall,
30+
_RedistributeCall,
31+
_TritonKernelCall,
32+
DebugMode,
33+
)
2734
from torch.utils._python_dispatch import TorchDispatchMode
35+
from torch.utils._triton import has_triton_package
2836

2937

3038
@requires_cuda
@@ -434,6 +442,110 @@ def forward(self, x):
434442
][-1]
435443
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
436444

445+
@unittest.skipIf(not HAS_GPU, "requires GPU")
446+
@unittest.skipIf(not has_triton_package(), "requires triton")
447+
def test_triton_kernel_logs(self):
448+
import triton
449+
450+
from torch.testing._internal.triton_utils import add_kernel_autotuned
451+
452+
def call_triton(x, y):
453+
output = torch.zeros_like(x)
454+
n_elements = output.numel()
455+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
456+
add_kernel_autotuned[grid](x, y, output, n_elements)
457+
return output
458+
459+
x = torch.randn(128, device=GPU_TYPE)
460+
y = torch.randn(128, device=GPU_TYPE)
461+
462+
with DebugMode() as debug_mode:
463+
torch.compile(call_triton)(x, y)
464+
465+
triton_calls = [
466+
op for op in debug_mode.operators if isinstance(op, _TritonKernelCall)
467+
]
468+
self.assertGreater(len(triton_calls), 0)
469+
self.assertIn("[triton]", triton_calls[0].render([]))
470+
471+
def test_check_hash_mismatches(self):
472+
x = torch.randn(64, 64, device=GPU_TYPE)
473+
x_different = torch.randn(64, 64, device=GPU_TYPE)
474+
475+
# Identical runs should have no mismatches
476+
with DebugMode() as dm1, DebugMode.log_tensor_hashes():
477+
x.sin().sum()
478+
with DebugMode() as dm2, DebugMode.log_tensor_hashes():
479+
x.sin().sum()
480+
mismatches = DebugMode.check_hash_mismatches(dm1.logs, dm2.logs)
481+
self.assertEqual(len(mismatches), 0)
482+
483+
# Different inputs should produce hash mismatches
484+
with DebugMode() as dm3, DebugMode.log_tensor_hashes():
485+
x_different.sin().sum()
486+
487+
# Check that mismatches are detected
488+
mismatches = DebugMode.check_hash_mismatches(dm1.logs, dm3.logs)
489+
self.assertEqual(len(mismatches), 2)
490+
self.assertEqual(
491+
[call["call"] for call in mismatches], ["aten::sin", "aten::sum"]
492+
)
493+
494+
@unittest.skipIf(not HAS_GPU, "requires GPU")
495+
@unittest.skipIf(not has_triton_package(), "requires triton")
496+
def test_check_triton_hash_mismatches(self):
497+
import triton
498+
499+
from torch.testing._internal.triton_utils import add_kernel_autotuned
500+
501+
def call_triton(x, y):
502+
output = torch.zeros_like(x)
503+
n_elements = output.numel()
504+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
505+
add_kernel_autotuned[grid](x, y, output, n_elements)
506+
return output
507+
508+
a = torch.randn(128, device=GPU_TYPE)
509+
b = torch.randn(128, device=GPU_TYPE)
510+
c = torch.randn(128, device=GPU_TYPE)
511+
512+
# Run with hash logging to verify triton kernels can be hashed
513+
with DebugMode() as dm_t1, DebugMode.log_tensor_hashes(hash_inputs=True):
514+
torch.compile(call_triton)(a, b)
515+
516+
# Different inputs should have different hashes in triton kernels
517+
with DebugMode() as dm_t2, DebugMode.log_tensor_hashes(hash_inputs=True):
518+
torch.compile(call_triton)(a, c)
519+
520+
# Compare triton kernel hashes
521+
mismatches = DebugMode.check_hash_mismatches(
522+
dm_t1.logs, dm_t2.logs, compare_inputs=True
523+
)
524+
triton_mismatches = [m for m in mismatches if m["call_type"] == "triton kernel"]
525+
self.assertGreater(len(triton_mismatches), 0)
526+
527+
# check both input & output hash mismatches are detected
528+
self.assertGreater(len([m for m in triton_mismatches if m["is_input_hash"]]), 0)
529+
self.assertGreater(
530+
len([m for m in triton_mismatches if not m["is_input_hash"]]), 0
531+
)
532+
533+
def test_check_structure_mismatches(self):
534+
x = torch.randn(32, 32, device=self.device_type)
535+
536+
with DebugMode() as dm1, DebugMode.log_tensor_hashes():
537+
x.sin()
538+
with DebugMode() as dm2, DebugMode.log_tensor_hashes():
539+
x.cos()
540+
with DebugMode() as dm3, DebugMode.log_tensor_hashes():
541+
x.sin().cos()
542+
543+
with self.assertRaisesRegex(ValueError, "Operators don't match"):
544+
DebugMode.check_hash_mismatches(dm1.logs, dm2.logs)
545+
546+
with self.assertRaisesRegex(ValueError, "Log lengths don't match"):
547+
DebugMode.check_hash_mismatches(dm1.logs, dm3.logs)
548+
437549
def test_pretty_print_dtensor_make_fx(self):
438550
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
439551

torch/_inductor/runtime/benchmarking.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.utils._pytree as pytree
1313
from torch._dynamo.utils import counters, dynamo_timed
1414
from torch._inductor.config import use_experimental_benchmarker
15+
from torch.utils._debug_mode import DebugMode
1516

1617

1718
logger = torch._logging.getArtifactLogger(__name__, "benchmarking")
@@ -189,12 +190,14 @@ def benchmark(
189190
else:
190191
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
191192

192-
if inferred_device == torch.device("cpu"):
193-
return self.benchmark_cpu(_callable, **kwargs)
194-
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
195-
# implementation which was written specifically with CUDA devices in mind, we may want to
196-
# explore alternate implementations for other device types.
197-
return self.benchmark_gpu(_callable, **kwargs)
193+
# Surfacing all kernels during autotuning is super noisy; filtering these out.
194+
with DebugMode._benchmarking_inductor():
195+
if inferred_device == torch.device("cpu"):
196+
return self.benchmark_cpu(_callable, **kwargs)
197+
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
198+
# implementation which was written specifically with CUDA devices in mind, we may want to
199+
# explore alternate implementations for other device types.
200+
return self.benchmark_gpu(_callable, **kwargs)
198201

199202
@time_and_count
200203
def benchmark_cpu(

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch._environment import is_fbcode
2626
from torch._inductor import metrics
2727
from torch._prims_common import compute_required_storage_length
28+
from torch.utils._debug_mode import get_active_debug_mode
2829
from torch.utils._ordered_set import OrderedSet
2930

3031
from ..triton_bundler import TritonBundler
@@ -1337,6 +1338,17 @@ def run(
13371338
benchmark_run=False,
13381339
**kwargs,
13391340
): # type:ignore[override]
1341+
"""Launch triton kernel call and return result."""
1342+
debug_mode = get_active_debug_mode()
1343+
debug_call = None
1344+
if debug_mode:
1345+
arg_names = list(self.triton_meta.get("signature", {}).keys())
1346+
kernel_kwargs = dict(zip(arg_names, args))
1347+
kernel_kwargs.update(kwargs)
1348+
debug_call = debug_mode.record_triton_kernel(
1349+
kernel_name=self.fn.__name__, kwargs=kernel_kwargs
1350+
)
1351+
13401352
if hasattr(triton, "set_allocator"):
13411353

13421354
def alloc_fn(size: int, align: int, stream: int | None):
@@ -1392,18 +1404,22 @@ def alloc_fn(size: int, align: int, stream: int | None):
13921404
args_without_constexprs,
13931405
profiler_kwargs,
13941406
):
1395-
return launcher(
1407+
result = launcher(
13961408
*args,
13971409
**kwargs,
13981410
stream=stream,
13991411
)
14001412
else:
1401-
return launcher(
1413+
result = launcher(
14021414
*args,
14031415
**kwargs,
14041416
stream=stream,
14051417
)
14061418

1419+
if debug_call:
1420+
debug_call.finalize(self.get_device_interface())
1421+
return result
1422+
14071423
def _interpret_args_grid(
14081424
self, args: tuple[Any, ...], cfg: Config
14091425
) -> tuple[tuple[Any, ...], tuple[int, int, int]]:

0 commit comments

Comments
 (0)