Skip to content

Commit a11e455

Browse files
committed
Fix CI
1 parent 93e68e6 commit a11e455

File tree

4 files changed

+15
-5
lines changed

4 files changed

+15
-5
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torchao.utils import (
3030
TORCH_VERSION_AFTER_2_3,
3131
TORCH_VERSION_AFTER_2_4,
32+
TORCH_VERSION_AFTER_2_5,
3233
is_fbcode,
3334
)
3435

@@ -99,7 +100,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
99100
.to(torch.int32)
100101
.reshape_as(w)
101102
)
102-
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
103+
if TORCH_VERSION_AFTER_2_5:
104+
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
103105

104106
return w_int4x8
105107

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from typing import ClassVar
2626
from dataclasses import dataclass
27+
from torchao.utils import TORCH_VERSION_AFTER_2_5
2728

2829
aten = torch.ops.aten
2930

@@ -532,8 +533,11 @@ def from_plain(
532533
layout_type: LayoutType
533534
):
534535
assert isinstance(layout_type, TensorCoreTiledLayoutType)
535-
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
536-
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
536+
if TORCH_VERSION_AFTER_2_5:
537+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
538+
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
539+
else:
540+
assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
537541
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
538542
scale = scale.reshape(int_data.shape[0], -1)
539543
zero_point = zero_point.reshape(int_data.shape[0], -1)

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from hqq.core.utils import *
1313

1414
import torch.nn.functional as F
15+
from torchao.utils import TORCH_VERSION_AFTER_2_5
1516

1617

1718
class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
@@ -198,7 +199,8 @@ def hqq_quants_to_torch_quants(
198199
.reshape(shape)
199200
.contiguous()
200201
)
201-
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
202+
if TORCH_VERSION_AFTER_2_5:
203+
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
202204

203205
# group_dequantize_tensor_from_qparams
204206
# W_r = W_q*scales + min_val

torchao/quantization/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
dequantize_affine,
1818
int_scaled_matmul,
1919
)
20+
from torchao.utils import TORCH_VERSION_AFTER_2_5
2021

2122
__all__ = [
2223
"compute_error",
@@ -349,7 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
349350
quant_max = 2 ** n_bit - 1
350351

351352
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
352-
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
353+
if TORCH_VERSION_AFTER_2_5:
354+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
353355
return int_data
354356

355357
def groupwise_affine_dequantize_tensor_from_qparams(

0 commit comments

Comments
 (0)