Skip to content

Commit 5f300dd

Browse files
committed
Refactor LinearActQuantizedTensor
Summary: * rename to LinearActivationQuantizedTensor * using `implements` util to implement torch function and torch dispatch overwrites Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent e5df48e commit 5f300dd

File tree

12 files changed

+277
-242
lines changed

12 files changed

+277
-242
lines changed

test/quantization/test_quant_api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
from torchao.dtypes import (
2323
AffineQuantizedTensor,
2424
)
25+
from torchao.quantization import (
26+
LinearActivationQuantizedTensor,
27+
)
2528
from torchao.quantization.quant_primitives import (
2629
MappingType,
2730
ZeroPointDomain,
2831
)
2932
from torchao.quantization.subclass import (
30-
LinearActQuantizedTensor,
3133
Int8WeightOnlyQuantizedLinearWeight,
3234
Int4WeightOnlyQuantizedLinearWeight,
3335
)
@@ -504,8 +506,8 @@ def test_quantized_tensor_subclass_8da4w(self):
504506
example_inputs = m.example_inputs()
505507
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
506508

507-
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
508-
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
509+
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
510+
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
509511
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
510512
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
511513

@@ -577,8 +579,8 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
577579
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
578580
quantize_(m, int8_dynamic_activation_int8_weight())
579581

580-
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
581-
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
582+
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
583+
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
582584
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
583585
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
584586

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from torchao.utils import find_multiple
1818
from torchao.dtypes.utils import (
1919
_implements,
20-
_ATEN_OP_OR_TORCH_FN_TABLE,
20+
_dispatch__torch_function__,
21+
_dispatch__torch_dispatch__,
2122
_register_layout_cls,
2223
_get_layout_tensor_constructor,
2324
LayoutType,
@@ -283,17 +284,6 @@ def from_float_static(
283284
def layout_type(self) -> LayoutType:
284285
return self.layout_tensor.layout_type
285286

286-
@classmethod
287-
def __torch_function__(cls, func, types, args=(), kwargs=None):
288-
kwargs = {} if kwargs is None else kwargs
289-
290-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
291-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)
292-
293-
with torch._C.DisableTorchFunctionSubclass():
294-
return func(*args, **kwargs)
295-
296-
297287
def _get_to_kwargs(self, *args, **kwargs):
298288
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
299289
device = self.device if device is None else device
@@ -335,29 +325,23 @@ def _apply_fn_to_data(self, fn):
335325
strides=self.stride(),
336326
)
337327

338-
@classmethod
339-
def __torch_dispatch__(cls, func, types, args, kwargs):
340-
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
341-
# 1. we'll add cpu/cuda version (int4mm etc.)
342-
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
343-
# cpu device + et laytout --> gives current 8da4w executorch representation
344-
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
345-
# cuda device + some layout --> gives cuda kernel
346-
347-
# two scenarios where we currently fall back to vanilla mm:
348-
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
349-
# kernels in CPU as well, see the note above
350-
# 2 - we're given non-floats - quantizing long to int8 is crazy
351328

352-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
353-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
329+
implements = classmethod(_implements)
330+
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
331+
# 1. we'll add cpu/cuda version (int4mm etc.)
332+
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
333+
# cpu device + et laytout --> gives current 8da4w executorch representation
334+
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
335+
# cuda device + some layout --> gives cuda kernel
354336

355-
raise NotImplementedError(
356-
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
357-
)
337+
# two scenarios where we currently fall back to vanilla mm:
338+
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
339+
# kernels in CPU as well, see the note above
340+
# 2 - we're given non-floats - quantizing long to int8 is crazy
341+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
342+
__torch_function__ = classmethod(_dispatch__torch_function__)
358343

359-
def implements(aten_ops_or_torch_fn):
360-
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)
344+
implements = AffineQuantizedTensor.implements
361345

362346
def register_layout_cls(layout_type_class: type(LayoutType)):
363347
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
@@ -749,7 +733,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
749733

750734

