From 61a1d31a5b2d9ae2e6b9e7c4b70c06f22d37df5d Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Wed, 11 Sep 2024 09:52:54 +0800 Subject: [PATCH 1/9] Bring `torch.compile` to `quant_block_v2_`. (#18) Signed-off-by: yiliu30 --- torchao/prototype/autoround/autoround_llm.py | 12 +++++++++++- torchao/prototype/autoround/core.py | 8 +++++++- torchao/prototype/autoround/eval_autoround.py | 8 ++++++++ torchao/prototype/autoround/utils.py | 16 +++++++++++++--- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/autoround/autoround_llm.py b/torchao/prototype/autoround/autoround_llm.py index 2d464be0f4..eb324a9b02 100644 --- a/torchao/prototype/autoround/autoround_llm.py +++ b/torchao/prototype/autoround/autoround_llm.py @@ -5,7 +5,7 @@ import torchao import torchao.prototype.autoround.utils as ar_utils - +from typing import Optional from torchao.prototype.autoround.core import ( apply_auto_round, prepare_model_for_applying_auto_round_, @@ -29,6 +29,7 @@ def quantize_model_with_autoround_( bs: int = 8, nsamples: int = 128, use_optimized_layer_output: bool = False, + compile_optimization_process: Optional[bool] = False, ): # Step 1. Prepare the model for applying auto-round @@ -42,6 +43,7 @@ def quantize_model_with_autoround_( group_size, iters, use_optimized_layer_output, + compile_optimization_process, device=device, ) @@ -107,6 +109,7 @@ def main(args): bs=args.train_bs, nsamples=args.nsamples, use_optimized_layer_output=args.use_optimized_layer_output, + compile_optimization_process=args.compile_optimization_process, ) # Revert the `use_cache` for generation stage. model.config.use_cache = True @@ -168,6 +171,13 @@ def main(args): action="store_true", help="Use the optimized layer output for next layer or not", ) + parser.add_argument( + "-c", + "--compile_optimization_process", + default=False, + action="store_true", + help="Whether to compile the optimization process", + ) parser.add_argument( "-d", "--model_device", diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py index 342f14d825..bede40cd70 100644 --- a/torchao/prototype/autoround/core.py +++ b/torchao/prototype/autoround/core.py @@ -20,6 +20,7 @@ class _AutoRoundConfig: group_size: int = 128 iters: int = 200 use_optimized_layer_output: bool = False + compile_optimization_process: bool = False _auto_round_config = _AutoRoundConfig() @@ -82,6 +83,7 @@ def prepare_model_for_applying_auto_round_( group_size: int = 128, iters: int = 200, use_optimized_layer_output: bool = False, + compile_optimization_process: Optional[bool] = False, device: Optional[torch.types.Device] = None, ): """Prepares the model for applying auto round optimization. @@ -94,6 +96,7 @@ def prepare_model_for_applying_auto_round_( group_size (int, optional): The group size for quantization. Defaults to 128. iters (int, optional): The number of iterations for optimization. Defaults to 200. use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to False. + compile_optimization_process (Optional[bool], optional): Whether to compile the optimization process. Defaults to False. device (Optional[torch.types.Device], optional): The device to use for accelrating optimization and calibration. Defaults to None. """ @@ -105,6 +108,7 @@ def prepare_model_for_applying_auto_round_( _auto_round_config.group_size = group_size _auto_round_config.iters = iters _auto_round_config.use_optimized_layer_output = use_optimized_layer_output + _auto_round_config.compile_optimization_process = compile_optimization_process logging.warning(f"config {_auto_round_config}") @@ -315,6 +319,8 @@ def _apply_auto_round_optimization( amp=True, model_dtype=next(block.parameters()).dtype, ) + if config.compile_optimization_process: + rounder.quant_block_v2_ = torch.compile(rounder.quant_block_v2_) with torch.enable_grad(): rounder.quant_block_v2_( @@ -326,7 +332,7 @@ def _apply_auto_round_optimization( block.to(orig_device) -@ar_utils.dump_elapsed_time() +@ar_utils.dump_elapsed_time(record=True) @torch.no_grad() def apply_auto_round_optimization( module: torch.nn.Module, diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index e72205aba9..c3f7a1c7ea 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -121,6 +121,7 @@ def main(args): bs=args.train_bs, nsamples=args.nsamples, use_optimized_layer_output=args.use_optimized_layer_output, + compile_optimization_process=args.compile_optimization_process, ) quantized_layer_cnt = ar_utils.count_tensor_of_type( model, torchao.dtypes.AffineQuantizedTensor @@ -184,6 +185,13 @@ def main(args): action="store_true", help="Use the optimized layer output for next layer or not", ) + parser.add_argument( + "-c", + "--compile_optimization_process", + default=False, + action="store_true", + help="Whether to compile the optimization process", + ) parser.add_argument( "-d", "--model_device", diff --git a/torchao/prototype/autoround/utils.py b/torchao/prototype/autoround/utils.py index ea62b2e34d..3a47289f55 100644 --- a/torchao/prototype/autoround/utils.py +++ b/torchao/prototype/autoround/utils.py @@ -6,6 +6,7 @@ import numpy as np import torch +import collections def _is_package_available(pkg_name, metadata_name=None): @@ -149,8 +150,8 @@ def get_float_model_info(model_name_or_path, torch_dtype=torch.float32): ) return model, tokenizer, decoder_cls - -def dump_elapsed_time(customized_msg=""): +execution_records = collections.defaultdict(list) +def dump_elapsed_time(customized_msg="", record=False): """Get the elapsed time for decorated functions. Args: @@ -164,13 +165,22 @@ def fi(*args, **kwargs): start = time.time() res = func(*args, **kwargs) end = time.time() + dur = round((end - start) * 1000, 2) + if record: + execution_records[func.__qualname__].append(dur) logging.warning( "%s elapsed time: %s ms" % ( customized_msg if customized_msg else func.__qualname__, - round((end - start) * 1000, 2), + dur, ) ) + if record: + avg_time = sum(execution_records[func.__qualname__])/len(execution_records[func.__qualname__]) + std_time = np.std(execution_records[func.__qualname__]) + logging.warning( + f"For {func.__qualname__}, the average elapsed time: {avg_time: .2f} ms, the std: {std_time: .2f} ms" + ) return res return fi From 77e5dcc0f02c1a9f63bf087270e429d9d07dd048 Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Wed, 11 Sep 2024 10:07:21 +0800 Subject: [PATCH 2/9] Add `AO_USE_DETERMINISTIC_ALGORITHMS` for reproducing results (#19) Signed-off-by: yiliu30 --- torchao/prototype/autoround/eval_autoround.py | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index c3f7a1c7ea..1fcae4ecc3 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -1,16 +1,40 @@ import argparse +import logging +import os -import torchao.prototype.autoround.utils as ar_utils - -ar_utils.freeze_random(42) import torch -torch.use_deterministic_algorithms(True, warn_only=True) import torchao - +import torchao.prototype.autoround.utils as ar_utils import torchao.quantization from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +logger = logging.getLogger(__name__) + +ar_utils.freeze_random(42) + + +def _use_deterministic(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True, warn_only=False) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + logger.warning( + ( + "Reproducibility is enabled with `AO_USE_DETERMINISTIC_ALGORITHMS=1`, which sets " + "`torch.use_deterministic_algorithms(True, warn_only=False)` and " + "environment variable `CUBLAS_WORKSPACE_CONFIG` to `:4096:8`.\n" + "Please note that this may impact performance, or cause crashes if the model includes non-deterministic operations." + ) + ) + + +AO_USE_DETERMINISTIC_ALGORITHMS = ( + os.environ.get("AO_USE_DETERMINISTIC_ALGORITHMS", "0") == "1" +) +if AO_USE_DETERMINISTIC_ALGORITHMS: + _use_deterministic() + @ar_utils.dump_elapsed_time() def run_evaluation(model, tokenizer, tasks, compile=False, batch_size=4): @@ -62,7 +86,9 @@ def main(args): ) model.eval() model_device = args.model_device - ar_utils.gen_text(model, tokenizer, "Float model", max_length=50) + # `sorted_logits` does not have a deterministic implementation + if not AO_USE_DETERMINISTIC_ALGORITHMS: + ar_utils.gen_text(model, tokenizer, "Float model", max_length=50) model = model.to(model_device) model.config.use_cache = False msg = "Float-model" if args.eval_float_model else "Quantized-model" @@ -127,7 +153,8 @@ def main(args): model, torchao.dtypes.AffineQuantizedTensor ) msg += f" quantized {quantized_layer_cnt} Linear layers " - ar_utils.gen_text(model, tokenizer, msg, max_length=50) + if not AO_USE_DETERMINISTIC_ALGORITHMS: + ar_utils.gen_text(model, tokenizer, msg, max_length=50) bench_accuracy(model, tokenizer, tasks=args.tasks, msg=msg) From b3c70ff99b27fd9ea58e178c28481bad1d126568 Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Wed, 11 Sep 2024 11:03:56 +0800 Subject: [PATCH 3/9] Add `gradient_accumulate_steps` and update results (#20) Signed-off-by: yiliu30 --- torchao/prototype/autoround/README.md | 39 +++++++++++-------- torchao/prototype/autoround/autoround_llm.py | 9 +++++ torchao/prototype/autoround/core.py | 9 ++++- torchao/prototype/autoround/eval_autoround.py | 7 ++++ 4 files changed, 45 insertions(+), 19 deletions(-) diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md index 11671009b5..7de729cf3f 100644 --- a/torchao/prototype/autoround/README.md +++ b/torchao/prototype/autoround/README.md @@ -71,31 +71,36 @@ quantize_(model, apply_auto_round(), is_target_module) ## End-to-End Results ### [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) -| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | -| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- | -| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 | -| auto-round-4bit | 0.6988 | 0.6533 | 0.7949 | 0.7372 | 0.5837 | 0.7250 | -| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 | +| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | +| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- | +| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 | +| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 | +| autoround-4bit | 0.6996 | 0.6669 | 0.7916 | 0.7285 | 0.5846 | 0.7262 | +| autoround-4bit* | 0.7010 | 0.6621 | 0.7976 | 0.7316 | 0.5847 | 0.7291 | ### [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) -| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | -| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- | -| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 | -| auto-round-4bit | 0.6818 | 0.6232 | 0.7862 | 0.7230 | 0.5661 | 0.7105 | -| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 | +| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | +| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- | +| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 | +| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 | +| autoround-4bit | 0.6796 | 0.6237 | 0.7758 | 0.7198 | 0.5664 | 0.7122 | +| autoround-4bit* | 0.6827 | 0.6273 | 0.7737 | 0.7348 | 0.5657 | 0.7120 | ### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) -| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | -| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- | -| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.577 | 0.7070 | -| auto-round-4bit | 0.6327 | 0.4534 | 0.7590 | 0.6661 | 0.5706 | 0.7143 | -| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 | +| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai | +| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- | +| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.5770 | 0.7070 | +| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 | +| autoround-4bit | 0.6311 | 0.4548 | 0.7606 | 0.6614 | 0.5717 | 0.7072 | +| autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 | > [!NOTE] -> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`.
> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`. -> - If the model includes operations without a deterministic implementation (such as Flash Attention), the results may differ slightly. +> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`.
+> - `auto-round-4bit*` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=4`, `gradient_accumulate_steps=2`, `group_size=128`, and `quant_lm_head=False`.
+> - Compared to `auto-round-4bit`(train_bs=8), the `auto-round-4bit*` accumulates two batches(4 samples per batch) before performing the backward pass.
+> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`. ## Credits diff --git a/torchao/prototype/autoround/autoround_llm.py b/torchao/prototype/autoround/autoround_llm.py index eb324a9b02..05a9c5a087 100644 --- a/torchao/prototype/autoround/autoround_llm.py +++ b/torchao/prototype/autoround/autoround_llm.py @@ -29,6 +29,7 @@ def quantize_model_with_autoround_( bs: int = 8, nsamples: int = 128, use_optimized_layer_output: bool = False, + gradient_accumulate_steps: Optional[int] = 1, compile_optimization_process: Optional[bool] = False, ): # Step 1. Prepare the model for applying auto-round @@ -43,6 +44,7 @@ def quantize_model_with_autoround_( group_size, iters, use_optimized_layer_output, + gradient_accumulate_steps, compile_optimization_process, device=device, ) @@ -109,6 +111,7 @@ def main(args): bs=args.train_bs, nsamples=args.nsamples, use_optimized_layer_output=args.use_optimized_layer_output, + gradient_accumulate_steps=args.gradient_accumulate_steps, compile_optimization_process=args.compile_optimization_process, ) # Revert the `use_cache` for generation stage. @@ -159,6 +162,12 @@ def main(args): type=int, help="Sequence length for calibration process", ) + parser.add_argument( + "--gradient_accumulate_steps", + default=1, + type=int, + help="Number of gradient accumulation steps", + ) parser.add_argument( "--quant_lm_head", default=False, diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py index bede40cd70..46cc7b4b73 100644 --- a/torchao/prototype/autoround/core.py +++ b/torchao/prototype/autoround/core.py @@ -20,6 +20,7 @@ class _AutoRoundConfig: group_size: int = 128 iters: int = 200 use_optimized_layer_output: bool = False + gradient_accumulate_steps: int = 1 compile_optimization_process: bool = False @@ -83,6 +84,7 @@ def prepare_model_for_applying_auto_round_( group_size: int = 128, iters: int = 200, use_optimized_layer_output: bool = False, + gradient_accumulate_steps: Optional[int] = 1, compile_optimization_process: Optional[bool] = False, device: Optional[torch.types.Device] = None, ): @@ -96,8 +98,9 @@ def prepare_model_for_applying_auto_round_( group_size (int, optional): The group size for quantization. Defaults to 128. iters (int, optional): The number of iterations for optimization. Defaults to 200. use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to False. - compile_optimization_process (Optional[bool], optional): Whether to compile the optimization process. Defaults to False. - device (Optional[torch.types.Device], optional): The device to use for accelrating optimization and calibration. + gradient_accumulate_steps (Optional[int]): The number of gradient accumulation steps. Defaults to 1. + compile_optimization_process (Optional[bool]): Whether to compile the optimization process. Defaults to False. + device (Optional[torch.types.Device]): The device to use for accelrating optimization and calibration. Defaults to None. """ _multi_tensor_config.device = device @@ -108,6 +111,7 @@ def prepare_model_for_applying_auto_round_( _auto_round_config.group_size = group_size _auto_round_config.iters = iters _auto_round_config.use_optimized_layer_output = use_optimized_layer_output + _auto_round_config.gradient_accumulate_steps = gradient_accumulate_steps _auto_round_config.compile_optimization_process = compile_optimization_process logging.warning(f"config {_auto_round_config}") @@ -316,6 +320,7 @@ def _apply_auto_round_optimization( bits=config.bits, iters=config.iters, group_size=config.group_size, + gradient_accumulate_steps=config.gradient_accumulate_steps, amp=True, model_dtype=next(block.parameters()).dtype, ) diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index 1fcae4ecc3..003d688587 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -147,6 +147,7 @@ def main(args): bs=args.train_bs, nsamples=args.nsamples, use_optimized_layer_output=args.use_optimized_layer_output, + gradient_accumulate_steps=args.gradient_accumulate_steps, compile_optimization_process=args.compile_optimization_process, ) quantized_layer_cnt = ar_utils.count_tensor_of_type( @@ -200,6 +201,12 @@ def main(args): type=int, help="Sequence length for calibration process", ) + parser.add_argument( + "--gradient_accumulate_steps", + default=1, + type=int, + help="Number of gradient accumulation steps", + ) parser.add_argument( "--quant_lm_head", default=False, From 27c767183741a6fe33610d5ae02b51d4d88565b1 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 13 Sep 2024 05:29:41 -0400 Subject: [PATCH 4/9] update the readme Signed-off-by: yiliu30 --- torchao/prototype/autoround/README.md | 24 ++++++++++++++++++++---- torchao/prototype/autoround/utils.py | 4 +++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md index 7de729cf3f..cf3b573cd1 100644 --- a/torchao/prototype/autoround/README.md +++ b/torchao/prototype/autoround/README.md @@ -10,6 +10,23 @@ Auto-Round is an advanced quantization algorithm designed for low-bit LLM infere python autoround_llm.py -m /model/name/or/path ``` +This script allows you to apply `Auto-Round` on a given model directly, more configurations options are list below: + +| Argument |Default | Description | +|------------------------------------|----------------------------|-------------------------------------------------------------------| +| `model_name_or_path` |`"facebook/opt-125m"` | Pretrained model name or path | +| `dataset_name` | `"NeelNanda/pile-10k"` | Dataset name for calibration | +| `iters` | 200 | Number of steps for optimizing each block | +| `bits` | 4 | Number of bits for quantization | +| `train_bs` | 8 | Batch size for calibration | +| `nsamples` | 128 | Number of samples for calibration process | +| `seqlen` | 2048 | Sequence length for each samples | +| `gradient_accumulate_steps` | 1 | Number of gradient accumulation steps per block optimization | +| `quant_lm_head` | `False` | Whether to quantize the `lm_head` | +| `use_optimized_layer_output` | `False` | Whether to use optimized layer output as input for the next layer | +| `compile_optimization_process` | `False` | Whether to compile the optimization process | +| `model_device` | `"cuda"` | Device for loading the float model (choices: `cpu`, `cuda`) | + > [!NOTE] > Before running, ensure you have installed the `auto-round` with `pip install -r requirements.txt`. @@ -96,10 +113,9 @@ quantize_(model, apply_auto_round(), is_target_module) | autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 | > [!NOTE] -> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`. -> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`.
-> - `auto-round-4bit*` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=4`, `gradient_accumulate_steps=2`, `group_size=128`, and `quant_lm_head=False`.
-> - Compared to `auto-round-4bit`(train_bs=8), the `auto-round-4bit*` accumulates two batches(4 samples per batch) before performing the backward pass.
+> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`.
+> - `auto-round-4bit` uses the deafult configuration from quick start.
+> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `train_bs=4`, which accumulating two batches(4 samples per batch) before performing the backward pass.
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`. diff --git a/torchao/prototype/autoround/utils.py b/torchao/prototype/autoround/utils.py index 3a47289f55..606fb99011 100644 --- a/torchao/prototype/autoround/utils.py +++ b/torchao/prototype/autoround/utils.py @@ -111,8 +111,10 @@ def see_memory_usage(message: str = "", force=True): @torch.no_grad() def gen_text( - model, tokenizer, msg="", device="cuda", prompt="What's AI?", max_length=20 + model, tokenizer, msg="", device=None, prompt="What's AI?", max_length=20 ): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" inputs = tokenizer(prompt, return_tensors="pt") model = model.to(device) new_tokens = model.generate(**inputs.to(device), max_length=max_length) From 00530e1e9e1c4f4d228fef00c4f552bd970c5fd6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 13 Sep 2024 05:35:37 -0400 Subject: [PATCH 5/9] udpate Signed-off-by: yiliu30 --- torchao/prototype/autoround/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md index cf3b573cd1..7ae8f122bc 100644 --- a/torchao/prototype/autoround/README.md +++ b/torchao/prototype/autoround/README.md @@ -113,8 +113,8 @@ quantize_(model, apply_auto_round(), is_target_module) | autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 | > [!NOTE] -> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`.
-> - `auto-round-4bit` uses the deafult configuration from quick start.
+> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`int4_weight_only(group_size=128)`) while leaving the `lm-head` unquantized.
+> - `auto-round-4bit` uses the deafult configuration from [quick start](#quick-start).
> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `train_bs=4`, which accumulating two batches(4 samples per batch) before performing the backward pass.
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`. From 35a9be26aa583d7edaca470f25549846c1bc22ab Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 13 Sep 2024 06:12:32 -0400 Subject: [PATCH 6/9] update the desc Signed-off-by: yiliu30 --- torchao/prototype/autoround/README.md | 3 ++- torchao/prototype/autoround/autoround_llm.py | 27 ++++++++++++------- torchao/prototype/autoround/core.py | 3 ++- torchao/prototype/autoround/eval_autoround.py | 23 +++++++++++----- torchao/prototype/autoround/utils.py | 13 +++++---- 5 files changed, 46 insertions(+), 23 deletions(-) diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md index 7ae8f122bc..5b5bd3959e 100644 --- a/torchao/prototype/autoround/README.md +++ b/torchao/prototype/autoround/README.md @@ -21,7 +21,8 @@ This script allows you to apply `Auto-Round` on a given model directly, more con | `train_bs` | 8 | Batch size for calibration | | `nsamples` | 128 | Number of samples for calibration process | | `seqlen` | 2048 | Sequence length for each samples | -| `gradient_accumulate_steps` | 1 | Number of gradient accumulation steps per block optimization | +| `group_size` | 128 | Group size for quantization | +| `gradient_accumulate_steps` | 1 | Number of steps for accumulating gradients
before performing the backward pass | | `quant_lm_head` | `False` | Whether to quantize the `lm_head` | | `use_optimized_layer_output` | `False` | Whether to use optimized layer output as input for the next layer | | `compile_optimization_process` | `False` | Whether to compile the optimization process | diff --git a/torchao/prototype/autoround/autoround_llm.py b/torchao/prototype/autoround/autoround_llm.py index 05a9c5a087..45331b7a0a 100644 --- a/torchao/prototype/autoround/autoround_llm.py +++ b/torchao/prototype/autoround/autoround_llm.py @@ -1,11 +1,11 @@ import argparse import logging +from typing import Optional import torch import torchao import torchao.prototype.autoround.utils as ar_utils -from typing import Optional from torchao.prototype.autoround.core import ( apply_auto_round, prepare_model_for_applying_auto_round_, @@ -30,7 +30,7 @@ def quantize_model_with_autoround_( nsamples: int = 128, use_optimized_layer_output: bool = False, gradient_accumulate_steps: Optional[int] = 1, - compile_optimization_process: Optional[bool] = False, + compile_optimization_process: Optional[bool] = False, ): # Step 1. Prepare the model for applying auto-round @@ -130,7 +130,7 @@ def main(args): "--model_name_or_path", type=str, default="facebook/opt-125m", - help="Model name or path", + help="Pretrained model name or path", ) parser.add_argument( "--dataset_name", @@ -142,13 +142,13 @@ def main(args): "--iters", default=200, type=int, - help="Number of iterations for auto-round optimization", + help="Number of steps for optimizing each block", ) parser.add_argument( "--bits", default=4, type=int, help="Number of bits for quantization" ) parser.add_argument( - "--train_bs", default=8, type=int, help="Batch size for auto-round optimization" + "--train_bs", default=8, type=int, help="Batch size for calibration" ) parser.add_argument( "--nsamples", @@ -156,29 +156,38 @@ def main(args): type=int, help="Number of samples for calibration process", ) + parser.add_argument( + "--group_size", + default=128, + type=int, + help="Group size for quantization", + ) parser.add_argument( "--seqlen", default=2048, type=int, - help="Sequence length for calibration process", + help="Sequence length for each samples", ) parser.add_argument( "--gradient_accumulate_steps", default=1, type=int, - help="Number of gradient accumulation steps", + help=( + "Number of steps for accumulating gradients before performing" + "the backward pass when optimizing each target module" + ), ) parser.add_argument( "--quant_lm_head", default=False, action="store_true", - help="Quantize the `lm_head` or not", + help="Whether to quantize the `lm_head`", ) parser.add_argument( "--use_optimized_layer_output", default=False, action="store_true", - help="Use the optimized layer output for next layer or not", + help="Whether to use optimized layer output as input for the next layer", ) parser.add_argument( "-c", diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py index 46cc7b4b73..05d1552c55 100644 --- a/torchao/prototype/autoround/core.py +++ b/torchao/prototype/autoround/core.py @@ -98,7 +98,8 @@ def prepare_model_for_applying_auto_round_( group_size (int, optional): The group size for quantization. Defaults to 128. iters (int, optional): The number of iterations for optimization. Defaults to 200. use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to False. - gradient_accumulate_steps (Optional[int]): The number of gradient accumulation steps. Defaults to 1. + gradient_accumulate_steps (Optional[int]): Number of steps for accumulating gradients before + performing the backward pass when optimizing each target module. Defaults to 1. compile_optimization_process (Optional[bool]): Whether to compile the optimization process. Defaults to False. device (Optional[torch.types.Device]): The device to use for accelrating optimization and calibration. Defaults to None. diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index 003d688587..33a814e307 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -169,19 +169,25 @@ def main(args): "--model_name_or_path", type=str, default="facebook/opt-125m", - help="Model name or path", + help="Pretrained model name or path", + ) + parser.add_argument( + "--dataset_name", + type=str, + default="NeelNanda/pile-10k", + help="Dataset name for calibration", ) parser.add_argument( "--iters", default=200, type=int, - help="Number of iterations for auto-round optimization", + help="Number of steps for optimizing each block", ) parser.add_argument( "--bits", default=4, type=int, help="Number of bits for quantization" ) parser.add_argument( - "--train_bs", default=8, type=int, help="Batch size for auto-round optimization" + "--train_bs", default=8, type=int, help="Batch size for calibration" ) parser.add_argument( "--nsamples", @@ -199,25 +205,28 @@ def main(args): "--seqlen", default=2048, type=int, - help="Sequence length for calibration process", + help="Sequence length for each samples", ) parser.add_argument( "--gradient_accumulate_steps", default=1, type=int, - help="Number of gradient accumulation steps", + help=( + "Number of steps for accumulating gradients before performing" + "the backward pass when optimizing each target module" + ), ) parser.add_argument( "--quant_lm_head", default=False, action="store_true", - help="Quantize the `lm_head` or not", + help="Whether to quantize the `lm_head`", ) parser.add_argument( "--use_optimized_layer_output", default=False, action="store_true", - help="Use the optimized layer output for next layer or not", + help="Whether to use optimized layer output as input for the next layer", ) parser.add_argument( "-c", diff --git a/torchao/prototype/autoround/utils.py b/torchao/prototype/autoround/utils.py index 606fb99011..634c2269a1 100644 --- a/torchao/prototype/autoround/utils.py +++ b/torchao/prototype/autoround/utils.py @@ -1,12 +1,12 @@ # ==------------------------------------------------------------------------------------------== # Utils for the auto-round # ==------------------------------------------------------------------------------------------== +import collections import logging import random import numpy as np import torch -import collections def _is_package_available(pkg_name, metadata_name=None): @@ -110,9 +110,7 @@ def see_memory_usage(message: str = "", force=True): @torch.no_grad() -def gen_text( - model, tokenizer, msg="", device=None, prompt="What's AI?", max_length=20 -): +def gen_text(model, tokenizer, msg="", device=None, prompt="What's AI?", max_length=20): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" inputs = tokenizer(prompt, return_tensors="pt") @@ -152,7 +150,10 @@ def get_float_model_info(model_name_or_path, torch_dtype=torch.float32): ) return model, tokenizer, decoder_cls + execution_records = collections.defaultdict(list) + + def dump_elapsed_time(customized_msg="", record=False): """Get the elapsed time for decorated functions. @@ -178,7 +179,9 @@ def fi(*args, **kwargs): ) ) if record: - avg_time = sum(execution_records[func.__qualname__])/len(execution_records[func.__qualname__]) + avg_time = sum(execution_records[func.__qualname__]) / len( + execution_records[func.__qualname__] + ) std_time = np.std(execution_records[func.__qualname__]) logging.warning( f"For {func.__qualname__}, the average elapsed time: {avg_time: .2f} ms, the std: {std_time: .2f} ms" From a92b1431cee9aec4589c48d501373130287a290b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 16 Sep 2024 20:40:43 -0400 Subject: [PATCH 7/9] rename `train_bs` to `batch_size` Signed-off-by: yiliu30 --- torchao/prototype/autoround/README.md | 4 ++-- torchao/prototype/autoround/autoround_llm.py | 8 ++++---- torchao/prototype/autoround/core.py | 2 +- torchao/prototype/autoround/eval_autoround.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md index 5b5bd3959e..18f3663427 100644 --- a/torchao/prototype/autoround/README.md +++ b/torchao/prototype/autoround/README.md @@ -18,7 +18,7 @@ This script allows you to apply `Auto-Round` on a given model directly, more con | `dataset_name` | `"NeelNanda/pile-10k"` | Dataset name for calibration | | `iters` | 200 | Number of steps for optimizing each block | | `bits` | 4 | Number of bits for quantization | -| `train_bs` | 8 | Batch size for calibration | +| `batch_size` | 8 | Batch size for calibration | | `nsamples` | 128 | Number of samples for calibration process | | `seqlen` | 2048 | Sequence length for each samples | | `group_size` | 128 | Group size for quantization | @@ -116,7 +116,7 @@ quantize_(model, apply_auto_round(), is_target_module) > [!NOTE] > - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`int4_weight_only(group_size=128)`) while leaving the `lm-head` unquantized.
> - `auto-round-4bit` uses the deafult configuration from [quick start](#quick-start).
-> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `train_bs=4`, which accumulating two batches(4 samples per batch) before performing the backward pass.
+> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `batch_size=4`, which accumulating two batches(4 samples per batch) before performing the backward pass.
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`. diff --git a/torchao/prototype/autoround/autoround_llm.py b/torchao/prototype/autoround/autoround_llm.py index 45331b7a0a..f1cb528bf9 100644 --- a/torchao/prototype/autoround/autoround_llm.py +++ b/torchao/prototype/autoround/autoround_llm.py @@ -26,7 +26,7 @@ def quantize_model_with_autoround_( iters: int = 200, seqlen: int = 2048, dataset_name: str = "NeelNanda/pile-10k", - bs: int = 8, + batch_size: int = 8, nsamples: int = 128, use_optimized_layer_output: bool = False, gradient_accumulate_steps: Optional[int] = 1, @@ -54,7 +54,7 @@ def quantize_model_with_autoround_( tokenizer, seqlen=seqlen, dataset_name=dataset_name, - bs=bs, + bs=batch_size, nsamples=nsamples, ) input_ids_lst = [] @@ -108,7 +108,7 @@ def main(args): iters=args.iters, seqlen=args.seqlen, dataset_name=args.dataset_name, - bs=args.train_bs, + batch_size=args.batch_size, nsamples=args.nsamples, use_optimized_layer_output=args.use_optimized_layer_output, gradient_accumulate_steps=args.gradient_accumulate_steps, @@ -148,7 +148,7 @@ def main(args): "--bits", default=4, type=int, help="Number of bits for quantization" ) parser.add_argument( - "--train_bs", default=8, type=int, help="Batch size for calibration" + "--batch_size", default=8, type=int, help="Batch size for calibration" ) parser.add_argument( "--nsamples", diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py index 05d1552c55..c602473c56 100644 --- a/torchao/prototype/autoround/core.py +++ b/torchao/prototype/autoround/core.py @@ -181,7 +181,7 @@ def to_uintx_weight(input_float): quant_min = 0 quant_max = _auto_round_config.bits**2 - 1 block_size = (1, observed_linear.group_size) - from torchao.dtypes.uintx.Uintx import ( + from torchao.dtypes.uintx.uintx import ( _BIT_WIDTH_TO_DTYPE, UintxLayoutType, ) diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index 33a814e307..c28a0fc390 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -107,7 +107,7 @@ def main(args): ) elif args.uintx: msg += f" (uintx {args.bits} bits)" - from torchao.dtypes.uintx.Uintx import _BIT_WIDTH_TO_DTYPE + from torchao.dtypes.uintx.uintx import _BIT_WIDTH_TO_DTYPE from torchao.quantization.quant_api import quantize_, uintx_weight_only bits = args.bits @@ -144,7 +144,7 @@ def main(args): group_size=args.group_size, iters=args.iters, seqlen=args.seqlen, - bs=args.train_bs, + batch_size=args.batch_size, nsamples=args.nsamples, use_optimized_layer_output=args.use_optimized_layer_output, gradient_accumulate_steps=args.gradient_accumulate_steps, @@ -187,7 +187,7 @@ def main(args): "--bits", default=4, type=int, help="Number of bits for quantization" ) parser.add_argument( - "--train_bs", default=8, type=int, help="Batch size for calibration" + "--batch_size", default=8, type=int, help="Batch size for calibration" ) parser.add_argument( "--nsamples", From a5722d8cf04554a452beb1a00289d78e6388b9f0 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 16 Sep 2024 21:02:46 -0400 Subject: [PATCH 8/9] update the eval Signed-off-by: yiliu30 --- torchao/_models/llama/eval.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 673c4f595f..6130b98751 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -128,19 +128,28 @@ def run_evaluation( _tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent) # parse args from quantization string: - # autoround------- + # autoround--------- _quant_args = quantization.split("-") - _default_quant_args = [False, 200, 128, 8, 2048, 128] + _default_quant_args = [False, 200, 128, 8, 2048, 128, 0, 1] _model_devie = _quant_args[1] if len(_quant_args) > 1 else device _quant_args = _quant_args[2:] - quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [ - int(x) for x in _quant_args - ] + _default_quant_args[len(_quant_args) :] + ( + quant_lm_head, + iters, + groupsize, + batch_size, + seqlen, + nsamples, + grad_acc_steps, + compile_optimization_process, + ) = [int(x) for x in _quant_args] + _default_quant_args[len(_quant_args) :] model = model.to(_model_devie) print( ( f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, " - f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})" + f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples}, " + f"gradient_accumulate_steps={grad_acc_steps}, " + f"compile_optimization_process={compile_optimization_process})" ) ) with torch.device(_model_devie): @@ -161,9 +170,11 @@ def run_evaluation( is_target_module=is_target_module, bits=4, seqlen=seqlen, - bs=batch_size, + batch_size=batch_size, iters=iters, nsamples=nsamples, + gradient_accumulate_steps=grad_acc_steps, + compile_optimization_process=compile_optimization_process == 1, ) model.to(device) model.reset_caches() @@ -195,9 +206,10 @@ def run_evaluation( "--quantization", type=str, help=( - "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--gptq, " - "autoquant, autoquant-int4, int4wo--hqq, uintx--, uintx---hqq, " - "sparse-marlin, autoround-------" + "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, " + "int4wo--gptq, autoquant, autoquant-int4, int4wo--hqq, " + "uintx--, uintx---hqq, sparse-marlin, " + "autoround---------" ), ) parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') From 47103a536cc4f13997ac6d50064c9f24fb03aa14 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 16 Sep 2024 21:05:43 -0400 Subject: [PATCH 9/9] update Signed-off-by: yiliu30 --- torchao/_models/llama/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 6130b98751..d495c2065b 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -130,7 +130,7 @@ def run_evaluation( # parse args from quantization string: # autoround--------- _quant_args = quantization.split("-") - _default_quant_args = [False, 200, 128, 8, 2048, 128, 0, 1] + _default_quant_args = [False, 200, 128, 8, 2048, 128, 1, 0] _model_devie = _quant_args[1] if len(_quant_args) > 1 else device _quant_args = _quant_args[2:] (