Skip to content

Commit 92dcc62

Browse files
authored
Update docs + add deprecation warning (#825)
* update docstrigs * updated README * update docs * updated docs * updated * fix * update doc * update png * fix affine quantized test
1 parent 28dc4fb commit 92dcc62

File tree

8 files changed

+83
-42
lines changed

8 files changed

+83
-42
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ from torchao.quantization.quant_api import (
3131
quantize_,
3232
int8_dynamic_activation_int4_weight,
3333
int8_dynamic_activation_int8_weight,
34-
int8_dynamic_activation_int8_semi_sparse_weight,
3534
int4_weight_only,
3635
int8_weight_only
3736
)
348 KB
Loading

test/dtypes/test_affine_quantized.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
float8_weight_only,
1212
)
13+
from torchao.dtypes import SemiSparseLayoutType
1314
from torch.testing._internal import common_utils
1415
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1516

@@ -30,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
3031
base_functions.append(int4_weight_only(group_size=32))
3132

3233
if do_sparse:
33-
base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight())
34+
base_functions.append(int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
3435

3536
if is_cuda_8_9:
3637
base_functions.append(float8_weight_only())

test/sparsity/test_sparse_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
from torchao.sparsity import (
99
apply_fake_sparsity,
1010
sparsify_,
11-
int8_dynamic_activation_int8_semi_sparse_weight,
1211
semi_sparse_weight,
1312
)
14-
from torchao.dtypes import MarlinSparseLayoutType
13+
from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType
1514
from torchao.quantization.quant_api import (
1615
int8_dynamic_activation_int8_weight,
1716
quantize_,
@@ -67,7 +66,7 @@ def test_quant_semi_sparse(self):
6766
quantize_(model_copy, int8_dynamic_activation_int8_weight())
6867
dense_result = model_copy(input)
6968

70-
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
69+
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
7170
sparse_result = model(input)
7271

7372
assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

torchao/_models/sam/eval_combo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import resource
1111

1212
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
13-
from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight
13+
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
14+
from torchao.dtypes import SemiSparseLayoutType, MarlinSparseLayoutType
1415
from torchao.utils import unwrap_tensor_subclass
1516
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1617

@@ -314,7 +315,7 @@ def mlp_only(mod, name):
314315
int8_dynamic_activation_int8_weight(),
315316
attn_only)
316317
quantize_(predictor.model.image_encoder,
317-
int8_dynamic_activation_int8_semi_sparse_weight(),
318+
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
318319
mlp_lin1_only)
319320
sparsify_(predictor.model.image_encoder,
320321
semi_sparse_weight(),

torchao/quantization/quant_api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
and mixed GEMM kernels
1616
"""
1717
from functools import partial
18+
import warnings
1819
import torch
1920
import torchao
2021
import torch.nn as nn
@@ -612,6 +613,11 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
612613
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
613614
quantization + 2:4 sparsity to linear layers.
614615
"""
616+
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout_type kwarg in int8_dynamic_activation_int8_weight instead.
617+
618+
from torchao.dtypes import SemiSparseLayoutType
619+
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()""")
620+
615621
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
616622

617623

torchao/sparsity/README.md

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,10 @@
22

33
Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1).
44

5-
## Goal
65

7-
We feel that the main problem current sparsity researchers / users face is fragmentation. Researchers rightfully aim to show end-to-end results, but this means a lot of time is spent figuring out how to integrate with PyTorch and implementation questions like:
8-
- *When should I mask?*
9-
- *When/how should I store the compressed representation?*
10-
- *Do I want in-place or out-of-place mask updates?*
11-
- *How can I call sparse matmul instead of dense?*
12-
13-
We feel like the above problems can be solved once by `torchao`, letting researchers focus on what really matters - pushing sparse kernel performance or more accurate pruning algorithms.
14-
15-
More concretely, we hope to provide tutorials and APIs for both sparse kernels (tensor subclassing) and pruning algorithms (torch.ao.pruning.Sparsifier) that users can extend. We aim to provide modular building blocks, that can be used to accelerate not only inference but training as well, and that compose nicely with `torchao` quantization workflows.
16-
17-
1. Train sparse models from scratch with hardware acceleration, with minimal accuracy loss.
18-
2. Recover accuracy loss of pruned model with custom pruning algorthim.
19-
3. Accelerate masked/pruned models on sparsity-supported hardware to realize performance improvements.
20-
21-
## Success Stories
22-
23-
#### segment-anything-fast
24-
We applied 2:4 sparsity to accelerate segment-anything, as part of [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast).
6+
## Benchmarks
257

8+
### segment-anything-fast
269
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.
2710

2811
Overall, we found that accelerating the MLP linear layers provied the most speedups (`lin1`, `lin2`), while mitigating accuracy loss.
@@ -47,20 +30,23 @@ The following benchmarks we ran for sam ViT-h on an NVIDIA-A100-80GB, with batch
4730

4831
To reproduce our benchmarks please follow these [instructions](/torchao/_models/sam/README.md).
4932

50-
#### BERT
33+
### LLama3
5134

52-
We were able to accelerate BERT 1.23x on an A100 with a negligible accuracy drop on SQuAD.
53-
For more information about accelerting BERT with semi-sturcutred sparsity, please see our [tutorial](https://pytorch.org/tutorials/advanced/semi_structured_sparse.html?highlight=beta).
35+
On Meta LLama3, we observe a 25% tok/s increase (180 -> 226) compared to our existing int4-wo implementation when using the sparse marlin kernel @Diogo-V added.
5436

55-
| Metrics | fp16 | 2:4 sparse | delta / speedup |
56-
| --- | --- | --- | --- |
57-
| Exact Match (%) | 78.53 | 78.44 | -0.09 |
58-
| F1 (%) | 86.93 | 86.49 | -0.44 |
59-
| Time (bs=16) | 19.35 | 15.74 | 1.23x |
37+
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
38+
| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- | --------------- |
39+
| Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 | 15.01 |
40+
| | int8dq | 8.61 | 64.75 | 9.24 | 7.52 |
41+
| | int8wo | 153.03 | 1150.80 | 10.42 | 7.52 |
42+
| | int4wo-64 | 180.80 | 763.33 | 6.88 | 4.22 |
43+
| | int4wo-64-sparse-marlin | 226.02 | 689.20 | 5.32 | 3.05 |
6044

61-
# Implemented APIs
45+
These benchmarks were also ran on a NVIDIA-A100-80GB.
6246

63-
## Quantization + Sparsity
47+
## Supported APIs
48+
49+
![support_matrix](/docs/static/supported_sparsity_patterns.png)
6450

6551
### Sparse Marlin 2:4
6652

@@ -72,11 +58,59 @@ from torchao.dtypes import MarlinSparseLayoutType
7258

7359
# Your FP16 model
7460
model = model.cuda().half()
75-
7661
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
7762
```
7863

79-
# Design
64+
### int8 dynamic quant + 2:4 sparasity
65+
66+
We support composing int8 dynaic quantization with 2:4 sparsity. We fuse one of the scalar dequant multiplications into our cuSPARSELt sparse mm in order to remain performant.
67+
68+
```py
69+
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
70+
from torchao.dtypes import SemiSparseLayoutType
71+
72+
model = model.cuda()
73+
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
74+
```
75+
76+
### 2:4 sparsity
77+
78+
```py
79+
from torchao.sparsity.sparse_api import sparsify_, semi_sparse_weight
80+
from torchao.dtypes import SemiSparseLayoutType
81+
82+
model = model.cuda()
83+
sparsify_(model, semi_sparse_weight())
84+
```
85+
86+
### Block sparsity (prototype)
87+
We offer prototype support for accelerating block sparsity with our triton kernels for bfloat16/float16 workloads.
88+
89+
```py
90+
from torchao.sparsity.sparse_api import sparsify_
91+
from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
92+
93+
model = model.cuda()
94+
sparsify_(model, block_sparse_weight())
95+
```
96+
97+
# Goal
98+
99+
We feel that the main problem current sparsity researchers / users face is fragmentation. Researchers rightfully aim to show end-to-end results, but this means a lot of time is spent figuring out how to integrate with PyTorch and implementation questions like:
100+
- *When should I mask?*
101+
- *When/how should I store the compressed representation?*
102+
- *Do I want in-place or out-of-place mask updates?*
103+
- *How can I call sparse matmul instead of dense?*
104+
105+
We feel like the above problems can be solved once by `torchao`, letting researchers focus on what really matters - pushing sparse kernel performance or more accurate pruning algorithms.
106+
107+
More concretely, we hope to provide tutorials and APIs for both sparse kernels (tensor subclassing) and pruning algorithms (torch.ao.pruning.Sparsifier) that users can extend. We aim to provide modular building blocks, that can be used to accelerate not only inference but training as well, and that compose nicely with `torchao` quantization workflows.
108+
109+
1. Train sparse models from scratch with hardware acceleration, with minimal accuracy loss.
110+
2. Recover accuracy loss of pruned model with custom pruning algorthim.
111+
3. Accelerate masked/pruned models on sparsity-supported hardware to realize performance improvements.
112+
113+
## Design
80114

81115
Sparsity, like quantization, is an accuracy/performance trade-off, where we care not only about the speedup but also on the accuracy degradation of our architecture optimization technique.
82116

torchao/sparsity/sparse_api.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ def sparsify_(model: torch.nn.Module,
4242
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
4343
This function is essentially the same as quantize, put for sparsity subclasses.
4444
45-
Currently, we support two options for sparsity:
45+
Currently, we support three options for sparsity:
4646
- semi-structured (2:4) sparsity with `semi_sparse_weight`
47-
- int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_semi_sparse_weight`, which is also available via the quantize API
47+
- int8 dynamic quantization + 2:4 sparsity with `layout_type=SemiSparseLayoutType`
48+
- int4 weight-only quantization + 2:4 sparsity with `layout_type=SparseMarlinLayoutType`
4849
4950
Args:
5051
model (torch.nn.Module): input model
@@ -67,8 +68,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
6768
m = sparsify_(m, semi_sparse_weight(), filter_fn)
6869
6970
# for int8 dynamic quantization + 2:4 sparsity
70-
from torchao.sparsity.prototype import int8_dynamic_activation_int8_semi_sparse_weight
71-
m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight(), filter_fn)
71+
from torchao.dtypes import SemiSparseLayoutType
72+
m = quantize_(m, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType), filter_fn)
7273
"""
7374
_replace_with_custom_fn_if_matches_filter(
7475
model,

0 commit comments

Comments
 (0)