Skip to content

Commit 333a88f

Browse files
committed
update benchmarks + README
1 parent 153fd0b commit 333a88f

File tree

7 files changed

+53
-29
lines changed

7 files changed

+53
-29
lines changed

torchao/_models/llama/benchmark_results.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ kv cache quantization:
3838
20240826171015, tok/s= 1.95, mem/s= 29.21 GB/s, peak_mem=59.27 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072
3939
20240826172121, tok/s= 1.73, mem/s= 26.02 GB/s, peak_mem=52.62 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization
4040
20240826173230, tok/s= 1.73, mem/s= 25.95 GB/s, peak_mem=34.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization --linear_causal_mask
41+
20240906054415, tok/s=226.02, mem/s= 689.20 GB/s, peak_mem= 5.32 GB, model_size= 3.05 GB quant: marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

torchao/_models/llama/benchmarks.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
3030
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
3131
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
3232
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt
33-
33+
# sparse marlin (NOTE: float16)
34+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
3435
# auto-round w/ quant_lm_head
3536
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
3637
# auto-round w/o quant_lm_head
3738
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0
3839

3940

41+
4042
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
4143
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192
4244
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization

torchao/_models/llama/generate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ def main(
225225
groupsize=int(quantization.split("-")[-1])
226226
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
227227
quantize_(model, int4_weight_only(group_size=groupsize))
228+
if "marlin" in quantization:
229+
from torchao.dtypes import MarlinSparseLayoutType
230+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
228231
if "autoround" in quantization:
229232
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
230233
from transformers import AutoTokenizer

torchao/_models/sam/benchmark.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
88
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
99
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
1010
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
11+
# int8 dynamic quant attn + int4 wo + sparse marlin lin 1 + 2:4 sparse lin2
12+
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half float16 --device cuda --compress int4_weight_only_sparse

torchao/_models/sam/eval_combo.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,16 @@ def run(
283283
for block in predictor.model.image_encoder.blocks:
284284
block.attn.use_rel_pos = use_rel_pos
285285

286+
# Helper filter functions
287+
def attn_only(mod, name):
288+
return isinstance(mod, torch.nn.Linear) and 'attn' in name
289+
def mlp_lin1_only(mod, name):
290+
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
291+
def mlp_lin2_only(mod, name):
292+
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
293+
def mlp_only(mod, name):
294+
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
295+
286296
if compress == "int8_dynamic_quant":
287297
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
288298
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -296,15 +306,6 @@ def mlp_only(mod, name):
296306
apply_fake_sparsity(predictor.model.image_encoder)
297307
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
298308
elif compress == "int8_dynamic_quant_sparse":
299-
def attn_only(mod, name):
300-
return isinstance(mod, torch.nn.Linear) and 'attn' in name
301-
def mlp_lin1_only(mod, name):
302-
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
303-
def mlp_lin2_only(mod, name):
304-
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
305-
def mlp_only(mod, name):
306-
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
307-
308309
# apply sparsify first to set qparams
309310
apply_fake_sparsity(predictor.model.image_encoder,
310311
filter_fn=mlp_only)
@@ -320,7 +321,20 @@ def mlp_only(mod, name):
320321
mlp_lin2_only)
321322
if not TORCH_VERSION_AT_LEAST_2_5:
322323
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
323-
324+
elif compress == "int4_weight_only_sparse":
325+
# apply sparsify first to set qparams
326+
apply_fake_sparsity(predictor.model.image_encoder,
327+
filter_fn=mlp_only)
328+
from torchao.dtypes import MarlinSparseLayoutType
329+
quantize_(predictor.model.image_encoder,
330+
int8_dynamic_activation_int8_weight(),
331+
attn_only)
332+
quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only)
333+
sparsify_(predictor.model.image_encoder
334+
semi_sparse_weight(),
335+
mlp_lin2_only)
336+
if not TORCH_VERSION_AT_LEAST_2_5:
337+
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
324338
else:
325339
assert compress is None, f"Unsupported compress mode {compress}"
326340

torchao/_models/sam/results.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma
44
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
55
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
66
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
7+
cuda,vit_h,32,17068,21,23.96093702681232,41.73459489004953,0.5485481164943489,max-autotune,torch.float16,int4_weight_only_sparse,False,True,True,32,154,4928,None,None

torchao/quantization/README.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,25 @@ Typically quantization algorithms will have different schemes for how the activa
55
Benchmarks are run on a machine with a single A100 GPU using the script in _models/llama which generates text in a latency optimized way (batchsize=1), evaluation was done
66
Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B.
77

8-
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
9-
| ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
10-
| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 |
11-
| | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 |
12-
| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 |
13-
| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 |
14-
| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 |
15-
| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 |
16-
| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 |
17-
| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 |
18-
| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 |
19-
| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 |
20-
| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 |
21-
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
22-
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
23-
| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 |
24-
| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 |
25-
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
8+
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
9+
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
10+
| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 |
11+
| | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 |
12+
| | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 |
13+
| | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 |
14+
| | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 |
15+
| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 |
16+
| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 |
17+
| | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 |
18+
| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 |
19+
| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 |
20+
| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 |
21+
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
22+
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
23+
| | int4wo-64-sparse-marlin | N/A | 226.02 | 689.20 | 5.32 | 3.05 |
24+
| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 |
25+
| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 |
26+
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
2627

2728
note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
2829

0 commit comments

Comments
 (0)