Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
parser.add_argument(
"--repo_id",
type=str,
default="checkpoints/meta-llama/llama-2-7b-chat-hf",
default="meta-llama/llama-2-7b-chat-hf",
help="Repository ID to download from.",
)
parser.add_argument(
Expand Down
25 changes: 24 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
PerBlock,
PerRow,
PerTensor,
UIntXWeightOnlyConfig,
Expand All @@ -44,6 +45,7 @@ def run_evaluation(
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
pad_calibration_inputs: bool = False,
print_model: bool = False,
):
"""Runs the evaluation of a model using LM Eval."""
print(
Expand Down Expand Up @@ -169,6 +171,14 @@ def run_evaluation(
model,
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
)
if quantization == "float8_a1x128_w128x128":
Copy link
Contributor

@jainapurva jainapurva Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The evaluation framework for torchao has multiple scripts:
torchao/_models/llama/eval.py
benchmarks/_models/eval_hf_models.py, which will need to be cleaned up as part of BE #3289. For now I feel the quantization technique should also be added to the benchmarking framework here:

def string_to_config(
quantization: Optional[str], sparsity: Optional[str], **kwargs
) -> AOBaseConfig:

This will enable float8_a1x128_w128x128 in the torchao benchmarking module, and running it on hf models

Rest, LGTM!

config = Float8DynamicActivationFloat8WeightConfig(
granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
activation_value_lb=1e-12,
)
# TODO(future): all workflows in this file should be skipping quantization
# of `lm_head`
quantize_(model, config)
if "autoround" in quantization:
from transformers import AutoTokenizer

Expand Down Expand Up @@ -273,7 +283,16 @@ def run_evaluation(
)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
# TODO(future PR): clean this up
if quantization == "float8_a1x128_w128x128":
# we don't need max-autotune for float8 blockwise quant
model = torch.compile(model)
else:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

if print_model:
print(model)

with torch.no_grad():
print("Running evaluation ...")
# avoid circular imports
Expand Down Expand Up @@ -371,6 +390,9 @@ def run_evaluation(
default=False,
help="pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower",
)
parser.add_argument(
"--print_model", action="store_true", help="Whether to print the model."
)

args = parser.parse_args()
run_evaluation(
Expand All @@ -387,4 +409,5 @@ def run_evaluation(
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
args.print_model,
)
2 changes: 0 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,8 +1778,6 @@ def __post_init__(self):

default_use_fast_accum = True
if _granularity_is_a_1_128_w_128_128(self.granularity):
assert self.activation_value_lb is None, "unimplemented"
assert self.activation_value_ub is None, "unimplemented"
assert self.kernel_preference in (
KernelPreference.AUTO,
KernelPreference.TORCH,
Expand Down
Loading