diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 0f024dbf61..42ff4e2567 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -115,10 +115,11 @@ To fake quantize embedding in addition to linear, you can additionally call the following with a filter function during the prepare step: ``` +from torchao.quantization.quant_api import _is_linear quantize_( m, IntXQuantizationAwareTrainingConfig(weight_config=weight_config), - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding) or _is_linear(m), ) ```