From 77042cecaaa1aefbe28ece52de7d61fac2fa8d02 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 18 Aug 2025 08:15:24 -0700 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- .../prototype/mx_formats/test_nvfp4_tensor.py | 8 +-- torchao/prototype/mx_formats/nvfp4_tensor.py | 56 +++++++++---------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 02c2d6f0d8..cb2a7a7e56 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -276,12 +276,12 @@ def test_nvfp4_swizzled_scales_view_semantics(): # Test that the sliced tensor shares storage with original for data # (Note: scales might not share storage due to swizzled layout complexity) - assert sliced_tensor._data.data_ptr() == tensor._data.data_ptr() + assert sliced_tensor.qdata.data_ptr() == tensor.qdata.data_ptr() # Test full-width column slicing (should maintain views) full_width_slice = tensor[:, 0:K] assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr() - assert full_width_slice._data.data_ptr() == tensor._data.data_ptr() + assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -399,8 +399,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): torch.testing.assert_close( nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten() ) - pt_unpacked = unpack_uint4(nvfp4_pt._data) - triton_unpacked = unpack_uint4(nvfp4_triton._data) + pt_unpacked = unpack_uint4(nvfp4_pt.qdata) + triton_unpacked = unpack_uint4(nvfp4_triton.qdata) torch.testing.assert_close( pt_unpacked, triton_unpacked, diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 221017b5f4..d31070df5d 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -56,18 +56,18 @@ class NVFP4Tensor(torch.Tensor): quantization algorithm for FP4 data with UE4M3 scales. Attributes: + qdata: Packed FP4 data (2 values per byte) _scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled) _per_tensor_scale: Optional global per-tensor scale in float32 format - _data: Packed FP4 data (2 values per byte) _block_size: Block size for quantization (fixed at 16) _orig_dtype: Original tensor dtype before quantization _is_swizzled_scales: Whether scales are stored in swizzled (blocked) format mm_config: Matrix multiplication configuration """ + qdata: torch.Tensor _scale_e4m3: torch.Tensor _per_tensor_scale: Optional[torch.Tensor] - _data: torch.Tensor _block_size: int _orig_dtype: torch.dtype _is_swizzled_scales: bool @@ -76,9 +76,9 @@ class NVFP4Tensor(torch.Tensor): def __new__( cls, + qdata, blockwise_scales, per_tensor_scale, - data_bits, block_size, orig_dtype, mm_config=NVFP4MMConfig.DYNAMIC, @@ -86,25 +86,25 @@ def __new__( use_triton_kernel=False, ): # FP4 tensor size handling two paths, contiguous or not - new_size = data_bits.size() + new_size = qdata.size() new_size = tensor_size_fp4x2_to_hp( new_size, - data_bits.stride(0) > data_bits.stride(1), + qdata.stride(0) > qdata.stride(1), ) self = torch.Tensor._make_wrapper_subclass( cls, new_size, dtype=orig_dtype, - device=data_bits.device, + device=qdata.device, requires_grad=False, ) self._scale_e4m3 = blockwise_scales self._is_swizzled_scales = is_swizzled_scales self._per_tensor_scale = per_tensor_scale - self._data = data_bits + self.qdata = qdata self._block_size = block_size self._orig_dtype = orig_dtype self.mm_config = mm_config @@ -112,7 +112,7 @@ def __new__( return self def __repr__(self): - return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self._data}, d_hp: {self.to_dtype(self._orig_dtype)}" + return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}" @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -163,9 +163,9 @@ def to_nvfp4( ).flatten() return NVFP4Tensor( + data_lp, blockwise_scales, per_tensor_scale, - data_lp, block_size, data_hp.dtype, mm_config, @@ -181,7 +181,7 @@ def __tensor_flatten__(self): "mm_config": self.mm_config, "use_triton_kernel": self.use_triton_kernel, } - tensor_list = ["_scale_e4m3", "_data"] + tensor_list = ["qdata", "_scale_e4m3"] if self._per_tensor_scale is not None: tensor_list.append("_per_tensor_scale") return tensor_list, ctx @@ -209,9 +209,9 @@ def __tensor_unflatten__( outer_stride, ): return NVFP4Tensor( + inner_tensors["qdata"], inner_tensors["_scale_e4m3"], inner_tensors.get("_per_tensor_scale", None), - inner_tensors["_data"], metadata["_block_size"], metadata["_orig_dtype"], metadata["mm_config"], @@ -231,12 +231,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor: Returns: torch.Tensor: Dequantized tensor in the target dtype """ - is_transposed = self._data.stride(0) < self._data.stride(1) + is_transposed = self.qdata.stride(0) < self.qdata.stride(1) if is_transposed: M, K = self.shape[1], self.shape[0] else: M, K = self.shape[0], self.shape[1] - data = self._data.t() if is_transposed else self._data + data = self.qdata.t() if is_transposed else self.qdata data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8)) data_f32 = f4_unpacked_to_f32(data_unpacked) @@ -256,7 +256,7 @@ def get_hp_scales(self) -> torch.Tensor: Returns: torch.Tensor: Scales of the NVFP4Tensor """ - is_transposed = self._data.stride(0) < self._data.stride(1) + is_transposed = self.qdata.stride(0) < self.qdata.stride(1) if is_transposed: M, K = self.shape[1], self.shape[0] else: @@ -296,7 +296,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool: and self._is_swizzled_scales == src._is_swizzled_scales and self._scale_e4m3.shape == src._scale_e4m3.shape and per_tensor_scale_equal - and self._data.shape == src._data.shape + and self.qdata.shape == src.qdata.shape ) @@ -379,7 +379,7 @@ def nvfp4_slice(func, types, args, kwargs): if step != 1: raise ValueError("Only support aten.slice with step=1") - assert x._data.is_contiguous(), "Only support contiguous data for now" + assert x.qdata.is_contiguous(), "Only support contiguous data for now" M, K = x.shape[0], x.shape[1] @@ -422,7 +422,7 @@ def nvfp4_slice(func, types, args, kwargs): ) sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1) - sliced_data = aten.slice.Tensor(x._data, 0, start, end, step) + sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step) elif dim == 1: # Column slicing @@ -485,7 +485,7 @@ def nvfp4_slice(func, types, args, kwargs): packed_start = None if start is None else start // 2 packed_end = None if end is None else end // 2 sliced_data = aten.slice.Tensor( - x._data, dim, packed_start, packed_end, step + x.qdata, dim, packed_start, packed_end, step ) else: @@ -498,7 +498,7 @@ def nvfp4_slice(func, types, args, kwargs): if dim == 0: sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) - sliced_data = aten.slice.Tensor(x._data, dim, start, end, step) + sliced_data = aten.slice.Tensor(x.qdata, dim, start, end, step) elif dim == 1: if start is not None: @@ -518,7 +518,7 @@ def nvfp4_slice(func, types, args, kwargs): packed_start = None if start is None else start // 2 packed_end = None if end is None else end // 2 sliced_data = aten.slice.Tensor( - x._data, dim, packed_start, packed_end, step + x.qdata, dim, packed_start, packed_end, step ) start_block = 0 if start is None else start // x._block_size @@ -531,9 +531,9 @@ def nvfp4_slice(func, types, args, kwargs): # Create result tensor result = NVFP4Tensor( + sliced_data, sliced_scale, x._per_tensor_scale, - sliced_data, x._block_size, x._orig_dtype, x.mm_config, @@ -549,9 +549,9 @@ def nvfp4_t(func, types, args, kwargs): # For now, only transpose(input, 0, 1) is supported. old = args[0] new = NVFP4Tensor( + old.qdata.t(), old._scale_e4m3, old._per_tensor_scale, - old._data.t(), old._block_size, old._orig_dtype, old.mm_config, @@ -563,14 +563,14 @@ def nvfp4_t(func, types, args, kwargs): @implements([aten.view.default]) def nvfp4_view_op(func, types, args, kwargs): - data = args[0]._data + data = args[0].qdata new_size = args[1] new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) new_data = func(data, new_size, *args[2:], **kwargs) return NVFP4Tensor( + new_data, args[0]._scale_e4m3, args[0]._per_tensor_scale, - new_data, args[0]._block_size, args[0]._orig_dtype, args[0].mm_config, @@ -586,8 +586,8 @@ def _addmm_nvfp4_dispatch( Core implementation shared between nvfp4_mm, nvfp4_addmm, and nvfp4_linear. The only difference is whether bias is None or not. """ - assert a._data.is_contiguous() - assert b._data.t().is_contiguous() + assert a.qdata.is_contiguous() + assert b.qdata.t().is_contiguous() assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" @@ -623,8 +623,8 @@ def _addmm_nvfp4_dispatch( # should_add_bias_separately = bias is not None result = torch._scaled_mm( - a._data.view(torch.float4_e2m1fn_x2), - b._data.view(torch.float4_e2m1fn_x2), + a.qdata.view(torch.float4_e2m1fn_x2), + b.qdata.view(torch.float4_e2m1fn_x2), a_scale_blocked.view(torch.float8_e4m3fn), b_scale_blocked.view(torch.float8_e4m3fn), bias=None if should_add_bias_separately else bias, From e6dd63d2145db28b834e5e8642a2278df8525ff2 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 18 Aug 2025 08:15:28 -0700 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- .../prototype/mx_formats/test_nvfp4_tensor.py | 4 +- torchao/prototype/mx_formats/nvfp4_tensor.py | 79 +++++-------------- torchao/utils.py | 2 + 3 files changed, 22 insertions(+), 63 deletions(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index cb2a7a7e56..4a52fbd6f2 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -304,8 +304,8 @@ def test_nvfp4_swizzled_scales_serialization(): tensor_list, ctx = original_tensor.__tensor_flatten__() # Verify swizzled flag is preserved in context - assert "_is_swizzled_scales" in ctx - assert ctx["_is_swizzled_scales"] == True + assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales" + assert ctx[3] == True # Test deserialization inner_tensors = {} diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index d31070df5d..f59813ebf8 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -6,7 +6,7 @@ import sys from enum import Enum -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict, Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -24,7 +24,7 @@ tensor_size_hp_to_fp4x2, ) from torchao.prototype.mx_formats.utils import from_blocked, to_blocked -from torchao.utils import ceil_div, fill_defaults +from torchao.utils import TorchAOBaseTensor, ceil_div, fill_defaults E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny @@ -38,6 +38,7 @@ class NVFP4MMConfig(Enum): WEIGHT_ONLY = "weight_only" +# TODO(future PR): move over to TorchAOBaseTensor's dispatch def implements(aten_ops): """Register aten ops to the NVFP4 op table""" @@ -49,7 +50,7 @@ def decorator(func): return decorator -class NVFP4Tensor(torch.Tensor): +class NVFP4Tensor(TorchAOBaseTensor): """NVIDIA FP4 (NVFP4) Tensor subclass. This implements the NVIDIA variant of MX FP4 format, which uses a specific @@ -59,20 +60,22 @@ class NVFP4Tensor(torch.Tensor): qdata: Packed FP4 data (2 values per byte) _scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled) _per_tensor_scale: Optional global per-tensor scale in float32 format - _block_size: Block size for quantization (fixed at 16) - _orig_dtype: Original tensor dtype before quantization - _is_swizzled_scales: Whether scales are stored in swizzled (blocked) format - mm_config: Matrix multiplication configuration + _block_size (int): Block size for quantization (fixed at 16) + _orig_dtype (torch.dtype): Original tensor dtype before quantization + _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format + mm_config (NVFP4MMConfig): Matrix multiplication configuration + use_triton_kernel (bool): Whether to use triton kernels """ - qdata: torch.Tensor - _scale_e4m3: torch.Tensor - _per_tensor_scale: Optional[torch.Tensor] - _block_size: int - _orig_dtype: torch.dtype - _is_swizzled_scales: bool - mm_config: NVFP4MMConfig - use_triton_kernel: bool + tensor_data_names = ["qdata", "_scale_e4m3"] + optional_tensor_data_names = ["_per_tensor_scale"] + tensor_attribute_names = [ + "_block_size", + "_orig_dtype", + "mm_config", + "_is_swizzled_scales", + "use_triton_kernel", + ] def __new__( cls, @@ -173,52 +176,6 @@ def to_nvfp4( use_triton_kernel, ) - def __tensor_flatten__(self): - ctx = { - "_block_size": self._block_size, - "_orig_dtype": self._orig_dtype, - "_is_swizzled_scales": self._is_swizzled_scales, - "mm_config": self.mm_config, - "use_triton_kernel": self.use_triton_kernel, - } - tensor_list = ["qdata", "_scale_e4m3"] - if self._per_tensor_scale is not None: - tensor_list.append("_per_tensor_scale") - return tensor_list, ctx - - def _apply_fn_to_data(self, fn: Callable): - """Applies a fn to all tensor components stored on this class""" - tensor_names, ctx = self.__tensor_flatten__() - new_tensors = {} - for name in tensor_names: - new_tensors[name] = fn(getattr(self, name)) - if "_per_tensor_scale" not in tensor_names: - new_tensors["_per_tensor_scale"] = None - return self.__class__.__tensor_unflatten__( - new_tensors, - ctx, - None, - None, - ) - - @staticmethod - def __tensor_unflatten__( - inner_tensors, - metadata, - outer_size, - outer_stride, - ): - return NVFP4Tensor( - inner_tensors["qdata"], - inner_tensors["_scale_e4m3"], - inner_tensors.get("_per_tensor_scale", None), - metadata["_block_size"], - metadata["_orig_dtype"], - metadata["mm_config"], - metadata.get("_is_swizzled_scales", False), - metadata.get("use_triton_kernel", False), - ) - # Do not force the NVFP4Tensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/utils.py b/torchao/utils.py index a32166d556..4a24adadb0 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -810,6 +810,8 @@ def __tensor_flatten__(self): if maybe_tensor is not None: tensor_data_names.append(tensor_data_name) + # TODO(future PR): also return names of tensor attributes for easier + # debugging return tensor_data_names, [ getattr(self, attr) for attr in self.tensor_attribute_names ] From 305e0de118be5b7d8c065dcb90c3c59669d50bb6 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 18 Aug 2025 09:25:48 -0700 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 6fe91a379f..43c8777cb3 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -888,12 +888,12 @@ def test_nvfp4_swizzled_scales_view_semantics(): # Test that the sliced tensor shares storage with original for data # (Note: scales might not share storage due to swizzled layout complexity) - assert sliced_tensor._data.data_ptr() == tensor._data.data_ptr() + assert sliced_tensor.qdata.data_ptr() == tensor.qdata.data_ptr() # Test full-width column slicing (should maintain views) full_width_slice = tensor[:, 0:K] assert full_width_slice._scale_e4m3.data_ptr() == tensor._scale_e4m3.data_ptr() - assert full_width_slice._data.data_ptr() == tensor._data.data_ptr() + assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -1011,8 +1011,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): torch.testing.assert_close( nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten() ) - pt_unpacked = unpack_uint4(nvfp4_pt._data) - triton_unpacked = unpack_uint4(nvfp4_triton._data) + pt_unpacked = unpack_uint4(nvfp4_pt.qdata) + triton_unpacked = unpack_uint4(nvfp4_triton.qdata) torch.testing.assert_close( pt_unpacked, triton_unpacked, From 1c8adb4f554baa56582c18a12ada451366bc32ba Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 18 Aug 2025 10:13:22 -0700 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 43c8777cb3..c91c6ac636 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -916,8 +916,8 @@ def test_nvfp4_swizzled_scales_serialization(): tensor_list, ctx = original_tensor.__tensor_flatten__() # Verify swizzled flag is preserved in context - assert "_is_swizzled_scales" in ctx - assert ctx["_is_swizzled_scales"] == True + assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales" + assert ctx[3] == True # Test deserialization inner_tensors = {}