Skip to content

Commit ccb38e2

Browse files
committed
Add support for save quantized checkpoint in llama code
Summary: The goal is to upload a torchao quantized model to huggingface so that we can run the model in huggingface Test Plan: python generate.py -q int4wo-32 --save Reviewers: Subscribers: Tasks: Tags:
1 parent 8fa11a6 commit ccb38e2

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

scripts/hf_eval.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def format_value(value):
4040

4141
print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))
4242

43-
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):
43+
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length):
4444

4545
tokenizer = AutoTokenizer.from_pretrained(repo_id)
4646
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
4747

48-
if compile:
48+
if quantization == "autoquant" and compile:
4949
model = torch.compile(model, mode="max-autotune", fullgraph=True)
5050

5151
if quantization == "int8dq":
@@ -57,6 +57,10 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
5757
quantize_(model.to(device=device), int4_weight_only())
5858
elif quantization == "autoquant":
5959
model = autoquant(model.to(device=device))
60+
61+
if quantization != "autoquant" and compile:
62+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
63+
6064
with torch.no_grad():
6165
result = evaluate(
6266
HFLM(
@@ -70,6 +74,11 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
7074

7175
pretty_print_nested_results(result)
7276

77+
if save:
78+
# This doesn't work yet: https://github.com/huggingface/transformers/issues/32364
79+
# model.save_pretrained("quantized_model_test", safe_serialization=False)
80+
torch.save(model.state_dict(), "model-" + quantization + ".pt")
81+
7382

7483
if __name__ == '__main__':
7584
import argparse
@@ -81,8 +90,9 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
8190
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
8291
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
8392
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
93+
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
8494
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
8595
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
8696

8797
args = parser.parse_args()
88-
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)
98+
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.save, args.batch_size, args.max_length)

torchao/_models/llama/generate.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import os
67
import sys
78
import time
89
from pathlib import Path
@@ -147,6 +148,7 @@ def main(
147148
temperature: float = 0.8,
148149
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
149150
quantization: Optional[str] = None,
151+
save: bool = False,
150152
compile: bool = True,
151153
compile_prefill: bool = False,
152154
profile: Optional[Path] = None,
@@ -219,6 +221,11 @@ def main(
219221

220222
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
221223

224+
if save:
225+
output_dir = str(checkpoint_path.cwd())
226+
filename = str(checkpoint_path.name).split(".")[0]
227+
torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt"))
228+
222229
if compile:
223230
print("Compiling Model")
224231
global decode_one_token, prefill
@@ -337,6 +344,7 @@ def callback(x):
337344
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
338345
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
339346
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
347+
parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.')
340348
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
341349
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
342350
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
@@ -347,5 +355,5 @@ def callback(x):
347355
args = parser.parse_args()
348356
main(
349357
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
350-
args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
358+
args.temperature, args.checkpoint_path, args.quantization, args.save, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
351359
)

0 commit comments

Comments
 (0)