751735
@implements(torch.nn.functional.linear)
752-
def functional_linear(*args, **kwargs):
736+
def _(func, types, *args, **kwargs):
753737
input_tensor, weight_tensor, bias = (
754738
args[0],
755739
args[1],
@@ -768,7 +752,7 @@ def functional_linear(*args, **kwargs):
768752
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
769753

770754
@implements([aten.mm.default, aten.addmm.default])
771-
def aten_mm(func, *args, **kwargs):
755+
def _(func, types, *args, **kwargs):
772756
if not args[0].is_floating_point():
773757
raise NotImplementedError(f"{func} is not implemented for non floating point input")
774758

@@ -807,21 +791,21 @@ def aten_mm(func, *args, **kwargs):
807791
return func(input_tensor, weight_tensor)
808792

809793
@implements([aten.detach.default])
810-
def detach(func, *args, **kwargs):
794+
def _(func, types, *args, **kwargs):
811795
return return_and_correct_aliasing(
812796
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
813797
)
814798

815799

816800
@implements([aten.clone.default])
817-
def clone(func, *args, **kwargs):
801+
def _(func, types, *args, **kwargs):
818802
return return_and_correct_aliasing(
819803
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
820804
)
821805

822806

823807
@implements([aten._to_copy.default])
824-
def _to_copy(func, *args, **kwargs):
808+
def _(func, types, *args, **kwargs):
825809
return return_and_correct_aliasing(
826810
func,
827811
args,
@@ -830,7 +814,7 @@ def _to_copy(func, *args, **kwargs):
830814
)
831815

832816
@implements([aten.t.default])
833-
def t(func, *args, **kwargs):
817+
def _(func, types, *args, **kwargs):
834818
block_size = args[0].block_size
835819
assert len(block_size) == 2
836820
transposed_block_size = (block_size[1], block_size[0])

torchao/dtypes/utils.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,28 @@
55
from dataclasses import dataclass
66

77
"""
8-
torch_function and torch_dispatch operator dispatch registrations
9-
10-
first key is a tensor subclass type like AffineQuantizedTensor,
11-
second key is a `func` in __torhc_function__ or __torch_dispatch__,
12-
value is a function that implements the dispatch
8+
Helper function for implementing aten op or torch function dispatch
9+
and dispatching to these implementations.
1310
"""
14-
_ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Callable, Callable]] = defaultdict(dict)
15-
1611
def _implements(cls, aten_ops_or_torch_fns):
1712
"""Use this decorator to implement a function for an aten ops in __torch_dispatch__
1813
(if user passed in a list of ops)
1914
or torch function in __torch_function__ (if user passed in a single object)
15+
16+
class MyTensor(torch.Tensor):
17+
...
18+
implements = classmethod(_implements)
19+
20+
implements = MyTensor.implements
21+
22+
@implements(torch.nn.functional.linear):
23+
def _(func, types, args, kwargs):
24+
...
25+
2026
"""
27+
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
28+
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}
29+
2130
if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
2231
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
2332
def decorator(func):
@@ -26,10 +35,41 @@ def decorator(func):
2635
def wrapper(*args, **kwargs):
2736
return func(*args, **kwargs)
2837

29-
_ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper
38+
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
3039
return func
3140
return decorator
3241

42+
def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None):
43+
"""Use this util function for a common `__torch_function__` implementation
44+
that dispatches to ops/functions registered with `_implements`
45+
46+
class MyTensor(torch.Tensor):
47+
...
48+
__torch_function__ = classmethod(_dispatch__torch_function__)
49+
"""
50+
kwargs = {} if kwargs is None else kwargs
51+
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
52+
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
53+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
54+
55+
with torch._C.DisableTorchFunctionSubclass():
56+
return func(*args, **kwargs)
57+
58+
def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
59+
"""Use this util function for a common `__torch_dispatch__` implementation
60+
that dispatches to ops/functions registered with `_implements`
61+
62+
class MyTensor(torch.Tensor):
63+
...
64+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
65+
"""
66+
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
67+
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
68+
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)
69+
70+
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")
71+
72+
3373
"""
3474
Base class for different LayoutType, should not be instantiated directly
3575
"""

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch import Tensor
5-
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
5+
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
66

77
from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap
88

@@ -72,16 +72,11 @@ def __repr__(self):
7272
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
7373
)
7474

