Skip to content

Commit afde175

Browse files
authored
Refactor LinearActQuantizedTensor (#542)
Summary: * rename to LinearActivationQuantizedTensor * using `implements` util to implement torch function and torch dispatch overwrites Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent c9f79be commit afde175

File tree

13 files changed

+288
-269
lines changed

13 files changed

+288
-269
lines changed

test/quantization/test_quant_api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
from torchao.dtypes import (
2323
AffineQuantizedTensor,
2424
)
25+
from torchao.quantization import (
26+
LinearActivationQuantizedTensor,
27+
)
2528
from torchao.quantization.quant_primitives import (
2629
MappingType,
2730
ZeroPointDomain,
2831
)
2932
from torchao.quantization.subclass import (
30-
LinearActQuantizedTensor,
3133
Int8WeightOnlyQuantizedLinearWeight,
3234
Int4WeightOnlyQuantizedLinearWeight,
3335
)
@@ -504,8 +506,8 @@ def test_quantized_tensor_subclass_8da4w(self):
504506
example_inputs = m.example_inputs()
505507
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
506508

507-
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
508-
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
509+
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
510+
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
509511
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
510512
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
511513

@@ -577,8 +579,8 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
577579
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
578580
quantize_(m, int8_dynamic_activation_int8_weight())
579581

580-
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
581-
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
582+
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
583+
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
582584
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
583585
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
584586

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from torchao.utils import find_multiple
1818
from torchao.dtypes.utils import (
1919
_implements,
20-
_ATEN_OP_OR_TORCH_FN_TABLE,
20+
_dispatch__torch_function__,
21+
_dispatch__torch_dispatch__,
2122
_register_layout_cls,
2223
_get_layout_tensor_constructor,
2324
LayoutType,
@@ -295,17 +296,6 @@ def from_float_static(
295296
def layout_type(self) -> LayoutType:
296297
return self.layout_tensor.layout_type
297298

298-
@classmethod
299-
def __torch_function__(cls, func, types, args=(), kwargs=None):
300-
kwargs = {} if kwargs is None else kwargs
301-
302-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
303-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)
304-
305-
with torch._C.DisableTorchFunctionSubclass():
306-
return func(*args, **kwargs)
307-
308-
309299
def _get_to_kwargs(self, *args, **kwargs):
310300
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
311301
device = self.device if device is None else device
@@ -347,29 +337,23 @@ def _apply_fn_to_data(self, fn):
347337
strides=self.stride(),
348338
)
349339

350-
@classmethod
351-
def __torch_dispatch__(cls, func, types, args, kwargs):
352-
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
353-
# 1. we'll add cpu/cuda version (int4mm etc.)
354-
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
355-
# cpu device + et laytout --> gives current 8da4w executorch representation
356-
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
357-
# cuda device + some layout --> gives cuda kernel
358-
359-
# two scenarios where we currently fall back to vanilla mm:
360-
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
361-
# kernels in CPU as well, see the note above
362-
# 2 - we're given non-floats - quantizing long to int8 is crazy
363340

364-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
365-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
341+
implements = classmethod(_implements)
342+
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
343+
# 1. we'll add cpu/cuda version (int4mm etc.)
344+
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
345+
# cpu device + et laytout --> gives current 8da4w executorch representation
346+
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
347+
# cuda device + some layout --> gives cuda kernel
366348

367-
raise NotImplementedError(
368-
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
369-
)
349+
# two scenarios where we currently fall back to vanilla mm:
350+
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
351+
# kernels in CPU as well, see the note above
352+
# 2 - we're given non-floats - quantizing long to int8 is crazy
353+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
354+
__torch_function__ = classmethod(_dispatch__torch_function__)
370355

371-
def implements(aten_ops_or_torch_fn):
372-
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)
356+
implements = AffineQuantizedTensor.implements
373357

374358
def register_layout_cls(layout_type_class: type(LayoutType)):
375359
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
@@ -827,7 +811,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
827811

828812

829813
@implements(torch.nn.functional.linear)
830-
def functional_linear(*args, **kwargs):
814+
def _(func, types, *args, **kwargs):
831815
input_tensor, weight_tensor, bias = (
832816
args[0],
833817
args[1],
@@ -846,7 +830,7 @@ def functional_linear(*args, **kwargs):
846830
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
847831

848832
@implements([aten.mm.default, aten.addmm.default])
849-
def aten_mm(func, *args, **kwargs):
833+
def _(func, types, *args, **kwargs):
850834
if not args[0].is_floating_point():
851835
raise NotImplementedError(f"{func} is not implemented for non floating point input")
852836

@@ -885,21 +869,21 @@ def aten_mm(func, *args, **kwargs):
885869
return func(input_tensor, weight_tensor)
886870

887871
@implements([aten.detach.default])
888-
def detach(func, *args, **kwargs):
872+
def _(func, types, *args, **kwargs):
889873
return return_and_correct_aliasing(
890874
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
891875
)
892876

893877

894878
@implements([aten.clone.default])
895-
def clone(func, *args, **kwargs):
879+
def _(func, types, *args, **kwargs):
896880
return return_and_correct_aliasing(
897881
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
898882
)
899883

900884

901885
@implements([aten._to_copy.default])
902-
def _to_copy(func, *args, **kwargs):
886+
def _(func, types, *args, **kwargs):
903887
return return_and_correct_aliasing(
904888
func,
905889
args,
@@ -908,7 +892,7 @@ def _to_copy(func, *args, **kwargs):
908892
)
909893

910894
@implements([aten.t.default])
911-
def t(func, *args, **kwargs):
895+
def _(func, types, *args, **kwargs):
912896
block_size = args[0].block_size
913897
assert len(block_size) == 2
914898
transposed_block_size = (block_size[1], block_size[0])

torchao/dtypes/utils.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,28 @@
55
from dataclasses import dataclass
66

77
"""
8-
torch_function and torch_dispatch operator dispatch registrations
9-
10-
first key is a tensor subclass type like AffineQuantizedTensor,
11-
second key is a `func` in __torhc_function__ or __torch_dispatch__,
12-
value is a function that implements the dispatch
8+
Helper function for implementing aten op or torch function dispatch
9+
and dispatching to these implementations.
1310
"""
14-
_ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Callable, Callable]] = defaultdict(dict)
15-
1611
def _implements(cls, aten_ops_or_torch_fns):
1712
"""Use this decorator to implement a function for an aten ops in __torch_dispatch__
1813
(if user passed in a list of ops)
1914
or torch function in __torch_function__ (if user passed in a single object)
15+
16+
class MyTensor(torch.Tensor):
17+
...
18+
implements = classmethod(_implements)
19+
20+
implements = MyTensor.implements
21+
22+
@implements(torch.nn.functional.linear):
23+
def _(func, types, args, kwargs):
24+
...
25+
2026
"""
27+
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
28+
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}
29+
2130
if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
2231
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
2332
def decorator(func):
@@ -26,10 +35,41 @@ def decorator(func):
2635
def wrapper(*args, **kwargs):
2736
return func(*args, **kwargs)
2837

