Skip to content

Commit 78f6bb2

Browse files
committed
Add assert for the original weight data type
1 parent 970aa17 commit 78f6bb2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def from_hp(
7676
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
7777
)
7878
assert len(block_size) == w.ndim
79-
79+
assert w.dtype in [torch.float16, torch.bfloat16], (
80+
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
81+
)
8082
original_shape = w.shape
8183
mapping_type = MappingType.ASYMMETRIC
8284
target_dtype = torch.int32

0 commit comments

Comments
 (0)