66
77import sys
88from enum import Enum
9- from typing import Any , Callable , Dict , Optional
9+ from typing import Any , Dict , Optional
1010
1111import torch
1212from torch .utils ._python_dispatch import return_and_correct_aliasing
2424 tensor_size_hp_to_fp4x2 ,
2525)
2626from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
27- from torchao .utils import ceil_div , fill_defaults
27+ from torchao .utils import TorchAOBaseTensor , ceil_div , fill_defaults
2828
2929E4M3_EPS = torch .finfo (torch .float8_e4m3fn ).tiny
3030
@@ -38,6 +38,7 @@ class NVFP4MMConfig(Enum):
3838 WEIGHT_ONLY = "weight_only"
3939
4040
41+ # TODO(future PR): move over to TorchAOBaseTensor's dispatch
4142def implements (aten_ops ):
4243 """Register aten ops to the NVFP4 op table"""
4344
@@ -49,7 +50,7 @@ def decorator(func):
4950 return decorator
5051
5152
52- class NVFP4Tensor (torch . Tensor ):
53+ class NVFP4Tensor (TorchAOBaseTensor ):
5354 """NVIDIA FP4 (NVFP4) Tensor subclass.
5455
5556 This implements the NVIDIA variant of MX FP4 format, which uses a specific
@@ -59,20 +60,22 @@ class NVFP4Tensor(torch.Tensor):
5960 qdata: Packed FP4 data (2 values per byte)
6061 _scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
6162 _per_tensor_scale: Optional global per-tensor scale in float32 format
62- _block_size: Block size for quantization (fixed at 16)
63- _orig_dtype: Original tensor dtype before quantization
64- _is_swizzled_scales: Whether scales are stored in swizzled (blocked) format
65- mm_config: Matrix multiplication configuration
63+ _block_size (int): Block size for quantization (fixed at 16)
64+ _orig_dtype (torch.dtype): Original tensor dtype before quantization
65+ _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
66+ mm_config (NVFP4MMConfig): Matrix multiplication configuration
67+ use_triton_kernel (bool): Whether to use triton kernels
6668 """
6769
68- qdata : torch .Tensor
69- _scale_e4m3 : torch .Tensor
70- _per_tensor_scale : Optional [torch .Tensor ]
71- _block_size : int
72- _orig_dtype : torch .dtype
73- _is_swizzled_scales : bool
74- mm_config : NVFP4MMConfig
75- use_triton_kernel : bool
70+ tensor_data_names = ["qdata" , "_scale_e4m3" ]
71+ optional_tensor_data_names = ["_per_tensor_scale" ]
72+ tensor_attribute_names = [
73+ "_block_size" ,
74+ "_orig_dtype" ,
75+ "mm_config" ,
76+ "_is_swizzled_scales" ,
77+ "use_triton_kernel" ,
78+ ]
7679
7780 def __new__ (
7881 cls ,
@@ -173,52 +176,6 @@ def to_nvfp4(
173176 use_triton_kernel ,
174177 )
175178
176- def __tensor_flatten__ (self ):
177- ctx = {
178- "_block_size" : self ._block_size ,
179- "_orig_dtype" : self ._orig_dtype ,
180- "_is_swizzled_scales" : self ._is_swizzled_scales ,
181- "mm_config" : self .mm_config ,
182- "use_triton_kernel" : self .use_triton_kernel ,
183- }
184- tensor_list = ["qdata" , "_scale_e4m3" ]
185- if self ._per_tensor_scale is not None :
186- tensor_list .append ("_per_tensor_scale" )
187- return tensor_list , ctx
188-
189- def _apply_fn_to_data (self , fn : Callable ):
190- """Applies a fn to all tensor components stored on this class"""
191- tensor_names , ctx = self .__tensor_flatten__ ()
192- new_tensors = {}
193- for name in tensor_names :
194- new_tensors [name ] = fn (getattr (self , name ))
195- if "_per_tensor_scale" not in tensor_names :
196- new_tensors ["_per_tensor_scale" ] = None
197- return self .__class__ .__tensor_unflatten__ (
198- new_tensors ,
199- ctx ,
200- None ,
201- None ,
202- )
203-
204- @staticmethod
205- def __tensor_unflatten__ (
206- inner_tensors ,
207- metadata ,
208- outer_size ,
209- outer_stride ,
210- ):
211- return NVFP4Tensor (
212- inner_tensors ["qdata" ],
213- inner_tensors ["_scale_e4m3" ],
214- inner_tensors .get ("_per_tensor_scale" , None ),
215- metadata ["_block_size" ],
216- metadata ["_orig_dtype" ],
217- metadata ["mm_config" ],
218- metadata .get ("_is_swizzled_scales" , False ),
219- metadata .get ("use_triton_kernel" , False ),
220- )
221-
222179 # Do not force the NVFP4Tensor type on the returned tensor
223180 __torch_function__ = torch ._C ._disabled_torch_function_impl
224181
0 commit comments