Skip to content

Commit 8ab52a7

Browse files
committed
Update on "Adding uint4 dtype implementation"
Summary: This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this we plan to move the uint4 tensor subclass to core after it is more mature Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 49c6a43 commit 8ab52a7

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

torchao/dtypes/uint4.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __new__(cls, elem, **kwargs):
8282
kwargs["requires_grad"] = False
8383
return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs)
8484

85-
def __init__(self, elem):
85+
def __init__(self, elem, **kwargs):
8686
self.elem = elem
8787

8888
@classmethod
@@ -242,12 +242,11 @@ def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype
242242

243243
class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor):
244244
@staticmethod
245-
def __new__(cls, elem, scales):
246-
return super().__new__(cls, elem)
245+
def __new__(cls, elem, scales, **kwargs):
246+
return super().__new__(cls, elem, **kwargs)
247247

248-
def __init__(self, elem, scales):
249-
# super().__init__(elem)
250-
self.elem = elem
248+
def __init__(self, elem, scales, **kwargs):
249+
super().__init__(elem, **kwargs)
251250
self.scales = scales
252251

253252

0 commit comments

Comments
 (0)