Skip to content

Commit 68eea61

Browse files
committed
use torch.ops.npu prefix and drop redundant torch_npu import
1 parent f3aefca commit 68eea61

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor_npu.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@
2121

2222
aten = torch.ops.aten
2323

24-
try:
25-
import torch_npu
26-
except ImportError:
27-
torch_npu = None
28-
2924

3025
class Int4PlainInt32TensorNPU(TorchAOBaseTensor):
3126
"""
@@ -93,9 +88,6 @@ def from_hp(
9388
w: torch.Tensor,
9489
block_size: List[int],
9590
):
96-
if torch_npu is None:
97-
raise ImportError("Requires torch_npu but it is not installed")
98-
9991
assert w.ndim == 2 and w.device.type == "npu", (
10092
f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}"
10193
)
@@ -143,7 +135,7 @@ def from_hp(
143135
f"torch_npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"
144136
)
145137

146-
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
138+
packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack(
147139
int_data.contiguous(), 0
148140
)
149141

@@ -174,9 +166,6 @@ def _(func, types, args, kwargs):
174166
args[2] if len(args) > 2 else None,
175167
)
176168

177-
if torch_npu is None:
178-
raise ImportError("Requires torch_npu but it is not installed")
179-
180169
assert input_tensor.device.type == "npu", (
181170
f"For NPU device only but got: {input_tensor.device.type}"
182171
)
@@ -219,7 +208,7 @@ def _(func, types, args, kwargs):
219208
# groupwise int4 quantization
220209
groupsize = weight_tensor.block_size[1]
221210

222-
y = torch_npu.npu_weight_quant_batchmatmul(
211+
y = torch.ops.npu.npu_weight_quant_batchmatmul(
223212
x=act_mat,
224213
weight=packed_weight.contiguous().transpose(-1, -2),
225214
antiquant_scale=scale,

0 commit comments

Comments
 (0)