Skip to content

Commit b8f15b4

Browse files
committed
nvfp4tensor: improve printing
Summary: Makes printing of linears with NVFP4 weights more descriptive, such as ```python (gate_proj): Linear(in_features=2048, out_features=1408, weight=NVFP4Tensor(self._is_swizzled_scales=True, self.use_triton_kernel=False, self.act_quant_kwargs=None)) ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b012e61 ghstack-comment-id: 3340883530 Pull Request resolved: #3086
1 parent 049eee2 commit b8f15b4

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ def __new__(
133133
def __repr__(self):
134134
return f"NVFP4Tensor: blockwise_scales: {self._scale_e4m3}, per_tensor_scale: {self._per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
135135

136+
def _quantization_type(self):
137+
return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}"
138+
136139
@classmethod
137140
def __torch_dispatch__(cls, func, types, args, kwargs=None):
138141
# Use NVFP4-specific ops table

0 commit comments

Comments
 (0)