Skip to content

Commit df358ce

Browse files
Minor fix for logical operators precedence in _aqt_is_* checks. (#899)
1 parent a4221df commit df358ce

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,24 +1119,24 @@ def _aqt_is_int8(aqt):
11191119
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
11201120
return (
11211121
aqt.layout_tensor.dtype == torch.int8 and
1122-
aqt.quant_min is None or aqt.quant_min == -128 and
1123-
aqt.quant_max is None or aqt.quant_max == 127
1122+
(aqt.quant_min is None or aqt.quant_min == -128) and
1123+
(aqt.quant_max is None or aqt.quant_max == 127)
11241124
)
11251125

11261126
def _aqt_is_int8_reduced_range(aqt):
11271127
return (
11281128
aqt.layout_tensor.dtype == torch.int8 and
11291129
aqt.quant_min == -127 and
1130-
aqt.quant_max is None or aqt.quant_max == 127
1130+
(aqt.quant_max is None or aqt.quant_max == 127)
11311131
)
11321132

1133-
def _aqt_is_uint4(aqt):
1133+
def _aqt_is_tensor_core_tile_uint4(aqt):
11341134
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
11351135
# TODO: use torch.uint4
11361136
return (
11371137
aqt.layout_tensor.dtype == torch.int32 and
1138-
aqt.quant_min is None or aqt.quant_min == 0 and
1139-
aqt.quant_max is None or aqt.quant_max == 15
1138+
aqt.quant_min == 0 and
1139+
aqt.quant_max == 15
11401140
)
11411141

11421142

@@ -1228,7 +1228,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
12281228
input_tensor.dtype == torch.bfloat16 and
12291229
# weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor
12301230
isinstance(weight_tensor, AffineQuantizedTensor) and
1231-
_aqt_is_uint4(weight_tensor) and
1231+
_aqt_is_tensor_core_tile_uint4(weight_tensor) and
12321232
weight_tensor.dtype == torch.bfloat16 and
12331233
len(weight_tensor.shape) == 2 and
12341234
weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and
@@ -1429,7 +1429,7 @@ def _linear_fp_act_fp8_weight_impl(
14291429
def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
14301430
return (
14311431
isinstance(weight_tensor, AffineQuantizedTensor) and
1432-
_aqt_is_uint4(weight_tensor) and
1432+
_aqt_is_tensor_core_tile_uint4(weight_tensor) and
14331433
input_tensor.dtype == torch.float16 and
14341434
len(weight_tensor.shape) == 2 and
14351435
weight_tensor.zero_point_domain == ZeroPointDomain.INT and

0 commit comments

Comments
 (0)