Skip to content

Commit 52b6f4d

Browse files
authored
int8 dynamic prefill weight only decode (#1436)
This PR adds in weight_only_decode option to int8_dynamic_activation_int8_weight, which when set will use dynamic quantization for matmuls of shape (> 1, x) * (x, n) and weight only quantization for the batch_size=1 case. It also updates generate.py to take in a text file for the prompt, we use this to demonstrate these prefill speedups with sh demo_summarize.sh.
1 parent 52a5137 commit 52b6f4d

File tree

6 files changed

+88
-14
lines changed

6 files changed

+88
-14
lines changed

scripts/prepare.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
22
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
33
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
4+
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B-Instruct
45
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
56
python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4
67
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
78
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
89
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
10+
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B-Instruct
911
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
1012
# neuralmagic doesn't come with tokenizer, so we need to copy it over
1113
mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model

torchao/_models/llama/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
moby.txt
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# grab moby dick prompt
2+
wget -nc -O moby.txt https://gist.githubusercontent.com/jcaip/f319146bb543e92e23b2c76815b0f29f/raw/31a9cd12b0b59f323eb197c9534953bdac352986/gistfile1.txt
3+
4+
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B-Instruct
5+
6+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq_prefill_wo_decode --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
7+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
8+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt

torchao/_models/llama/eval.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def run_evaluation(
3434
device = "cuda",
3535
precision = torch.bfloat16,
3636
quantization: Optional[str] = None,
37+
sparsity:Optional[str] = None,
3738
compile=False,
3839
max_length=None,
3940
calibration_tasks: Optional[List[str]] = None,
@@ -44,7 +45,7 @@ def run_evaluation(
4445
"""Runs the evaluation of a model using LM Eval."""
4546
print(
4647
f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, "
47-
+f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
48+
+f"quantization: {quantization}, sparsity: {sparsity}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
4849
+f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n"
4950
)
5051
torchao.quantization.utils.recommended_inductor_config_setter()
@@ -236,6 +237,13 @@ def run_evaluation(
236237
"float8wo, float8dq, float8saq"
237238
),
238239
)
240+
parser.add_argument(
241+
"--sparsity",
242+
type=str,
243+
help=(
244+
"Which sparsity techniques to apply: semi-structured"
245+
),
246+
)
239247
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
240248
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
241249
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
@@ -251,6 +259,7 @@ def run_evaluation(
251259
args.device,
252260
args.precision,
253261
args.quantization,
262+
args.sparstiy,
254263
args.compile,
255264
args.max_length,
256265
args.calibration_tasks,

torchao/_models/llama/generate.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626

2727
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
28-
28+
torch.backends.cuda.enable_cudnn_sdp(True)
2929

3030
class HostEvent:
3131
def __init__(self):
@@ -256,6 +256,7 @@ def _load_model(checkpoint_path, device, precision):
256256
def main(
257257
prefill_size: Optional[int] = None,
258258
prompt: str = "Hello, my name is",
259+
demo_summarize_prompt: Optional[str] = None,
259260
interactive: bool = False,
260261
num_samples: int = 5,
261262
max_new_tokens: int = 100,
@@ -285,7 +286,11 @@ def main(
285286

286287
if prefill_size is not None and prefill_size > 0:
287288
# create prompt of prefill size
288-
prompt = "prompt " * (int(prefill_size) - 3)
289+
if demo_summarize_prompt is None:
290+
prompt = "prompt " * (int(prefill_size) - 2)
291+
else:
292+
with open(demo_summarize_prompt, "r") as f:
293+
prompt = f.read()
289294

290295
torchao.quantization.utils.recommended_inductor_config_setter()
291296

@@ -306,6 +311,12 @@ def main(
306311
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
307312

308313
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
314+
315+
if demo_summarize_prompt is not None:
316+
end_tag = encode_tokens(tokenizer, "\n <END_TEXT>", bos=False, device=device)
317+
encoded = encoded[:prefill_size-end_tag.size(0)]
318+
encoded = torch.cat((encoded, end_tag), dim=0)
319+
309320
prompt_length = encoded.size(0)
310321

311322
torch.manual_seed(1234)
@@ -390,6 +401,8 @@ def ffn_or_attn_only(mod, fqn):
390401
quantize_(
391402
model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only
392403
)
404+
elif "int8dq_prefill_wo_decode" in quantization:
405+
quantize_(model, int8_dynamic_activation_int8_weight(weight_only_decode=True))
393406
else:
394407
quantize_(model, int8_dynamic_activation_int8_weight())
395408
if "int4wo" in quantization:
@@ -809,14 +822,23 @@ def callback(x):
809822
nonlocal done_generating
810823
if done_generating:
811824
return
812-
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
825+
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
813826
if x.item() == tokenizer.eos_id():
814827
done_generating = True
815828
if len(buffer) == 4 or done_generating:
816829
print("".join(buffer), end="", flush=True)
817830
buffer.clear()
818-
# print(, end='', flush=True)
831+
# print(, end="", flush=True)
832+
833+
elif demo_summarize_prompt is not None and i >= 0:
834+
buffer = []
835+
period_id = tokenizer.encode(".")[0]
819836

837+
def callback(x):
838+
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
839+
if len(buffer) == 4:
840+
print("".join(buffer), end="", flush=True)
841+
buffer.clear()
820842
else:
821843
callback = lambda x: x
822844
t0 = time.perf_counter()
@@ -851,15 +873,15 @@ def callback(x):
851873
decode_start_event=decode_start_event,
852874
decode_end_event=decode_end_event,
853875
)
854-
if i == -1:
876+
if i < 0:
855877
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
856878
continue
857879
if hasattr(prof, "export_chrome_trace"):
858880
prof.export_chrome_trace(f"{profile}.json")
859881
device_sync(device=device) # MKG
860882
t = time.perf_counter() - t0
861883

862-
if not interactive and prefill_size is None:
884+
if not interactive and demo_summarize_prompt is None:
863885
tok_list = y[0].tolist()
864886
# truncate text after end of string token
865887
tokens = (
@@ -869,7 +891,7 @@ def callback(x):
869891
)
870892
print(tokenizer.decode(tokens))
871893
else:
872-
print()
894+
print("\n")
873895
tokens_generated = y.size(-1) - prompt_length
874896
tokens_sec = tokens_generated / t
875897
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
@@ -913,7 +935,7 @@ def callback(x):
913935
bandwidth = model_size * tokpersec
914936
mem = torch.cuda.max_memory_reserved() / 1e9
915937
print(f"Average overall tokens/sec: {tokpersec:.2f}")
916-
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
938+
print(f"Average decode tokens/sec: {decode_tokpersec:.04f} s")
917939
print(f"Average TTFT: {ttft:.04f} s")
918940
if device == "cuda":
919941
mem = torch.cuda.max_memory_reserved() / 1e9
@@ -975,6 +997,9 @@ def callback(x):
975997
parser.add_argument(
976998
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
977999
)
1000+
parser.add_argument(
1001+
"--demo_summarize_prompt", type=str, help="Read prompt from text file"
1002+
)
9781003
parser.add_argument(
9791004
"--interactive",
9801005
action="store_true",
@@ -1073,6 +1098,7 @@ def callback(x):
10731098
main(
10741099
args.prefill_size,
10751100
args.prompt,
1101+
args.demo_summarize_prompt,
10761102
args.interactive,
10771103
args.num_samples,
10781104
args.max_new_tokens,

torchao/quantization/quant_api.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,33 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
803803
)
804804

805805

806+
def _int8_symm_per_token_reduced_range_quant_noop_decode(
807+
x: torch.Tensor,
808+
) -> torch.Tensor:
809+
mapping_type = MappingType.SYMMETRIC
810+
target_dtype = torch.int8
811+
eps = 1e-5
812+
quant_min = -127
813+
quant_max = 127
814+
if x.shape[1] == 1:
815+
return x
816+
else:
817+
return to_affine_quantized_intx(
818+
x,
819+
mapping_type,
820+
_get_per_token_block_size(x),
821+
target_dtype,
822+
eps=eps,
823+
quant_min=quant_min,
824+
quant_max=quant_max,
825+
scale_dtype=torch.float32 if x.dtype == torch.float16 else None,
826+
)
827+
828+
806829
def int8_dynamic_activation_int8_weight(
807-
layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC
830+
layout=PlainLayout(),
831+
act_mapping_type=MappingType.SYMMETRIC,
832+
weight_only_decode=False,
808833
):
809834
"""
810835
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
@@ -831,11 +856,14 @@ def get_weight_block_size(x):
831856
eps = torch.finfo(torch.float32).eps
832857
zero_point_dtype = torch.int64
833858

834-
# input settings
835-
if act_mapping_type == MappingType.SYMMETRIC:
836-
input_quant_func = _int8_symm_per_token_reduced_range_quant
859+
if weight_only_decode:
860+
input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode
837861
else:
838-
input_quant_func = _int8_asymm_per_token_quant
862+
# input settings
863+
if act_mapping_type == MappingType.SYMMETRIC:
864+
input_quant_func = _int8_symm_per_token_reduced_range_quant
865+
else:
866+
input_quant_func = _int8_asymm_per_token_quant
839867

840868
block_size = get_weight_block_size(weight)
841869
weight = to_affine_quantized_intx(

0 commit comments

Comments
 (0)