Skip to content

Commit ea70c74

Browse files
committed
chore: review
1 parent ef17ed6 commit ea70c74

File tree

3 files changed

+61
-55
lines changed

3 files changed

+61
-55
lines changed

test/sparsity/test_marlin.py

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from torch import nn
66
from torch.testing._internal.common_utils import TestCase, run_tests
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
78
from torchao.dtypes import MarlinSparseLayoutType
89
from torchao.sparsity.sparse_api import apply_fake_sparsity
910
from torchao.quantization.quant_api import int4_weight_only, quantize_
@@ -12,20 +13,22 @@
1213
unpack_from_marlin_24,
1314
inject_24
1415
)
15-
from torchao.quantization.utils import (
16-
get_group_qparams_symmetric,
17-
groupwise_affine_quantize_tensor_from_qparams,
16+
from torchao.quantization.quant_primitives import (
17+
choose_qparams_affine,
18+
quantize_affine,
19+
ZeroPointDomain,
20+
MappingType,
1821
)
1922

2023

2124
class SparseMarlin24(TestCase):
2225

23-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
24-
def test_quant_sparse_marlin_layout_eager(self):
26+
def setUp(self):
27+
super().setUp()
2528
torch.manual_seed(0)
2629

27-
input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
28-
model = (
30+
self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
31+
self.model = (
2932
nn.Sequential(
3033
nn.Linear(4096, 21504),
3134
nn.Linear(21504, 4096),
@@ -37,48 +40,38 @@ def test_quant_sparse_marlin_layout_eager(self):
3740
.cuda()
3841
)
3942

40-
apply_fake_sparsity(model)
41-
model_copy = copy.deepcopy(model)
43+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
44+
def test_quant_sparse_marlin_layout_eager(self):
45+
apply_fake_sparsity(self.model)
46+
model_copy = copy.deepcopy(self.model)
4247

4348
# Quantized
4449
quantize_(model_copy.bfloat16(), int4_weight_only())
45-
dense_result = model_copy(input.bfloat16()).half()
50+
dense_result = model_copy(self.input.bfloat16()).half()
4651

4752
# Sparse + quantized
48-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
49-
sparse_result = model(input)
53+
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
54+
sparse_result = self.model(self.input)
5055

5156
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
5257

5358
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
5459
def test_quant_sparse_marlin_layout_compile(self):
55-
torch.manual_seed(0)
56-
57-
input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
58-
model = (
59-
nn.Sequential(
60-
nn.Linear(4096, 21504),
61-
nn.Linear(21504, 4096),
62-
nn.ReLU(),
63-
nn.Linear(4096, 21504),
64-
nn.Linear(21504, 4096),
65-
)
66-
.half()
67-
.cuda()
68-
)
69-
70-
apply_fake_sparsity(model)
71-
model_copy = copy.deepcopy(model)
60+
apply_fake_sparsity(self.model)
61+
model_copy = copy.deepcopy(self.model)
7262

7363
# Quantized
7464
quantize_(model_copy.bfloat16(), int4_weight_only())
7565
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
76-
dense_result = model_copy(input.bfloat16()).half()
66+
dense_result = model_copy(self.input.bfloat16()).half()
7767

7868
# Sparse + quantized
79-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
80-
model.forward = torch.compile(model.forward, fullgraph=True)
81-
sparse_result = model(input)
69+
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
70+
if not TORCH_VERSION_AT_LEAST_2_5:
71+
unwrap_tensor_subclass(self.model)
72+
73+
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
74+
sparse_result = self.model(self.input)
8275

8376
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
8477

@@ -87,34 +80,26 @@ def test_pack_unpack_equivalence(self):
8780
num_bits = 4
8881
group_size = 128
8982
shape = (11008, 4096)
90-
max_q_val = 2**num_bits - 1
91-
half_q_val = (max_q_val + 1) // 2
83+
block_size = (1, group_size)
84+
target_dtype = torch.int32
85+
quant_min = 0
86+
quant_max = 15
87+
eps = 1e-6
88+
zero_point_dtype = torch.bfloat16
89+
mapping_type = MappingType.SYMMETRIC
90+
preserve_zero = True
91+
zero_point_domain = ZeroPointDomain.INT
92+
scale_dtype = None
9293

9394
w = torch.rand(shape, dtype=torch.float16, device="cuda")
94-
size_k, size_n = w.shape
9595

9696
# Inject 2:4 sparsity mask
9797
w_24, _ = inject_24(w, *w.shape)
9898

9999
# Quantize weights
100-
w_24 = w_24.reshape((-1, group_size, size_n))
101-
w_24 = w_24.permute(1, 0, 2)
102-
w_24 = w_24.reshape((group_size, -1))
103-
104-
# Compute scale for each group
105-
scales = torch.max(torch.abs(w_24), 0, keepdim=True)[0]
106-
scales *= 2 / max_q_val # 2 => symmetric
107-
108-
# Quantize
109-
w_q_24 = torch.round(w_24 / scales).int()
110-
w_q_24 += half_q_val
111-
w_q_24 = torch.clamp(w_q_24, 0, max_q_val)
112-
113-
# Shape back to original shape
114-
w_q_24 = w_q_24.reshape((group_size, -1, size_n))
115-
w_q_24 = w_q_24.permute(1, 0, 2)
116-
w_q_24 = w_q_24.reshape((size_k, size_n)).contiguous()
117-
scales = scales.reshape((-1, size_n)).contiguous()
100+
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
101+
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain)
102+
scales = scales.reshape(-1, w_q_24.shape[1])
118103

119104
# Test pack/unpack equivalence
120105
q_w_comp, packed_scales, meta = pack_to_marlin_24(

torchao/sparsity/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,23 @@ For more information about accelerting BERT with semi-sturcutred sparsity, pleas
5858
| F1 (%) | 86.93 | 86.49 | -0.44 |
5959
| Time (bs=16) | 19.35 | 15.74 | 1.23x |
6060

61+
# Implemented APIs
62+
63+
## Quantization + Sparsity
64+
65+
### Sparse Marlin 2:4
66+
67+
Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regressive Linear (Marlin) dense kernel to support 4-bit quantized weights and 2:4 sparsity, improving performance in matrix multiplication and accumulation. Full documentation can be found [here](https://github.com/IST-DASLab/Sparse-Marlin).
68+
69+
```py
70+
from torchao.quantization.quant_api import quantize_, int4_weight_only
71+
from torchao.dtypes import MarlinSparseLayoutType
72+
73+
# Your FP16 model
74+
model = model.cuda().half()
75+
76+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
77+
```
6178

6279
# Design
6380

torchao/sparsity/marlin/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]:
9696
"""Precompute permutations for Marlin24 weight and scale shuffling
9797
9898
Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible
99-
with the tensor-core format.
99+
with the tensor-core format that is described here:
100+
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
101+
102+
As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core
103+
(without the need to use ldmatrix instructions)
100104
101105
Args:
102106
num_bits (int): Number of bits to pack.

0 commit comments

Comments
 (0)