Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,8 +916,8 @@ def test_nvfp4_swizzled_scales_serialization():
tensor_list, ctx = original_tensor.__tensor_flatten__()

# Verify swizzled flag is preserved in context
assert "_is_swizzled_scales" in ctx
assert ctx["_is_swizzled_scales"] == True
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
assert ctx[3] == True

# Test deserialization
inner_tensors = {}
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def test_nvfp4_swizzled_scales_serialization():
tensor_list, ctx = original_tensor.__tensor_flatten__()

# Verify swizzled flag is preserved in context
assert "_is_swizzled_scales" in ctx
assert ctx["_is_swizzled_scales"] == True
assert NVFP4Tensor.tensor_attribute_names[3] == "_is_swizzled_scales"
assert ctx[3] == True

# Test deserialization
inner_tensors = {}
Expand Down
79 changes: 18 additions & 61 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import sys
from enum import Enum
from typing import Any, Callable, Dict, Optional
from typing import Any, Dict, Optional

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
Expand All @@ -24,7 +24,7 @@
tensor_size_hp_to_fp4x2,
)
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
from torchao.utils import ceil_div, fill_defaults
from torchao.utils import TorchAOBaseTensor, ceil_div, fill_defaults

E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny

Expand All @@ -38,6 +38,7 @@ class NVFP4MMConfig(Enum):
WEIGHT_ONLY = "weight_only"


# TODO(future PR): move over to TorchAOBaseTensor's dispatch
def implements(aten_ops):
"""Register aten ops to the NVFP4 op table"""

Expand All @@ -49,7 +50,7 @@ def decorator(func):
return decorator


class NVFP4Tensor(torch.Tensor):
class NVFP4Tensor(TorchAOBaseTensor):
"""NVIDIA FP4 (NVFP4) Tensor subclass.

This implements the NVIDIA variant of MX FP4 format, which uses a specific
Expand All @@ -59,20 +60,22 @@ class NVFP4Tensor(torch.Tensor):
qdata: Packed FP4 data (2 values per byte)
_scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
_per_tensor_scale: Optional global per-tensor scale in float32 format
_block_size: Block size for quantization (fixed at 16)
_orig_dtype: Original tensor dtype before quantization
_is_swizzled_scales: Whether scales are stored in swizzled (blocked) format
mm_config: Matrix multiplication configuration
_block_size (int): Block size for quantization (fixed at 16)
_orig_dtype (torch.dtype): Original tensor dtype before quantization
_is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
mm_config (NVFP4MMConfig): Matrix multiplication configuration
use_triton_kernel (bool): Whether to use triton kernels
"""

qdata: torch.Tensor
_scale_e4m3: torch.Tensor
_per_tensor_scale: Optional[torch.Tensor]
_block_size: int
_orig_dtype: torch.dtype
_is_swizzled_scales: bool
mm_config: NVFP4MMConfig
use_triton_kernel: bool
tensor_data_names = ["qdata", "_scale_e4m3"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove underscores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to current PR, we can do this in a future PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from a typehint / ide experience this kinda sucks

is it not possible to have the actual attributes typehinted on the class w/ AOBaseTensor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can try to do that as an improvement for future maybe, not sure how to do that yet, maybe we can do some inspection

optional_tensor_data_names = ["_per_tensor_scale"]
tensor_attribute_names = [
"_block_size",
"_orig_dtype",
"mm_config",
"_is_swizzled_scales",
"use_triton_kernel",
]

def __new__(
cls,
Expand Down Expand Up @@ -173,52 +176,6 @@ def to_nvfp4(
use_triton_kernel,
)

def __tensor_flatten__(self):
ctx = {
"_block_size": self._block_size,
"_orig_dtype": self._orig_dtype,
"_is_swizzled_scales": self._is_swizzled_scales,
"mm_config": self.mm_config,
"use_triton_kernel": self.use_triton_kernel,
}
tensor_list = ["qdata", "_scale_e4m3"]
if self._per_tensor_scale is not None:
tensor_list.append("_per_tensor_scale")
return tensor_list, ctx

def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to all tensor components stored on this class"""
tensor_names, ctx = self.__tensor_flatten__()
new_tensors = {}
for name in tensor_names:
new_tensors[name] = fn(getattr(self, name))
if "_per_tensor_scale" not in tensor_names:
new_tensors["_per_tensor_scale"] = None
return self.__class__.__tensor_unflatten__(
new_tensors,
ctx,
None,
None,
)

@staticmethod
def __tensor_unflatten__(
inner_tensors,
metadata,
outer_size,
outer_stride,
):
return NVFP4Tensor(
inner_tensors["qdata"],
inner_tensors["_scale_e4m3"],
inner_tensors.get("_per_tensor_scale", None),
metadata["_block_size"],
metadata["_orig_dtype"],
metadata["mm_config"],
metadata.get("_is_swizzled_scales", False),
metadata.get("use_triton_kernel", False),
)

# Do not force the NVFP4Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl

Expand Down
2 changes: 2 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,8 @@ def __tensor_flatten__(self):
if maybe_tensor is not None:
tensor_data_names.append(tensor_data_name)

# TODO(future PR): also return names of tensor attributes for easier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we could change it, that's what we've been using in the past

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth changing, right now it's a list of attribute values without the names, which is way harder to debug than a key:value dictionary.

# debugging
return tensor_data_names, [
getattr(self, attr) for attr in self.tensor_attribute_names
]
Expand Down
Loading