Skip to content

Commit e531daa

Browse files
HDCharlesjainapurva
authored andcommitted
adding kv_cache quantization (#532)
Adding support for kv_cache quantization, we are using simple symmetric quantization, though using the full precision of the k and v values of the current token. we see tok/s reduction of 3-5 tok/s depending on context length. image and a reduction in peak memory image We expect this reduction to scale to large context lengths, in the model memory trace we can see the point where we replace the bf16 cache with the int8 cache which visually saves about half of the used memory Screenshot 2024-08-02 at 2 45 14 AM at longer context lengths both quantized and non-quantized kv_cache models start outputing weird stuff but otherwise accuracy of the kv_cache quant looks reasonable though e.g. for 2048 context length: <|begin_of_text|>Hello, my name is Richard Brown and I have been a professional musician for over 25 years. I have played in a number of bands, doing a wide variety of genres (soul/funk, rock, jazz, blues, latin, world). I have played on over a hundred albums so far. I have played with many different singers, as well as instrumentalists (guitarists, sax players, brass players, etc.). I love to play and try to learn as much as I can from others. I have become an all-round musician - playing keyboards, drums, programming, arranging; as well as writing songs myself. I have my own studio, and I can do sessions online. I also have my own website, where you can find out more about me and my music. I hope that you will find the music that you are looking for here. Otherwise there are some fixes in generate.py to get things working for large context lengths without overflowing beyond the model limit. test plan: sh benchmarks.sh (specifically the last 6 rows of benchmark_results.txt)
1 parent fd5a858 commit e531daa

File tree

4 files changed

+89
-10
lines changed

4 files changed

+89
-10
lines changed

torchao/_models/llama/benchmark_results.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
llama 2
12
20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
23
20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
34
20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
@@ -8,6 +9,7 @@
89
20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
910
20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
1011

12+
llama 3
1113
20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
1214
20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
1315
20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
@@ -17,3 +19,11 @@
1719
20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
1820
20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
1921
20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
22+
23+
kv cache quantization:
24+
20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
25+
20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
26+
20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
27+
20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
28+
20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
29+
20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8

torchao/_models/llama/benchmarks.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,11 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
2222
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
2323
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
2424
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
25+
26+
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
27+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
28+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
29+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
30+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
31+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
32+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192

torchao/_models/llama/generate.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
6868
next_token, next_prob = decode_one_token(
6969
model, cur_token, input_pos, **sampling_kwargs
7070
)
71+
next_token, next_prob = next_token.clone(), next_prob.clone()
7172
input_pos += 1
72-
new_tokens.append(next_token.clone())
73+
new_tokens.append(next_token)
7374
callback(new_tokens[-1])
74-
new_probs.append(next_prob.clone())
75+
new_probs.append(next_prob)
7576
cur_token = next_token.view(1, -1)
7677

7778
return new_tokens, new_probs
@@ -88,6 +89,7 @@ def generate(
8889
*,
8990
interactive: bool,
9091
callback = lambda x: x,
92+
kv_cache_quantization: bool = False,
9193
**sampling_kwargs
9294
) -> torch.Tensor:
9395
"""
@@ -97,14 +99,27 @@ def generate(
9799
# create an empty tensor of the expected final shape and fill in the current tokens
98100
device = prompt.device
99101
T = prompt.numel()
100-
T_new = T + max_new_tokens
101-
seq = torch.empty(T_new, dtype=prompt.dtype, device=device)
102+
103+
# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
104+
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
105+
new_tokens = max_seq_length - T
106+
107+
# full prompt+output will be stored in seq
108+
seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device)
102109
seq[:T] = prompt.view(-1)
103110

104-
# setup model cache
105-
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
111+
# setup model caches
106112
with torch.device(device):
107113
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
114+
if kv_cache_quantization:
115+
from model import AffineQuantizedKVCache
116+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
117+
_replace_with_custom_fn_if_matches_filter(
118+
model,
119+
AffineQuantizedKVCache.from_float,
120+
lambda x, y: isinstance(x, torchao._models.llama.model.KVCache),
121+
)
122+
108123

109124
# format model input
110125
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
@@ -113,8 +128,9 @@ def generate(
113128
next_token = prefill(model, x, input_pos, **sampling_kwargs).clone()
114129
seq[T] = next_token
115130

131+
# execute token generation
116132
input_pos = torch.tensor([T], device=device, dtype=torch.int)
117-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
133+
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
118134
seq[T + 1:] = torch.cat(generated_tokens)
119135

120136
return seq
@@ -147,6 +163,7 @@ def main(
147163
temperature: float = 0.8,
148164
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
149165
quantization: Optional[str] = None,
166+
kv_cache_quantization: bool = False,
150167
compile: bool = True,
151168
compile_prefill: bool = False,
152169
profile: Optional[Path] = None,
@@ -276,6 +293,7 @@ def callback(x):
276293
callback=callback,
277294
temperature=temperature,
278295
top_k=top_k,
296+
kv_cache_quantization=kv_cache_quantization,
279297
)
280298
if i == -1:
281299
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
@@ -286,7 +304,10 @@ def callback(x):
286304
t = time.perf_counter() - t0
287305

288306
if not interactive:
289-
print(tokenizer.decode(y.tolist()))
307+
tok_list = y.tolist()
308+
# truncate text after end of string token
309+
tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())]
310+
print(tokenizer.decode(tokens))
290311
else:
291312
print()
292313
tokens_generated = y.size(0) - prompt_length
@@ -305,12 +326,13 @@ def callback(x):
305326
print(f"Model Size: {model_size:.02f} GB")
306327
if write_result:
307328
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
308-
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
329+
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
309330
result_txt += f"repro: python generate.py "
310331
result_txt += f"--quantization {quantization} " if quantization else ""
311332
result_txt += f"--checkpoint_path {checkpoint_path} "
312333
result_txt += f"--device {device} "
313334
result_txt += f"--precision {precision} "
335+
result_txt += f"--kv_cache_quantization " if kv_cache_quantization else ""
314336
result_txt += f"--compile " if compile else ""
315337
result_txt += f"--compile_prefill " if compile_prefill else ""
316338
result_txt += f"--profile {profile} " if profile else ""
@@ -337,6 +359,7 @@ def callback(x):
337359
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
338360
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
339361
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
362+
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
340363
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
341364
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
342365
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
@@ -347,5 +370,5 @@ def callback(x):
347370
args = parser.parse_args()
348371
main(
349372
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
350-
args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
373+
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
351374
)

torchao/_models/llama/model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.nn import functional as F
1313
from torchao.utils import find_multiple
1414

15+
# TODO remove suplerfluous arg
1516
def prepare_inputs_for_model(inps, max_new_tokens=1):
1617
# this is because input from lm-eval is 2d
1718
if inps.dim() > 2:
@@ -97,6 +98,43 @@ def update(self, input_pos, k_val, v_val):
9798

9899
return k_out, v_out
99100

101+
102+
from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine
103+
from torchao.quantization.utils import quantize_activation_per_token_absmax
104+
105+
class AffineQuantizedKVCache(nn.Module):
106+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16):
107+
super().__init__()
108+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
109+
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
110+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8))
111+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8))
112+
self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
113+
self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
114+
115+
def update(self, input_pos, k_val, v_val):
116+
# quantize current k_val and store it in the cache
117+
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
118+
self.k_cache[:, :, input_pos] = q_k_val
119+
self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1)
120+
k_out = self.k_cache*self.k_cache_scale
121+
k_out[:, :, input_pos] = k_val
122+
123+
q_v_val, v_scale = quantize_activation_per_token_absmax(v_val)
124+
self.v_cache[:, :, input_pos] = q_v_val
125+
self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
126+
v_out = self.v_cache*self.v_cache_scale
127+
v_out[:, :, input_pos] = v_val
128+
129+
return k_out, v_out
130+
131+
@classmethod
132+
def from_float(cls, kv_cache):
133+
cache_shape = kv_cache.k_cache.shape
134+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
135+
scale_dtype = kv_cache.k_cache.dtype
136+
return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype)
137+
100138
class Transformer(nn.Module):
101139
def __init__(self, config: ModelArgs) -> None:
102140
super().__init__()

0 commit comments

Comments
 (0)