75-
@classmethod
76-
def __torch_dispatch__(cls, func, types, args, kwargs):
77-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
78-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
79-
80-
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
75+
__torch__dispatch = classmethod(_dispatch__torch_dispatch__)
8176

8277

8378
@OptimState4bit.implements(aten.copy_.default)
84-
def _(func, *args, **kwargs):
79+
def _(func, types, *args, **kwargs):
8580
dst = args[0]
8681
src = args[1]
8782

@@ -108,13 +103,13 @@ def _(func, *args, **kwargs):
108103

109104

110105
@OptimState4bit.implements(aten.lerp.Scalar)
111-
def _(func, *args, **kwargs):
106+
def _(func, types, *args, **kwargs):
112107
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]
113108
return func(*args, **kwargs)
114109

115110

116111
@OptimState4bit.implements(aten.view.default)
117-
def _(func, *args, **kwargs):
112+
def _(func, types, *args, **kwargs):
118113
x, shape = args
119114
if len(shape) > 1 or shape[0] != -1:
120115
raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]")

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
3+
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
44

55
from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap
66

@@ -62,16 +62,11 @@ def __repr__(self):
6262
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
6363
)
6464

65-
@classmethod
66-
def __torch_dispatch__(cls, func, types, args, kwargs):
67-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
68-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
69-
70-
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
65+
__torch__dispatch = classmethod(_dispatch__torch_dispatch__)
7166

7267

7368
@OptimState8bit.implements(aten.copy_.default)
74-
def _(func, *args, **kwargs):
69+
def _(func, types, *args, **kwargs):
7570
dst = args[0]
7671
src = args[1]
7772

@@ -94,6 +89,6 @@ def _(func, *args, **kwargs):
9489

9590

9691
@OptimState8bit.implements(aten.lerp.Scalar)
97-
def _(func, *args, **kwargs):
92+
def _(func, types, *args, **kwargs):
9893
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
9994
return func(*args, **kwargs)

torchao/prototype/low_bit_optim/subclass_fp8.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
3+
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__
44

55

66
aten = torch.ops.aten
@@ -66,16 +66,11 @@ def __repr__(self):
6666
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
6767
)
6868

69-
@classmethod
70-
def __torch_dispatch__(cls, func, types, args, kwargs):
71-
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
72-
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
73-
74-
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
69+
__torch__dispatch = classmethod(_dispatch__torch_dispatch__)
7570

7671

7772
@OptimStateFp8.implements(aten.copy_.default)
78-
def _(func, *args, **kwargs):
73+
def _(func, types, *args, **kwargs):
7974
dst = args[0]
8075
src = args[1]
8176

@@ -96,6 +91,6 @@ def _(func, *args, **kwargs):
9691

9792

9893
@OptimStateFp8.implements(aten.lerp.Scalar)
99-
def _(func, *args, **kwargs):
94+
def _(func, types, *args, **kwargs):
10095
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
10196
return func(*args, **kwargs)

torchao/quantization/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ for n, m in model.named_modules():
145145
# note: quantization for activation need to be applied after the weight quantization
146146
# quantization activation (needed by dynamic quantization)
147147
input_quant_func = int8wo_quant # specify how input activation is quantized
148-
m.weight = nn.Parameter(to_linear_act_quantized(m.weight, input_quant_func))
148+
m.weight = nn.Parameter(to_linear_activation_quantized(m.weight, input_quant_func))
149149
```
150150
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
151151
`torch.export.export` and `torch.aot_compile` with the following workaround:

torchao/quantization/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from .weight_only import * # noqa: F403
1313
from .unified import *
1414
from .autoquant import *
15+
from .linear_activation_quantized_tensor import ( # noqat: F403
16+
LinearActivationQuantizedTensor,
17+
to_linear_activation_quantized,
18+
)
1519

1620
__all__ = [
1721
"swap_conv2d_1x1_to_linear"
@@ -34,4 +38,6 @@
3438
"int8_dynamic_activation_int8_weight",
3539
"int4_weight_only",
3640
"int8_weight_only",
41+
"LinearActivationQuantizedTensor",
42+
"to_linear_activation_quantized",
3743
]

0 commit comments

Comments
 (0)