@@ -56,18 +56,18 @@ class NVFP4Tensor(torch.Tensor):
5656 quantization algorithm for FP4 data with UE4M3 scales.
5757
5858 Attributes:
59+ qdata: Packed FP4 data (2 values per byte)
5960 _scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
6061 _per_tensor_scale: Optional global per-tensor scale in float32 format
61- _data: Packed FP4 data (2 values per byte)
6262 _block_size: Block size for quantization (fixed at 16)
6363 _orig_dtype: Original tensor dtype before quantization
6464 _is_swizzled_scales: Whether scales are stored in swizzled (blocked) format
6565 mm_config: Matrix multiplication configuration
6666 """
6767
68+ qdata : torch .Tensor
6869 _scale_e4m3 : torch .Tensor
6970 _per_tensor_scale : Optional [torch .Tensor ]
70- _data : torch .Tensor
7171 _block_size : int
7272 _orig_dtype : torch .dtype
7373 _is_swizzled_scales : bool
@@ -76,43 +76,43 @@ class NVFP4Tensor(torch.Tensor):
7676
7777 def __new__ (
7878 cls ,
79+ qdata ,
7980 blockwise_scales ,
8081 per_tensor_scale ,
81- data_bits ,
8282 block_size ,
8383 orig_dtype ,
8484 mm_config = NVFP4MMConfig .DYNAMIC ,
8585 is_swizzled_scales = False ,
8686 use_triton_kernel = False ,
8787 ):
8888 # FP4 tensor size handling two paths, contiguous or not
89- new_size = data_bits .size ()
89+ new_size = qdata .size ()
9090
9191 new_size = tensor_size_fp4x2_to_hp (
9292 new_size ,
93- data_bits .stride (0 ) > data_bits .stride (1 ),
93+ qdata .stride (0 ) > qdata .stride (1 ),
9494 )
9595
9696 self = torch .Tensor ._make_wrapper_subclass (
9797 cls ,
9898 new_size ,
9999 dtype = orig_dtype ,
100- device = data_bits .device ,
100+ device = qdata .device ,
101101 requires_grad = False ,
102102 )
103103
104104 self ._scale_e4m3 = blockwise_scales
105105 self ._is_swizzled_scales = is_swizzled_scales
106106 self ._per_tensor_scale = per_tensor_scale
107- self ._data = data_bits
107+ self .qdata = qdata
108108 self ._block_size = block_size
109109 self ._orig_dtype = orig_dtype
110110 self .mm_config = mm_config
111111 self .use_triton_kernel = use_triton_kernel
112112 return self
113113
114114 def __repr__ (self ):
115- return f"NVFP4Tensor: blockwise_scales: { self ._scale_e4m3 } , per_tensor_scale: { self ._per_tensor_scale } , d: { self ._data } , d_hp: { self .to_dtype (self ._orig_dtype )} "
115+ return f"NVFP4Tensor: blockwise_scales: { self ._scale_e4m3 } , per_tensor_scale: { self ._per_tensor_scale } , d: { self .qdata } , d_hp: { self .to_dtype (self ._orig_dtype )} "
116116
117117 @classmethod
118118 def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
@@ -163,9 +163,9 @@ def to_nvfp4(
163163 ).flatten ()
164164
165165 return NVFP4Tensor (
166+ data_lp ,
166167 blockwise_scales ,
167168 per_tensor_scale ,
168- data_lp ,
169169 block_size ,
170170 data_hp .dtype ,
171171 mm_config ,
@@ -181,7 +181,7 @@ def __tensor_flatten__(self):
181181 "mm_config" : self .mm_config ,
182182 "use_triton_kernel" : self .use_triton_kernel ,
183183 }
184- tensor_list = ["_scale_e4m3 " , "_data " ]
184+ tensor_list = ["qdata " , "_scale_e4m3 " ]
185185 if self ._per_tensor_scale is not None :
186186 tensor_list .append ("_per_tensor_scale" )
187187 return tensor_list , ctx
@@ -209,9 +209,9 @@ def __tensor_unflatten__(
209209 outer_stride ,
210210 ):
211211 return NVFP4Tensor (
212+ inner_tensors ["qdata" ],
212213 inner_tensors ["_scale_e4m3" ],
213214 inner_tensors .get ("_per_tensor_scale" , None ),
214- inner_tensors ["_data" ],
215215 metadata ["_block_size" ],
216216 metadata ["_orig_dtype" ],
217217 metadata ["mm_config" ],
@@ -231,12 +231,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
231231 Returns:
232232 torch.Tensor: Dequantized tensor in the target dtype
233233 """
234- is_transposed = self ._data .stride (0 ) < self ._data .stride (1 )
234+ is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
235235 if is_transposed :
236236 M , K = self .shape [1 ], self .shape [0 ]
237237 else :
238238 M , K = self .shape [0 ], self .shape [1 ]
239- data = self ._data .t () if is_transposed else self ._data
239+ data = self .qdata .t () if is_transposed else self .qdata
240240 data_unpacked = unpack_uint4 (data .contiguous ().view (torch .uint8 ))
241241 data_f32 = f4_unpacked_to_f32 (data_unpacked )
242242
@@ -256,7 +256,7 @@ def get_hp_scales(self) -> torch.Tensor:
256256 Returns:
257257 torch.Tensor: Scales of the NVFP4Tensor
258258 """
259- is_transposed = self ._data .stride (0 ) < self ._data .stride (1 )
259+ is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
260260 if is_transposed :
261261 M , K = self .shape [1 ], self .shape [0 ]
262262 else :
@@ -296,7 +296,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
296296 and self ._is_swizzled_scales == src ._is_swizzled_scales
297297 and self ._scale_e4m3 .shape == src ._scale_e4m3 .shape
298298 and per_tensor_scale_equal
299- and self ._data .shape == src ._data .shape
299+ and self .qdata .shape == src .qdata .shape
300300 )
301301
302302
@@ -379,7 +379,7 @@ def nvfp4_slice(func, types, args, kwargs):
379379 if step != 1 :
380380 raise ValueError ("Only support aten.slice with step=1" )
381381
382- assert x ._data .is_contiguous (), "Only support contiguous data for now"
382+ assert x .qdata .is_contiguous (), "Only support contiguous data for now"
383383
384384 M , K = x .shape [0 ], x .shape [1 ]
385385
@@ -422,7 +422,7 @@ def nvfp4_slice(func, types, args, kwargs):
422422 )
423423
424424 sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start_idx , end_idx , 1 )
425- sliced_data = aten .slice .Tensor (x ._data , 0 , start , end , step )
425+ sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
426426
427427 elif dim == 1 :
428428 # Column slicing
@@ -485,7 +485,7 @@ def nvfp4_slice(func, types, args, kwargs):
485485 packed_start = None if start is None else start // 2
486486 packed_end = None if end is None else end // 2
487487 sliced_data = aten .slice .Tensor (
488- x ._data , dim , packed_start , packed_end , step
488+ x .qdata , dim , packed_start , packed_end , step
489489 )
490490
491491 else :
@@ -498,7 +498,7 @@ def nvfp4_slice(func, types, args, kwargs):
498498
499499 if dim == 0 :
500500 sliced_scale = aten .slice .Tensor (scale_shaped , dim , start , end , step )
501- sliced_data = aten .slice .Tensor (x ._data , dim , start , end , step )
501+ sliced_data = aten .slice .Tensor (x .qdata , dim , start , end , step )
502502
503503 elif dim == 1 :
504504 if start is not None :
@@ -518,7 +518,7 @@ def nvfp4_slice(func, types, args, kwargs):
518518 packed_start = None if start is None else start // 2
519519 packed_end = None if end is None else end // 2
520520 sliced_data = aten .slice .Tensor (
521- x ._data , dim , packed_start , packed_end , step
521+ x .qdata , dim , packed_start , packed_end , step
522522 )
523523
524524 start_block = 0 if start is None else start // x ._block_size
@@ -531,9 +531,9 @@ def nvfp4_slice(func, types, args, kwargs):
531531
532532 # Create result tensor
533533 result = NVFP4Tensor (
534+ sliced_data ,
534535 sliced_scale ,
535536 x ._per_tensor_scale ,
536- sliced_data ,
537537 x ._block_size ,
538538 x ._orig_dtype ,
539539 x .mm_config ,
@@ -549,9 +549,9 @@ def nvfp4_t(func, types, args, kwargs):
549549 # For now, only transpose(input, 0, 1) is supported.
550550 old = args [0 ]
551551 new = NVFP4Tensor (
552+ old .qdata .t (),
552553 old ._scale_e4m3 ,
553554 old ._per_tensor_scale ,
554- old ._data .t (),
555555 old ._block_size ,
556556 old ._orig_dtype ,
557557 old .mm_config ,
@@ -563,14 +563,14 @@ def nvfp4_t(func, types, args, kwargs):
563563
564564@implements ([aten .view .default ])
565565def nvfp4_view_op (func , types , args , kwargs ):
566- data = args [0 ]._data
566+ data = args [0 ].qdata
567567 new_size = args [1 ]
568568 new_size = tensor_size_hp_to_fp4x2 (new_size , data .is_contiguous ())
569569 new_data = func (data , new_size , * args [2 :], ** kwargs )
570570 return NVFP4Tensor (
571+ new_data ,
571572 args [0 ]._scale_e4m3 ,
572573 args [0 ]._per_tensor_scale ,
573- new_data ,
574574 args [0 ]._block_size ,
575575 args [0 ]._orig_dtype ,
576576 args [0 ].mm_config ,
@@ -586,8 +586,8 @@ def _addmm_nvfp4_dispatch(
586586 Core implementation shared between nvfp4_mm, nvfp4_addmm, and nvfp4_linear.
587587 The only difference is whether bias is None or not.
588588 """
589- assert a ._data .is_contiguous ()
590- assert b ._data .t ().is_contiguous ()
589+ assert a .qdata .is_contiguous ()
590+ assert b .qdata .t ().is_contiguous ()
591591 assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
592592 assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
593593
@@ -623,8 +623,8 @@ def _addmm_nvfp4_dispatch(
623623 # should_add_bias_separately = bias is not None
624624
625625 result = torch ._scaled_mm (
626- a ._data .view (torch .float4_e2m1fn_x2 ),
627- b ._data .view (torch .float4_e2m1fn_x2 ),
626+ a .qdata .view (torch .float4_e2m1fn_x2 ),
627+ b .qdata .view (torch .float4_e2m1fn_x2 ),
628628 a_scale_blocked .view (torch .float8_e4m3fn ),
629629 b_scale_blocked .view (torch .float8_e4m3fn ),
630630 bias = None if should_add_bias_separately else bias ,
0 commit comments