@@ -70,6 +70,12 @@ def __repr__(self):
70
70
# Tensor Subclass Definition #
71
71
##############################
72
72
73
+
74
+ class QuantizedLinearNotImplementedError (NotImplementedError ):
75
+ """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """
76
+ pass
77
+
78
+
73
79
_QLINEAR_DISPATCH_TABLE = {}
74
80
def _register_quantized_linear_dispatch (dispatch_condition , impl ):
75
81
_QLINEAR_DISPATCH_TABLE [dispatch_condition ] = impl
@@ -159,7 +165,7 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
159
165
if dispatch_condition (input_tensor , weight_tensor , bias ):
160
166
return impl (input_tensor , weight_tensor , bias )
161
167
162
- raise NotImplementedError ("No specialized dispatch found for quantized linear op" )
168
+ raise QuantizedLinearNotImplementedError ("No specialized dispatch found for quantized linear op" )
163
169
164
170
def __tensor_flatten__ (self ):
165
171
return ["layout_tensor" ], [self .block_size , self .shape , self .quant_min , self .quant_max , self .zero_point_domain , self .dtype ]
@@ -887,7 +893,7 @@ def _(func, types, args, kwargs):
887
893
# make the branches easier to understand in `_quantized_linear_op`
888
894
try :
889
895
return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
890
- except :
896
+ except QuantizedLinearNotImplementedError :
891
897
if isinstance (input_tensor , AffineQuantizedTensor ):
892
898
input_tensor = input_tensor .dequantize ()
893
899
if isinstance (weight_tensor , AffineQuantizedTensor ):
@@ -910,7 +916,7 @@ def _(func, types, args, kwargs):
910
916
try :
911
917
weight_tensor = weight_tensor .t ()
912
918
return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
913
- except :
919
+ except QuantizedLinearNotImplementedError :
914
920
if isinstance (input_tensor , AffineQuantizedTensor ):
915
921
input_tensor = input_tensor .dequantize ()
916
922
if isinstance (weight_tensor , AffineQuantizedTensor ):
@@ -930,7 +936,7 @@ def _(func, types, args, kwargs):
930
936
try :
931
937
weight_tensor = weight_tensor .t ()
932
938
return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
933
- except :
939
+ except QuantizedLinearNotImplementedError :
934
940
if isinstance (input_tensor , AffineQuantizedTensor ):
935
941
input_tensor = input_tensor .dequantize ()
936
942
if isinstance (weight_tensor , AffineQuantizedTensor ):
0 commit comments