diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index 07c90505237..95f92ddb887 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -412,6 +412,10 @@ python -m examples.models.llama.export_llama \ -d fp32 ``` +A few notes: +- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized with weight zeros or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and uses weight zeros (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32, but is quantized with scales-only. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations. +- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers. + Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels. The first step is to install ExecuTorch (the same as step 3.1 above): diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 37a4e6952d8..44a6226af23 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -155,6 +155,11 @@ def build_args_parser() -> argparse.ArgumentParser: type=str, help="type of embedding quantization, ',', e.g., '8,1024'.", ) + parser.add_argument( + "--use_shared_embedding", + action="store_true", + help="Whether the embedding/unembedding weights should be shared. Only available with torchao kernels.", + ) parser.add_argument( "--pt2e_quantize", default=None, @@ -684,6 +689,15 @@ def _validate_args(args): if args.num_sharding > 0 and not args.qnn: raise ValueError("Model shard is only supported with qnn backend now.") + if args.use_shared_embedding: + if not ( + args.embedding_quantize is not None + and args.embedding_quantize.startswith("torchao:") + ): + raise ValueError( + "Shared embedding is only supported with torchao quantization." + ) + if ( args.quantization_mode is not None and args.quantization_mode.startswith("torchao:") @@ -1122,6 +1136,21 @@ def _get_source_transforms( # noqa transforms.append(inject_fast_hadamard_transform_native_for_spin_quant) + if args.embedding_quantize: + """ + When this option is selected, it finds all embedding layers and transforms + into quantized embedding equivalent module. + + There are cases where the checkpoint is already quantized, for example + on use_spin_quant is enabled. In that case, it will do the appropriate + transformations based on the given checkpoint first. In those cases, + this wil be a no-op. + """ + modelname = f"{modelname}_e" + transforms.append(get_quant_embedding_transform(args, checkpoint_dtype)) + + # quantization_mode should be applied after embedding_quantize + # to support shared_embedding if args.quantization_mode: """ When this option is selected, it finds all linear layers and transforms @@ -1145,19 +1174,6 @@ def _get_source_transforms( # noqa ) ) - if args.embedding_quantize: - """ - When this option is selected, it finds all embedding layers and transforms - into quantized embedding equivalent module. - - There are cases where the checkpoint is already quantized, for example - on use_spin_quant is enabled. In that case, it will do the appropriate - transformations based on the given checkpoint first. In those cases, - this wil be a no-op. - """ - modelname = f"{modelname}_e" - transforms.append(get_quant_embedding_transform(args, checkpoint_dtype)) - if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 17cff7c63fd..36743bb3b79 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -124,9 +124,7 @@ def quantize( # noqa C901 model, Int8DynamicActivationIntxWeightConfig( weight_dtype=getattr(torch, f"int{bitwidth}"), - granularity=( - PerRow() if group_size in [0, -1] else PerGroup(group_size) - ), + granularity=(PerRow() if group_size == 0 else PerGroup(group_size)), has_weight_zeros=False, ), ) @@ -786,19 +784,43 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None): if args.embedding_quantize.startswith("torchao:"): - bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") + from torchao.experimental.quant_api import ( + EmbeddingQuantizer, + SharedEmbeddingQuantizer, + ) + from torchao.quantization.granularity import PerGroup, PerRow + + quant_args = args.embedding_quantize.split(":")[1].split(",") + if len(quant_args) == 2: + bitwidth, group_size = quant_args + has_weight_zeros = True + else: + bitwidth, group_size, has_weight_zeros = quant_args + + if group_size in ["none", "None", "0"]: + group_size = 0 + group_size = int(group_size) bitwidth = int(bitwidth) - from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer + has_weight_zeros = bool(has_weight_zeros) + weight_dtype = getattr(torch, f"int{bitwidth}") + granularity = PerRow() if group_size == 0 else PerGroup(group_size) def _torchao_embedding_quantizer(model): with torch.no_grad(): - model = IntxWeightEmbeddingQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=bitwidth, - groupsize=group_size, - ).quantize(model) + if not args.use_shared_embedding: + EmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + use_fallback=False, + ).quantize(model) + else: + SharedEmbeddingQuantizer( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + ).quantize(model) return model return _torchao_embedding_quantizer diff --git a/third-party/ao b/third-party/ao index 64bcf4c2575..83eb4903916 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit 64bcf4c25755a783685ba7383000b3bf722523c1 +Subproject commit 83eb4903916340900c140afd0fe35dfaddf23c23