You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add cachemask variant for fake_quantize_affine (#500)
Summary: In QAT, we often wish to filter out the gradients
corresponding to values outside the expected quantization
range, for example:
```
q = _quantize_affine_no_dtype_cast(...)
dq = _dequantize_affine_no_dtype_check(...)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
grad = grad * mask
```
The existing `fake_quantize_affine` returns the dequantized
values only, so callers do not have access to this mask.
This commit adds the variant to this op that returns both
the dequantized values and the mask, similar to
`fake_quantize_per_tensor_affine_cachemask` in core.
Test Plan:
python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
0 commit comments