Skip to content

Commit 9ef86f8

Browse files
committed
support eval of float8_a1x128_w128x128
Summary: Adds support for the new float8 scaling recipe in the official eval scripts used to generate accuracy numbers in the README. For now, I am using this as a smoke test that the scaling is working on a real model - it is. We can add official benchmark results after we hook up slayton's cuBLAS binding on H100, which should make the UEX of running evals a lot better. Test Plan: Smoke test on LLama-3.1-8B, accuracy looks good ``` // download checkpoint with-proxy python scripts/download.py --hf_token {token} --repo_id meta-llama/Meta-Llama-3.1-8B // prepare checkpoint python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B // run bf16 eval on a single task with-proxy time python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --tasks 'winogrande' ... winogrande: {'alias': 'winogrande', 'acc,none': 0.7426992896606156, 'acc_stderr,none': 0.012285989618865697} // run float8 eval on the same task with-proxy time python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --tasks 'winogrande' --quantization float8_a1x128_w128x128 --compile ... winogrande: {'alias': 'winogrande', 'acc,none': 0.7419100236779794, 'acc_stderr,none': 0.012298278833972477} ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e87609a ghstack-comment-id: 3474380821 Pull-Request: #3269
1 parent 06cba89 commit 9ef86f8

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

scripts/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
3838
parser.add_argument(
3939
"--repo_id",
4040
type=str,
41-
default="checkpoints/meta-llama/llama-2-7b-chat-hf",
41+
default="meta-llama/llama-2-7b-chat-hf",
4242
help="Repository ID to download from.",
4343
)
4444
parser.add_argument(

torchao/_models/llama/eval.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Int4WeightOnlyConfig,
2424
Int8DynamicActivationInt8WeightConfig,
2525
Int8WeightOnlyConfig,
26+
PerBlock,
2627
PerRow,
2728
PerTensor,
2829
UIntXWeightOnlyConfig,
@@ -44,6 +45,7 @@ def run_evaluation(
4445
calibration_limit: Optional[int] = None,
4546
calibration_seq_length: Optional[int] = None,
4647
pad_calibration_inputs: bool = False,
48+
print_model: bool = False,
4749
):
4850
"""Runs the evaluation of a model using LM Eval."""
4951
print(
@@ -169,6 +171,13 @@ def run_evaluation(
169171
model,
170172
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
171173
)
174+
if quantization == "float8_a1x128_w128x128":
175+
config = Float8DynamicActivationFloat8WeightConfig(
176+
granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
177+
)
178+
# TODO(future): all workflows in this file should be skipping quantization
179+
# of `lm_head`
180+
quantize_(model, config)
172181
if "autoround" in quantization:
173182
from transformers import AutoTokenizer
174183

@@ -273,7 +282,16 @@ def run_evaluation(
273282
)
274283

275284
if compile:
276-
model = torch.compile(model, mode="max-autotune", fullgraph=True)
285+
# TODO(future PR): clean this up
286+
if quantization == "float8_a1x128_w128x128":
287+
# we don't need max-autotune for float8 blockwise quant
288+
model = torch.compile(model)
289+
else:
290+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
291+
292+
if print_model:
293+
print(model)
294+
277295
with torch.no_grad():
278296
print("Running evaluation ...")
279297
# avoid circular imports
@@ -371,6 +389,9 @@ def run_evaluation(
371389
default=False,
372390
help="pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower",
373391
)
392+
parser.add_argument(
393+
"--print_model", action="store_true", help="Whether to print the model."
394+
)
374395

375396
args = parser.parse_args()
376397
run_evaluation(
@@ -387,4 +408,5 @@ def run_evaluation(
387408
args.calibration_limit,
388409
args.calibration_seq_length,
389410
args.pad_calibration_inputs,
411+
args.print_model,
390412
)

0 commit comments

Comments
 (0)