Skip to content

Commit 03c4553

Browse files
authored
Add cachemask variant for fake_quantize_affine (#500)
Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
1 parent 591df26 commit 03c4553

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torchao.quantization.quant_primitives import (
1212
fake_quantize_affine,
13+
fake_quantize_affine_cachemask,
1314
quantize_affine,
1415
dequantize_affine,
1516
choose_qparams_affine,
@@ -523,5 +524,28 @@ def test_fake_quantize_affine(self):
523524
fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
524525
torch.testing.assert_close(dequantized, fake_quantized)
525526

527+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
528+
def test_fake_quantize_affine_cachemask(self):
529+
input = torch.randn(10, 10)
530+
531+
mapping_type = MappingType.SYMMETRIC
532+
block_size = list(input.shape)
533+
for i in range(len(block_size) - 1):
534+
block_size[i] = 1
535+
dtype = torch.int8
536+
eps = 1e-5
537+
quant_min = -127
538+
quant_max = 127
539+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)
540+
541+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
542+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max)
543+
(fake_quantized, mask) = fake_quantize_affine_cachemask(
544+
input, block_size, scale, zero_point, dtype, quant_min, quant_max,
545+
)
546+
expected_mask = torch.full(input.shape, True)
547+
torch.testing.assert_close(dequantized, fake_quantized)
548+
torch.testing.assert_close(expected_mask, mask)
549+
526550
if __name__ == "__main__":
527551
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"quantize_affine",
2525
"dequantize_affine",
2626
"fake_quantize_affine",
27+
"fake_quantize_affine_cachemask",
2728
]
2829

2930
class MappingType(Enum):
@@ -411,6 +412,76 @@ def fake_quantize_affine(
411412
value during quantization
412413
default is ZeroPointDomain.INT
413414
"""
415+
(_, fq) = _do_fake_quantize_affine(
416+
input,
417+
block_size,
418+
scale,
419+
zero_point,
420+
quant_dtype,
421+
quant_min,
422+
quant_max,
423+
zero_point_domain,
424+
)
425+
return fq
426+
427+
428+
def fake_quantize_affine_cachemask(
429+
input: torch.Tensor,
430+
block_size: Tuple[int, ...],
431+
scale: torch.Tensor,
432+
zero_point: Optional[torch.Tensor],
433+
quant_dtype: torch.dtype,
434+
quant_min: Optional[int] = None,
435+
quant_max: Optional[int] = None,
436+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
437+
) -> Tuple[torch.Tensor, torch.Tensor]:
438+
"""
439+
General fake quantize op for quantization-aware training (QAT).
440+
This is equivalent to calling `quantize_affine` + `dequantize_affine`
441+
but without the dtype casts.
442+
443+
Note: Compared to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`,
444+
this consumes more memory and returns an additional outlier mask for
445+
intermediate quantized values.
446+
447+
Args:
448+
Same as :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`.
449+
450+
Returns:
451+
A 2-tuple of (
452+
final fake quantized values,
453+
outlier mask for intermediate quantized values
454+
)
455+
456+
"""
457+
(q, dq) = _do_fake_quantize_affine(
458+
input,
459+
block_size,
460+
scale,
461+
zero_point,
462+
quant_dtype,
463+
quant_min,
464+
quant_max,
465+
zero_point_domain,
466+
)
467+
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
468+
return (dq, mask)
469+
470+
471+
def _do_fake_quantize_affine(
472+
input: torch.Tensor,
473+
block_size: Tuple[int, ...],
474+
scale: torch.Tensor,
475+
zero_point: Optional[torch.Tensor],
476+
quant_dtype: torch.dtype,
477+
quant_min: Optional[int] = None,
478+
quant_max: Optional[int] = None,
479+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
480+
) -> Tuple[torch.Tensor, torch.Tensor]:
481+
"""
482+
Helper function for `fake_quantize_affine` that returns both the
483+
intermediate quantized values and the final dequantized values.
484+
"""
414485
input_dtype = input.dtype
415486
quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max)
416487
q = _quantize_affine_no_dtype_cast(
@@ -432,7 +503,7 @@ def fake_quantize_affine(
432503
zero_point_domain.name,
433504
output_dtype=input_dtype,
434505
)
435-
return dq
506+
return (q, dq)
436507

437508

438509
def choose_qparams_affine(

0 commit comments

Comments
 (0)