@@ -1119,24 +1119,24 @@ def _aqt_is_int8(aqt):
1119
1119
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
1120
1120
return (
1121
1121
aqt .layout_tensor .dtype == torch .int8 and
1122
- aqt .quant_min is None or aqt .quant_min == - 128 and
1123
- aqt .quant_max is None or aqt .quant_max == 127
1122
+ ( aqt .quant_min is None or aqt .quant_min == - 128 ) and
1123
+ ( aqt .quant_max is None or aqt .quant_max == 127 )
1124
1124
)
1125
1125
1126
1126
def _aqt_is_int8_reduced_range (aqt ):
1127
1127
return (
1128
1128
aqt .layout_tensor .dtype == torch .int8 and
1129
1129
aqt .quant_min == - 127 and
1130
- aqt .quant_max is None or aqt .quant_max == 127
1130
+ ( aqt .quant_max is None or aqt .quant_max == 127 )
1131
1131
)
1132
1132
1133
- def _aqt_is_uint4 (aqt ):
1133
+ def _aqt_is_tensor_core_tile_uint4 (aqt ):
1134
1134
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
1135
1135
# TODO: use torch.uint4
1136
1136
return (
1137
1137
aqt .layout_tensor .dtype == torch .int32 and
1138
- aqt .quant_min is None or aqt . quant_min == 0 and
1139
- aqt .quant_max is None or aqt . quant_max == 15
1138
+ aqt .quant_min == 0 and
1139
+ aqt .quant_max == 15
1140
1140
)
1141
1141
1142
1142
@@ -1228,7 +1228,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
1228
1228
input_tensor .dtype == torch .bfloat16 and
1229
1229
# weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor
1230
1230
isinstance (weight_tensor , AffineQuantizedTensor ) and
1231
- _aqt_is_uint4 (weight_tensor ) and
1231
+ _aqt_is_tensor_core_tile_uint4 (weight_tensor ) and
1232
1232
weight_tensor .dtype == torch .bfloat16 and
1233
1233
len (weight_tensor .shape ) == 2 and
1234
1234
weight_tensor .zero_point_domain == ZeroPointDomain .FLOAT and
@@ -1429,7 +1429,7 @@ def _linear_fp_act_fp8_weight_impl(
1429
1429
def _linear_fp_act_int4_weight_sparse_marlin_check (input_tensor , weight_tensor , bias ):
1430
1430
return (
1431
1431
isinstance (weight_tensor , AffineQuantizedTensor ) and
1432
- _aqt_is_uint4 (weight_tensor ) and
1432
+ _aqt_is_tensor_core_tile_uint4 (weight_tensor ) and
1433
1433
input_tensor .dtype == torch .float16 and
1434
1434
len (weight_tensor .shape ) == 2 and
1435
1435
weight_tensor .zero_point_domain == ZeroPointDomain .INT and
0 commit comments