55# LICENSE file in the root directory of this source tree.
66
77import sys
8+ from dataclasses import dataclass
89from enum import Enum
910from typing import Any , Dict , Optional
1011
2425 tensor_size_hp_to_fp4x2 ,
2526)
2627from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
28+ from torchao .quantization .quantize_ .common import (
29+ QuantizeTensorKwargs ,
30+ )
2731from torchao .utils import TorchAOBaseTensor , ceil_div , fill_defaults
2832
2933E4M3_EPS = torch .finfo (torch .float8_e4m3fn ).tiny
@@ -38,6 +42,13 @@ class NVFP4MMConfig(Enum):
3842 WEIGHT_ONLY = "weight_only"
3943
4044
45+ @dataclass
46+ class QuantizeTensorToNVFP4Kwargs (QuantizeTensorKwargs ):
47+ block_size : int = 16
48+ is_swizzled_scales : bool = False
49+ use_triton_kernel : bool = False
50+
51+
4152# TODO(future PR): move over to TorchAOBaseTensor's dispatch
4253def implements (aten_ops ):
4354 """Register aten ops to the NVFP4 op table"""
@@ -60,33 +71,34 @@ class NVFP4Tensor(TorchAOBaseTensor):
6071 qdata: Packed FP4 data (2 values per byte)
6172 _scale_e4m3: Blockwise scales in float8_e4m3fn format (may be swizzled)
6273 _per_tensor_scale: Optional global per-tensor scale in float32 format
74+ _act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation
6375 _block_size (int): Block size for quantization (fixed at 16)
6476 _orig_dtype (torch.dtype): Original tensor dtype before quantization
6577 _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
66- mm_config (NVFP4MMConfig): Matrix multiplication configuration
6778 use_triton_kernel (bool): Whether to use triton kernels
6879 """
6980
7081 tensor_data_names = ["qdata" , "_scale_e4m3" ]
71- optional_tensor_data_names = ["_per_tensor_scale" ]
82+ optional_tensor_data_names = ["_per_tensor_scale" , "_act_per_tensor_scale" ]
7283 tensor_attribute_names = [
7384 "_block_size" ,
7485 "_orig_dtype" ,
75- "mm_config" ,
7686 "_is_swizzled_scales" ,
7787 "use_triton_kernel" ,
88+ "act_quant_kwargs" ,
7889 ]
7990
8091 def __new__ (
8192 cls ,
8293 qdata ,
8394 blockwise_scales ,
8495 per_tensor_scale ,
96+ act_per_tensor_scale ,
8597 block_size ,
8698 orig_dtype ,
87- mm_config = NVFP4MMConfig .DYNAMIC ,
8899 is_swizzled_scales = False ,
89100 use_triton_kernel = False ,
101+ act_quant_kwargs = None ,
90102 ):
91103 # FP4 tensor size handling two paths, contiguous or not
92104 new_size = qdata .size ()
@@ -107,11 +119,12 @@ def __new__(
107119 self ._scale_e4m3 = blockwise_scales
108120 self ._is_swizzled_scales = is_swizzled_scales
109121 self ._per_tensor_scale = per_tensor_scale
122+ self ._act_per_tensor_scale = act_per_tensor_scale
110123 self .qdata = qdata
111124 self ._block_size = block_size
112125 self ._orig_dtype = orig_dtype
113- self .mm_config = mm_config
114126 self .use_triton_kernel = use_triton_kernel
127+ self .act_quant_kwargs = act_quant_kwargs
115128 return self
116129
117130 def __repr__ (self ):
@@ -130,9 +143,10 @@ def to_nvfp4(
130143 data_hp : torch .Tensor ,
131144 block_size : int = 16 ,
132145 per_tensor_scale : Optional [torch .Tensor ] = None ,
133- mm_config : NVFP4MMConfig = NVFP4MMConfig . DYNAMIC ,
146+ act_per_tensor_scale : Optional [ torch . Tensor ] = None ,
134147 is_swizzled_scales : bool = False ,
135148 use_triton_kernel : bool = False ,
149+ act_quant_kwargs : Optional [QuantizeTensorToNVFP4Kwargs ] = None ,
136150 ):
137151 """Convert high precision tensor to NVFP4 format.
138152
@@ -141,9 +155,11 @@ def to_nvfp4(
141155 block_size: Block size for quantization (must be 16)
142156 per_tensor_scale: Optional pre-computed absolute maximum for calibration.
143157 If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
144- mm_config: Matrix multiplication configuration
158+ act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation
159+ If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
145160 is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
146161 use_triton_kernel: If True, use Triton kernel for quantization
162+ act_quant_kwargs: If specified, config for quantizing the activation
147163
148164 Returns:
149165 NVFP4Tensor: Quantized tensor in NVFP4 format
@@ -169,11 +185,12 @@ def to_nvfp4(
169185 data_lp ,
170186 blockwise_scales ,
171187 per_tensor_scale ,
188+ act_per_tensor_scale ,
172189 block_size ,
173190 data_hp .dtype ,
174- mm_config ,
175191 is_swizzled_scales ,
176192 use_triton_kernel ,
193+ act_quant_kwargs ,
177194 )
178195
179196 # Do not force the NVFP4Tensor type on the returned tensor
@@ -244,6 +261,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
244261 per_tensor_scale_equal = (
245262 self ._per_tensor_scale is None and src ._per_tensor_scale is None
246263 ) or (self ._per_tensor_scale .shape == src ._per_tensor_scale .shape )
264+ act_per_tensor_scale_equal = (
265+ self ._act_per_tensor_scale is None and src ._act_per_tensor_scale is None
266+ ) or (self ._act_per_tensor_scale .shape == src ._act_per_tensor_scale .shape )
247267
248268 return (
249269 isinstance (self , NVFP4Tensor )
@@ -253,7 +273,9 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
253273 and self ._is_swizzled_scales == src ._is_swizzled_scales
254274 and self ._scale_e4m3 .shape == src ._scale_e4m3 .shape
255275 and per_tensor_scale_equal
276+ and act_per_tensor_scale_equal
256277 and self .qdata .shape == src .qdata .shape
278+ and self .act_quant_kwargs == src .act_quant_kwargs
257279 )
258280
259281
@@ -290,12 +312,13 @@ def nvfp4_to_copy(func, types, args, kwargs):
290312 res = NVFP4Tensor (
291313 tensor ._scale_e4m3 ,
292314 tensor ._per_tensor_scale ,
315+ tensor ._act_per_tensor_scale ,
293316 tensor ._data ,
294317 tensor ._block_size ,
295318 dtype ,
296- tensor .mm_config ,
297319 tensor ._is_swizzled_scales ,
298320 tensor .use_triton_kernel ,
321+ tensor .act_quant_kwargs ,
299322 )
300323 return res
301324
@@ -491,11 +514,12 @@ def nvfp4_slice(func, types, args, kwargs):
491514 sliced_data ,
492515 sliced_scale ,
493516 x ._per_tensor_scale ,
517+ x ._act_per_tensor_scale ,
494518 x ._block_size ,
495519 x ._orig_dtype ,
496- x .mm_config ,
497520 x ._is_swizzled_scales ,
498521 x .use_triton_kernel ,
522+ x .act_quant_kwargs ,
499523 )
500524
501525 return return_and_correct_aliasing (func , args , kwargs , result )
@@ -509,11 +533,12 @@ def nvfp4_t(func, types, args, kwargs):
509533 old .qdata .t (),
510534 old ._scale_e4m3 ,
511535 old ._per_tensor_scale ,
536+ old ._act_per_tensor_scale ,
512537 old ._block_size ,
513538 old ._orig_dtype ,
514- old .mm_config ,
515539 old ._is_swizzled_scales ,
516540 old .use_triton_kernel ,
541+ old .act_quant_kwargs ,
517542 )
518543 return new
519544
@@ -528,11 +553,12 @@ def nvfp4_view_op(func, types, args, kwargs):
528553 new_data ,
529554 args [0 ]._scale_e4m3 ,
530555 args [0 ]._per_tensor_scale ,
556+ args [0 ]._act_per_tensor_scale ,
531557 args [0 ]._block_size ,
532558 args [0 ]._orig_dtype ,
533- args [0 ].mm_config ,
534559 args [0 ]._is_swizzled_scales ,
535560 args [0 ].use_triton_kernel ,
561+ args [0 ].act_quant_kwargs ,
536562 )
537563
538564
@@ -610,17 +636,19 @@ def nvfp4_linear(func, types, args, kwargs):
610636 if not isinstance (weight_tensor , NVFP4Tensor ):
611637 raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
612638
613- config = weight_tensor .mm_config
614-
615- if config == NVFP4MMConfig .WEIGHT_ONLY :
639+ if weight_tensor .act_quant_kwargs is None :
640+ # weight_only quant
616641 weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
617642 return torch .nn .functional .linear (input_tensor , weight_dequant , bias )
618643 else :
644+ # dynamic quant
645+ k = weight_tensor .act_quant_kwargs
619646 input_tensor = NVFP4Tensor .to_nvfp4 (
620647 input_tensor ,
621- mm_config = config ,
622- is_swizzled_scales = True ,
623- use_triton_kernel = weight_tensor .use_triton_kernel ,
648+ block_size = k .block_size ,
649+ per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
650+ is_swizzled_scales = k .is_swizzled_scales ,
651+ use_triton_kernel = k .use_triton_kernel ,
624652 )
625653 return _addmm_nvfp4_dispatch (input_tensor , weight_tensor .t (), func , bias = bias )
626654
@@ -632,9 +660,7 @@ def nvfp4_mm(func, types, args, kwargs):
632660 if not isinstance (weight_tensor , NVFP4Tensor ):
633661 raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
634662
635- config = weight_tensor .mm_config
636-
637- if config == NVFP4MMConfig .WEIGHT_ONLY :
663+ if weight_tensor .act_quant_kwargs is None :
638664 weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
639665 if isinstance (input_tensor , NVFP4Tensor ):
640666 input_dequant = input_tensor .to_dtype (input_tensor ._orig_dtype )
@@ -643,11 +669,13 @@ def nvfp4_mm(func, types, args, kwargs):
643669 return func (input_tensor , weight_dequant )
644670 else :
645671 if not isinstance (input_tensor , NVFP4Tensor ):
672+ k = weight_tensor .act_quant_kwargs
646673 input_tensor = NVFP4Tensor .to_nvfp4 (
647674 input_tensor ,
648- mm_config = config ,
649- is_swizzled_scales = True ,
650- use_triton_kernel = weight_tensor .use_triton_kernel ,
675+ block_size = k .block_size ,
676+ per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
677+ is_swizzled_scales = k .is_swizzled_scales ,
678+ use_triton_kernel = k .use_triton_kernel ,
651679 )
652680 return _addmm_nvfp4_dispatch (input_tensor , weight_tensor , func )
653681
@@ -659,9 +687,7 @@ def nvfp4_addmm(func, types, args, kwargs):
659687 if not isinstance (weight_tensor , NVFP4Tensor ):
660688 raise NotImplementedError ("NVFP4Tensor: weight must be NVFP4Tensor" )
661689
662- config = weight_tensor .mm_config
663-
664- if config == NVFP4MMConfig .WEIGHT_ONLY :
690+ if weight_tensor .act_quant_kwargs is None :
665691 weight_dequant = weight_tensor .to_dtype (weight_tensor ._orig_dtype )
666692 if isinstance (input_tensor , NVFP4Tensor ):
667693 input_dequant = input_tensor .to_dtype (input_tensor ._orig_dtype )
@@ -670,11 +696,13 @@ def nvfp4_addmm(func, types, args, kwargs):
670696 return torch .addmm (bias , input_tensor , weight_dequant )
671697 else :
672698 if not isinstance (input_tensor , NVFP4Tensor ):
699+ k = weight_tensor .act_quant_kwargs
673700 input_tensor = NVFP4Tensor .to_nvfp4 (
674701 input_tensor ,
675- mm_config = config ,
676- is_swizzled_scales = True ,
677- use_triton_kernel = weight_tensor .use_triton_kernel ,
702+ block_size = k .block_size ,
703+ per_tensor_scale = weight_tensor ._act_per_tensor_scale ,
704+ is_swizzled_scales = k .is_swizzled_scales ,
705+ use_triton_kernel = k .use_triton_kernel ,
678706 )
679707 return _addmm_nvfp4_dispatch (input_tensor , weight_tensor , func , bias = bias )
680708
0 commit comments