Closed
Description
In the QAT README, there is a suggestion:
To fake quantize embedding in addition to linear, you can additionally call the following with a filter function during the prepare step:
quantize_(
m,
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
Would the correct approach not be as follows?
from torchao.quantization.quant_api import _is_linear
quantize_(
m,
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding) or _is_linear(m),
)