Skip to content

Commit dc235db

Browse files
committed
update
1 parent e987088 commit dc235db

File tree

3 files changed

+50
-38
lines changed

3 files changed

+50
-38
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DynamicType,
2222
QuantizationArgs,
2323
QuantizationStrategy,
24-
round_to_quantized_type,
24+
round_to_quantized_type_args,
2525
)
2626
from compressed_tensors.quantization.quant_config import QuantizationStatus
2727
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
@@ -473,13 +473,10 @@ def _quantize(
473473
if zero_point is not None:
474474
scaled += zero_point.to(x.dtype)
475475

476-
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
477-
clamped_value = torch.clamp(
478-
scaled,
479-
q_min,
480-
q_max,
476+
# clamp and round
477+
quantized_value = round_to_quantized_type_args(
478+
tensor=scaled, args=args, min=q_min, max=q_max
481479
)
482-
quantized_value = round_to_quantized_type(clamped_value, args)
483480

484481
if dtype is not None:
485482
quantized_value = quantized_value.to(dtype)

src/compressed_tensors/quantization/quant_args.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
"QuantizationType",
3232
"QuantizationStrategy",
3333
"QuantizationArgs",
34-
"round_to_quantized_type",
34+
"round_to_quantized_type_args",
35+
"round_to_quantized_type_dtype",
3536
"ActivationOrdering",
3637
"DynamicType",
3738
]
@@ -392,47 +393,57 @@ def get_observer(self) -> str:
392393
model_config = ConfigDict(extra="forbid")
393394

394395

395-
def _round_dtype(tensor: torch.Tensor, dtype: torch.dtype):
396+
def round_to_quantized_type_dtype(
397+
tensor: torch.Tensor, dtype: torch.dtype
398+
) -> torch.Tensor:
399+
"""
400+
Rounds an input tensor to the nearest quantized representation given a dtype.
401+
The original dtype is kept post-rounding.
402+
403+
:param tensor: tensor to round
404+
:param dtype: dtype to use for rounding
405+
:return: rounded tensor
406+
"""
407+
original_dtype = tensor.dtype
396408
if torch.is_floating_point(torch.tensor([], dtype=dtype)):
397409
finfo = torch.finfo(dtype)
398-
return torch.clamp(tensor, finfo.min, finfo.max).to(dtype)
410+
rounded = torch.clamp(tensor, finfo.min, finfo.max).to(dtype)
399411
else:
400412
iinfo = torch.iinfo(dtype)
401-
return torch.round(torch.clamp(tensor, iinfo.min, iinfo.max))
402-
413+
rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max))
403414

404-
def _round_args(tensor: torch.Tensor, args: QuantizationArgs):
405-
if args.type == QuantizationType.FLOAT:
406-
if args.num_bits == 8:
407-
return tensor.to(FP8_E4M3_DATA.dtype)
408-
elif args.num_bits == 4:
409-
return FP4_E2M1_DATA.cast_to_fp4(tensor)
410-
else:
411-
raise NotImplementedError("Only num_bits in (4, 8) are supported")
412-
elif args.type == QuantizationType.INT:
413-
return torch.round(tensor)
414-
else:
415-
raise ValueError(f"Invalid quantization type {args.type}")
415+
return rounded.to(original_dtype)
416416

417417

418-
def round_to_quantized_type(
418+
def round_to_quantized_type_args(
419419
tensor: torch.Tensor,
420-
args: Optional[QuantizationArgs] = None,
421-
dtype: Optional[torch.dtype] = None,
420+
args: QuantizationArgs,
421+
min: torch.Tensor,
422+
max: torch.Tensor,
422423
) -> torch.Tensor:
423424
"""
424-
Rounds each element of the input tensor to the nearest quantized representation,
425-
keeping to original dtype. This can be done given QuantizationArgs or dtype
425+
Rounds an input tensor to the nearest quantized representation given
426+
qunatization args. The original dtype is kept post-rounding.
426427
427428
:param tensor: tensor to round
428-
:param args: QuantizationArgs to pull appropriate dtype from
429-
:param dtype: dtype to use for rounding
429+
:param args: quantization args to use for rounding
430+
:param min: min value to use for clamping
431+
:param max: max value to use for clamping
430432
:return: rounded tensor
431433
"""
434+
432435
original_dtype = tensor.dtype
433-
if dtype is not None:
434-
rounded = _round_dtype(tensor=tensor, dtype=dtype)
435-
elif args is not None:
436-
rounded = _round_args(tensor=tensor, args=args)
436+
tensor = torch.clamp(tensor, min, max)
437+
if args.type == QuantizationType.FLOAT:
438+
if args.num_bits == 8:
439+
rounded = tensor.to(FP8_E4M3_DATA.dtype)
440+
elif args.num_bits == 4:
441+
rounded = FP4_E2M1_DATA.cast_to_fp4(tensor)
442+
else:
443+
raise NotImplementedError("Only num_bits in (4, 8) are supported")
444+
elif args.type == QuantizationType.INT:
445+
rounded = torch.round(tensor)
446+
else:
447+
raise ValueError(f"Invalid quantization type {args.type}")
437448

438449
return rounded.to(original_dtype)

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
QuantizationArgs,
2525
QuantizationStrategy,
2626
QuantizationType,
27-
round_to_quantized_type,
27+
round_to_quantized_type_dtype,
2828
)
2929
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3030
from compressed_tensors.utils import deprecated
@@ -108,7 +108,9 @@ def calculate_qparams(
108108

109109
# 3. Conditionally round the scale to the quantized dtype, if scale_dtype is set
110110
if quantization_args.scale_dtype is not None:
111-
scales = round_to_quantized_type(scales, dtype=quantization_args.scale_dtype)
111+
scales = round_to_quantized_type_dtype(
112+
scales, dtype=quantization_args.scale_dtype
113+
)
112114

113115
# 4. Update any 0s with small values to
114116
# prevent div by 0
@@ -124,7 +126,9 @@ def calculate_qparams(
124126
)
125127

126128
# 5. Round the zp to zp_dtype
127-
zero_points = round_to_quantized_type(zero_points, dtype=quantization_args.zp_dtype)
129+
zero_points = round_to_quantized_type_dtype(
130+
zero_points, dtype=quantization_args.zp_dtype
131+
)
128132

129133
if scales.ndim == 0:
130134
scales = scales.reshape(1)

0 commit comments

Comments
 (0)