Skip to content

Commit a05dbd9

Browse files
committed
Refine code
1 parent 2e2dffe commit a05dbd9

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

torchao/quantization/granularity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class PerBlock(Granularity):
110110
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
111111
`block_size`
112112
Attributes:
113-
block_size (Tuple[int, ...]): The size of each quantization group
113+
block_size (tuple[int, ...]): The size of each quantization group
114114
"""
115115

116116
block_size: tuple[int, ...]

torchao/quantization/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,14 +711,14 @@ def get_block_size(
711711
)
712712
for i in range(len(block_size)):
713713
assert input_shape[i] % block_size[i] == 0, (
714-
f"Block size {block_size} does not divide input shape {input_shape}"
714+
f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}"
715715
)
716716
return block_size
717717
elif isinstance(granularity, (PerRow, PerToken)):
718718
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
719719
elif isinstance(granularity, PerGroup):
720720
assert input_shape[-1] % granularity.group_size == 0, (
721-
f"Group size {granularity.group_size} does not divide input shape {input_shape}"
721+
f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}"
722722
)
723723
return (1,) * (len(input_shape) - 1) + (granularity.group_size,)
724724
raise ValueError(f"Unsupported Granularity: {granularity}")

0 commit comments

Comments
 (0)