Skip to content

Commit 915e8d3

Browse files
committed
make register/deregister fn public
1 parent 5ae47f2 commit 915e8d3

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def test_to_device(self, apply_quant):
9090
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
9191
def test_register_new_dispatch(self):
9292
from torchao.dtypes.affine_quantized_tensor import (
93-
_register_aqt_quantized_linear_dispatch,
94-
_deregister_aqt_quantized_linear_dispatch,
93+
register_aqt_quantized_linear_dispatch,
94+
deregister_aqt_quantized_linear_dispatch,
9595
)
9696
from torchao.dtypes import to_affine_quantized_intx
9797
from torchao.dtypes import AffineQuantizedTensor
@@ -109,7 +109,7 @@ def impl(input_tensor, weight_tensor, bias):
109109
# quantized linear operator here
110110
assert False, "dispatching to my impl for uint6 weight only quant"
111111

112-
_register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
112+
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
113113

114114
def apply_uint6_weight_only_quant(linear):
115115
linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False)
@@ -122,7 +122,7 @@ def apply_uint6_weight_only_quant(linear):
122122
with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"):
123123
l(example_input)
124124

125-
_deregister_aqt_quantized_linear_dispatch(dispatch_condition)
125+
deregister_aqt_quantized_linear_dispatch(dispatch_condition)
126126

127127

128128

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class QuantizedLinearNotImplementedError(NotImplementedError):
9292

9393

9494
_AQT_QLINEAR_DISPATCH_TABLE = {}
95-
def _register_aqt_quantized_linear_dispatch(dispatch_condition, impl):
95+
def register_aqt_quantized_linear_dispatch(dispatch_condition, impl):
9696
"""Register a dispatch for quantized linear op with dispatch_condition function and impl function
9797
both takes three arguments:
9898
input_tensor: dimension is (M1, M2, ..., in_features)
@@ -108,7 +108,7 @@ def _register_aqt_quantized_linear_dispatch(dispatch_condition, impl):
108108
"""
109109
_AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl
110110

111-
def _deregister_aqt_quantized_linear_dispatch(dispatch_condition):
111+
def deregister_aqt_quantized_linear_dispatch(dispatch_condition):
112112
if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE:
113113
del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition]
114114
else:

0 commit comments

Comments
 (0)