29-
_ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper
38+
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
3039
return func
3140
return decorator
3241

42+
def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None):
43+
"""Use this util function for a common `__torch_function__` implementation
44+
that dispatches to ops/functions registered with `_implements`
45+
46+
class MyTensor(torch.Tensor):
47+
...
48+
__torch_function__ = classmethod(_dispatch__torch_function__)
49+
"""
50+
kwargs = {} if kwargs is None else kwargs
51+
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
52+
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
53+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
54+
55+
with torch._C.DisableTorchFunctionSubclass():
56+
return func(*args, **kwargs)
57+
58+
def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
59+
"""Use this util function for a common `__torch_dispatch__` implementation
60+
that dispatches to ops/functions registered with `_implements`
61+
62+
class MyTensor(torch.Tensor):
63+
...
64+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
65+
"""
66+
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
67+
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
68+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
69+
70+
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")
71+
72+
3373
"""
3474
Base class for different LayoutType, should not be instantiated directly
3575
"""

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch import Tensor
5-
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
5+
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
66

77
from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap
88

@@ -85,16 +85,11 @@ def __repr__(self):
8585
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
8686
)
8787

88-
@classmethod
89-
def __torch_dispatch__(cls, func, types, args, kwargs):
90-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
91-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
92-
93-
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
88+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
9489

9590

9691
@OptimState4bit.implements(aten.copy_.default)
97-
def _(func, *args, **kwargs):
92+
def _(func, types, *args, **kwargs):
9893
dst = args[0]
9994
src = args[1]
10095

@@ -121,14 +116,14 @@ def _(func, *args, **kwargs):
121116

122117

123118
@OptimState4bit.implements(aten.lerp.Scalar)
124-
def _(func, *args, **kwargs):
119+
def _(func, types, *args, **kwargs):
125120
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]
126121
return func(*args, **kwargs)
127122

128123

129124
# this is needed for DTensor.from_local() and for flattening tensor
130125
@OptimState4bit.implements(aten.view.default)
131-
def _(func, *args, **kwargs):
126+
def _(func, types, *args, **kwargs):
132127
x, shape = args
133128

134129
if tuple(x.shape) == tuple(shape):
@@ -147,7 +142,7 @@ def _(func, *args, **kwargs):
147142
c10d_functional.wait_tensor.default,
148143
_c10d_functional.wait_tensor.default,
149144
])
150-
def _(func, *args, **kwargs):
145+
def _(func, types, *args, **kwargs):
151146
x = args[0]
152147
if not isinstance(x, OptimState4bit):
153148
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
3+
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
44

55
from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap
66

@@ -71,16 +71,11 @@ def __repr__(self):
7171
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
7272
)
7373

74-
@classmethod
75-
def __torch_dispatch__(cls, func, types, args, kwargs):
76-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
77-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
78-
79-
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
74+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
8075

8176

8277
@OptimState8bit.implements(aten.copy_.default)
83-
def _(func, *args, **kwargs):
78+
def _(func, types, *args, **kwargs):
8479
dst = args[0]
8580
src = args[1]
8681

@@ -103,14 +98,14 @@ def _(func, *args, **kwargs):
10398

10499

105100
@OptimState8bit.implements(aten.lerp.Scalar)
106-
def _(func, *args, **kwargs):
101+
def _(func, types, *args, **kwargs):
107102
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
108103
return func(*args, **kwargs)
109104

110105

111106
# this is needed for DTensor.from_local()
112107
@OptimState8bit.implements(aten.view.default)
113-
def _(func, *args, **kwargs):
108+
def _(func, types, *args, **kwargs):
114109
x, shape = args
115110
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)
116111

@@ -122,7 +117,7 @@ def _(func, *args, **kwargs):
122117
c10d_functional.wait_tensor.default,
123118
_c10d_functional.wait_tensor.default,
124119
])
125-
def _(func, *args, **kwargs):
120+
def _(func, types, *args, **kwargs):
126121
x = args[0]
127122
if not isinstance(x, OptimState8bit):
128123
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")

0 commit comments

Comments
 (0)