Skip to content

Commit 03c2d28

Browse files
authored
nvfp4: support inference_mode and rank 3 (#3240)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 14eff10 commit 03c2d28

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def test_inference_workflow_mx(
172172
],
173173
ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}",
174174
)
175+
@pytest.mark.parametrize("use_inference_mode", [False, True])
176+
@pytest.mark.parametrize("x_rank", [2, 3])
175177
@torch.no_grad()
176178
@skip_if_rocm("ROCm float4 gemm require gfx950")
177179
def test_inference_workflow_nvfp4(
@@ -182,6 +184,8 @@ def test_inference_workflow_nvfp4(
182184
use_triton_kernel: bool,
183185
use_dynamic_per_tensor_scale: bool,
184186
shapes: tuple,
187+
use_inference_mode: bool,
188+
x_rank: int,
185189
):
186190
"""
187191
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
@@ -196,6 +200,16 @@ def test_inference_workflow_nvfp4(
196200

197201
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
198202
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
203+
204+
if use_inference_mode and (
205+
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel
206+
):
207+
pytest.skip("skipping unnecessary tests for inference mode")
208+
if x_rank == 3 and (
209+
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel
210+
):
211+
pytest.skip("skipping unnecessary tests for x_rank 3")
212+
199213
batch_size, in_features, out_features = shapes
200214

201215
m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda")
@@ -212,14 +226,21 @@ def test_inference_workflow_nvfp4(
212226
m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager")
213227

214228
x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype)
229+
if x_rank == 3:
230+
x = x.unsqueeze(0)
231+
215232
y_ref = m(x)
216233

217234
if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY:
218235
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
219236
y_mx = m_mx(x)
220237
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
221238
else:
222-
y_mx = m_mx(x)
239+
if use_inference_mode:
240+
with torch.inference_mode():
241+
y_mx = m_mx(x)
242+
else:
243+
y_mx = m_mx(x)
223244

224245
sqnr = compute_error(y_ref, y_mx)
225246

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def _addmm_nvfp4_dispatch(
502502
assert b.scale.t().is_contiguous()
503503
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
504504
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
505+
assert len(a.shape) == 2 and len(b.shape) == 2
505506

506507
M, K = a.shape[0], a.shape[1]
507508
N = b.shape[1]
@@ -576,15 +577,19 @@ def nvfp4_linear(func, types, args, kwargs):
576577
tensor_amax = torch.max(torch.abs(input_tensor))
577578
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
578579
else:
579-
per_tensor_scale = weight_tensor._act_per_tensor_scale
580+
per_tensor_scale = weight_tensor.act_per_tensor_scale
581+
orig_shape = input_tensor.shape
582+
input_tensor = input_tensor.view(-1, orig_shape[-1])
580583
input_tensor = NVFP4Tensor.to_nvfp4(
581584
input_tensor,
582585
block_size=k.block_size,
583586
per_tensor_scale=per_tensor_scale,
584587
is_swizzled_scales=k.is_swizzled_scales,
585588
use_triton_kernel=k.use_triton_kernel,
586589
)
587-
return _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias)
590+
res = _addmm_nvfp4_dispatch(input_tensor, weight_tensor.t(), func, bias=bias)
591+
res = res.reshape(*orig_shape[:-1], res.shape[-1])
592+
return res
588593

589594

590595
@implements([aten.mm.default, aten.matmul.default])

0 commit comments

Comments
 (0)