diff --git a/scripts/download.py b/scripts/download.py index 7a9da38db5..8dd4214f6e 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -38,7 +38,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - parser.add_argument( "--repo_id", type=str, - default="checkpoints/meta-llama/llama-2-7b-chat-hf", + default="meta-llama/llama-2-7b-chat-hf", help="Repository ID to download from.", ) parser.add_argument( diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index fdd9792cb4..676d3569e7 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -23,6 +23,7 @@ Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + PerBlock, PerRow, PerTensor, UIntXWeightOnlyConfig, @@ -44,6 +45,7 @@ def run_evaluation( calibration_limit: Optional[int] = None, calibration_seq_length: Optional[int] = None, pad_calibration_inputs: bool = False, + print_model: bool = False, ): """Runs the evaluation of a model using LM Eval.""" print( @@ -169,6 +171,14 @@ def run_evaluation( model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity), ) + if quantization == "float8_a1x128_w128x128": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=(PerBlock((1, 128)), PerBlock((128, 128))), + activation_value_lb=1e-12, + ) + # TODO(future): all workflows in this file should be skipping quantization + # of `lm_head` + quantize_(model, config) if "autoround" in quantization: from transformers import AutoTokenizer @@ -273,7 +283,16 @@ def run_evaluation( ) if compile: - model = torch.compile(model, mode="max-autotune", fullgraph=True) + # TODO(future PR): clean this up + if quantization == "float8_a1x128_w128x128": + # we don't need max-autotune for float8 blockwise quant + model = torch.compile(model) + else: + model = torch.compile(model, mode="max-autotune", fullgraph=True) + + if print_model: + print(model) + with torch.no_grad(): print("Running evaluation ...") # avoid circular imports @@ -371,6 +390,9 @@ def run_evaluation( default=False, help="pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower", ) + parser.add_argument( + "--print_model", action="store_true", help="Whether to print the model." + ) args = parser.parse_args() run_evaluation( @@ -387,4 +409,5 @@ def run_evaluation( args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, + args.print_model, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ddeb8c7ca6..235cd85a0f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1778,8 +1778,6 @@ def __post_init__(self): default_use_fast_accum = True if _granularity_is_a_1_128_w_128_128(self.granularity): - assert self.activation_value_lb is None, "unimplemented" - assert self.activation_value_ub is None, "unimplemented" assert self.kernel_preference in ( KernelPreference.AUTO, KernelPreference.TORCH,