Skip to content

Commit ada04bf

Browse files
committed
Add quant+sparse subclasses to torchao
This PR adds in functionality for quant + sparse composition, with subclasses.
1 parent eff2f5a commit ada04bf

File tree

7 files changed

+406
-0
lines changed

7 files changed

+406
-0
lines changed

benchmark_sam.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
from torchao.quantization import change_linear_weights_to_int8_dqtensors
3+
from torchao.quantization.quant_api import change_linear_weights_to_int8_dq_semi_structured_sparsetensors
4+
from segment_anything import sam_model_registry
5+
from torch.utils.benchmark import Timer
6+
from torchao.sparsity import apply_fake_sparsity, apply_sparse
7+
8+
from torchao.quantization.dynamic_quant_sparse import apply_int4_dynamic_quant_sparse
9+
10+
sam_checkpoint_base_path = "/home/jessecai/local/MODELS"
11+
model_type = 'vit_h'
12+
model_name = 'sam_vit_h_4b8939.pth'
13+
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
14+
batchsize = 16
15+
only_one_block = False
16+
17+
@torch.no_grad()
18+
def benchmark(f, *args, **kwargs):
19+
for _ in range(3):
20+
f(*args, **kwargs)
21+
torch.cuda.synchronize()
22+
23+
torch.cuda.reset_peak_memory_stats()
24+
t0 = Timer(
25+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
26+
)
27+
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
28+
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}
29+
30+
def get_sam_model(only_one_block=False, batchsize=1):
31+
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
32+
model = sam.image_encoder.eval()
33+
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')
34+
35+
# code to use just a single block of the model
36+
if only_one_block:
37+
model = model.blocks[0]
38+
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
39+
return model, image
40+
41+
model, image = get_sam_model(only_one_block, batchsize)
42+
model = model.to(torch.bfloat16)
43+
image = image.to(torch.bfloat16)
44+
torch._inductor.config.epilogue_fusion = False
45+
torch._inductor.config.coordinate_descent_tuning = True
46+
torch._inductor.config.coordinate_descent_check_all_directions = True
47+
torch._inductor.config.force_fuse_int_mm_with_mul = True
48+
49+
change_linear_weights_to_int8_dqtensors(model)
50+
model_c = torch.compile(model, mode='max-autotune')
51+
quant_res = benchmark(model_c, image)
52+
53+
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
54+
55+
del model_c, model, image
56+
model, image = get_sam_model(only_one_block, batchsize)
57+
model = model.to(torch.bfloat16)
58+
image = image.to(torch.bfloat16)
59+
apply_sparse(model)
60+
torch._inductor.config.epilogue_fusion = False
61+
torch._inductor.config.coordinate_descent_tuning = True
62+
torch._inductor.config.coordinate_descent_check_all_directions = True
63+
torch._inductor.config.force_fuse_int_mm_with_mul = True
64+
model_c = torch.compile(model, mode='max-autotune')
65+
quant_res = benchmark(model_c, image)
66+
67+
print(f"bf16 compiled runtime of the final sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
68+
69+
del model_c, model, image
70+
model, image = get_sam_model(only_one_block, batchsize)
71+
model = model.to(torch.bfloat16)
72+
image = image.to(torch.bfloat16)
73+
change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model)
74+
torch._inductor.config.epilogue_fusion = False
75+
torch._inductor.config.coordinate_descent_tuning = True
76+
torch._inductor.config.coordinate_descent_check_all_directions = True
77+
torch._inductor.config.force_fuse_int_mm_with_mul = True
78+
model_c = torch.compile(model, mode='max-autotune')
79+
quant_res = benchmark(model_c, image)
80+
81+
print(f"bf16 compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
82+
83+
del model_c, model, image
84+
model, image = get_sam_model(only_one_block, batchsize)
85+
model = model.to(torch.bfloat16)
86+
image = image.to(torch.bfloat16)
87+
apply_int4_dynamic_quant_sparse(model)
88+
torch._inductor.config.epilogue_fusion = False
89+
torch._inductor.config.coordinate_descent_tuning = True
90+
torch._inductor.config.coordinate_descent_check_all_directions = True
91+
torch._inductor.config.force_fuse_int_mm_with_mul = True
92+
model_c = torch.compile(model, mode='max-autotune')
93+
quant_res = benchmark(model_c, image)
94+
95+
print(f"bf16 compiled runtime of the final quant + sparsified block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")

torchao/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"apply_weight_only_int8_quant",
1717
"apply_dynamic_quant",
1818
"change_linear_weights_to_int8_dqtensors",
19+
"change_linear_weights_to_int8_dq_semi_structured_sparsetensors",
1920
"change_linear_weights_to_int8_woqtensors",
2021
"change_linear_weights_to_int4_woqtensors",
2122
"swap_conv2d_1x1_to_linear"
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import torch
2+
import torch.nn as nn
3+
from typing import Tuple, Optional
4+
5+
from functools import partial
6+
7+
from torchao.quantization.quant_primitives import (
8+
dynamically_quantize_per_channel,
9+
quant_int8_dynamic_per_token_linear,
10+
quantize_activation_per_token_absmax
11+
)
12+
from torchao.quantization import quant_api
13+
from torchao.sparsity import apply_fake_sparsity
14+
15+
# Quant + Sparse helper functinos
16+
def sparse_quant_int8_dynamic_per_token_linear(
17+
x,
18+
w_vals_int8,
19+
w_scales,
20+
bias,
21+
out_dtype=torch.float32,
22+
fuse_dequant=True,
23+
):
24+
# like F.linear, but with int8 dynamic quantization of activation,
25+
# and a quantized weight
26+
x_vals_int8, x_scales = quantize_activation_per_token_absmax(x)
27+
mm_out = sparse_quant_int8_per_token_matmul(
28+
x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype, fuse_dequant=fuse_dequant)
29+
if bias is not None:
30+
mm_out += bias
31+
return mm_out
32+
33+
def sparse_quant_int8_per_token_matmul(
34+
x_vals_int8,
35+
x_scales,
36+
w_vals_int8,
37+
w_scales,
38+
out_dtype=torch.float32,
39+
fuse_dequant=True,
40+
):
41+
# Quantized sparse matmul of int8 operands that accumulates to fp16 and returns
42+
# out_dtype. This matmul uses cuSPARSELt as a backend.
43+
44+
# Assumes that activation and weight quantization are symmetric,
45+
# i.e. act_zp and w_zp is 0.
46+
# Assumes that weight quantization is per-channel.
47+
# NOTE: sparsity is only compatible with symmetric (zero-preserving) quantization techniques.
48+
49+
# see
50+
# https://github.com/google/gemmlowp/blob/master/doc/quantization.md
51+
# for an overview of quantized matmul compute
52+
53+
# in scalar form, assuming out_dtype is fp32 and zw == 0:
54+
#
55+
# Y_i_j_fp32 = sx * sw dot(X_i, W_j)
56+
#
57+
58+
assert x_vals_int8.dtype == torch.int8, \
59+
f'x dtype {x_vals_int8.dtype} not yet supported'
60+
assert w_vals_int8.dtype == torch.int8, \
61+
f'w dtype {w_vals_int8.dtype} not yet supported'
62+
assert w_scales.dtype == out_dtype, \
63+
f'{w_scales.dtype} does not match {out_dtype}'
64+
65+
#
66+
# 1. do the matrix form of dot(X_i, W_j)
67+
#
68+
69+
# For sparse matmul, we need one of the input operands to be transposed.
70+
# This is because cuSPARSELt only supports int8 matmul for specific formats:
71+
# https://docs.nvidia.com/cuda/cusparselt/functions.html#matmul-descriptor-functions
72+
# Because we currently only support the first input to the operand being sparse,
73+
# we cannot transpose w_vals_int8, so instead we transpose x_vals_int8.
74+
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous()
75+
# Since cuSPARSELt does not have support for int32 output, we instead use the fp16 kernel
76+
# instead, by setting out_dtype.
77+
# y_dot_fp16 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, out_dtype=torch.float16)
78+
y_dot_fp16 = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), out_dtype=torch.float16).t()
79+
y_dot_fp32 = y_dot_fp16.reshape(*x_vals_int8.shape[:-1], -1).to(out_dtype)
80+
81+
#
82+
# 2. rescale the output
83+
#
84+
# in cases with large matrices, y_dot_int32 can grow sufficiently
85+
# large that y_dot_int32 * a float16 scale is greater than the maximum
86+
# value of a float 16, (which results in a value of inf even if multiplying
87+
# by the other scale would bring it within the expected range)
88+
89+
# assert x_scales.dtype == torch.float, f"x_scales needs to be a torch.float32 but got {x_scales.dtype}"
90+
91+
y = y_dot_fp32 * x_scales * w_scales
92+
93+
# can downcast only at the very end
94+
y = y.to(out_dtype)
95+
return y
96+
97+
class SparseDynamicallyPerAxisQuantizedLinear(torch.nn.Linear):
98+
"""
99+
This class is a replacement for `torch.nn.Linear`, implementing sparse dynamic quantization on
100+
the input across all axes except for the last axis.
101+
"""
102+
103+
def __init__(
104+
self,
105+
in_features: int,
106+
out_features: int,
107+
bias: bool = True
108+
):
109+
super().__init__(in_features, out_features, bias)
110+
111+
def forward(self, X: torch.Tensor) -> torch.Tensor:
112+
"""
113+
Performs the forward pass of the sparse quantized linear layer.
114+
115+
This method applies dynamic quantization to the input tensor across all axes except
116+
the last axis using the `quant_int8_dynamic_per_token_linear` function.
117+
118+
We artifically limit the quantization value to int4 range to ensure we stay within the range of fp16.
119+
This method will use cuSPASRELt to perform sparse matmul.
120+
121+
Args:
122+
X (torch.Tensor): The input tensor to the sparse quantized linear layer.
123+
Returns:
124+
torch.Tensor: The output tensor after the sparse quantized matmul and rescale.
125+
"""
126+
Y = sparse_quant_int8_dynamic_per_token_linear(
127+
X, self.W_int_repr, self.W_scales, self.bias, X.dtype, fuse_dequant=self.fuse_dequant)
128+
return Y
129+
130+
@classmethod
131+
def from_float(cls, mod: torch.nn.Linear, fuse_dequant=True) -> 'SparseDynamicallyPerAxisQuantizedLinear':
132+
"""
133+
Converts a `mod` of class `torch.nn.Linear` to the sparse dynamically quantized version of it.
134+
Note: this class does not require calibration.
135+
Args:
136+
mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert.
137+
Returns:
138+
SparseDynamicallyPerAxisQuantizedLinear: The converted sparse quantized linear module.
139+
"""
140+
141+
# create the new module with a toy size to ensure initialization is fast
142+
fake_in_features, fake_out_features = 8, 8
143+
new_mod = cls(
144+
fake_in_features, fake_out_features, bias=mod.bias is not None)
145+
new_mod.in_features = mod.in_features
146+
new_mod.out_features = mod.out_features
147+
# NOTE: We artifically clamp the values to int4 quantization to ensure we stay within the
148+
# dynamic range of fp16
149+
W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel(
150+
mod.weight, -8, 7, torch.int8)
151+
new_mod.register_buffer('W_int_repr', torch._cslt_compress(W_int_repr.contiguous()))
152+
new_mod.register_buffer('W_scales', W_scales)
153+
new_mod.bias = mod.bias
154+
new_mod.fuse_dequant = fuse_dequant
155+
del new_mod.weight
156+
157+
device_to_use = next(mod.parameters()).device
158+
new_mod.to(device_to_use)
159+
return new_mod
160+
161+
def apply_int4_dynamic_quant_sparse(model, fuse_dequant=False):
162+
apply_fake_sparsity(model)
163+
quant_api._replace_with_custom_fn_if_matches_filter(
164+
model,
165+
partial(SparseDynamicallyPerAxisQuantizedLinear.from_float, fuse_dequant=fuse_dequant),
166+
lambda mod, fqn: isinstance(mod, torch.nn.Linear))

torchao/quantization/quant_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .subclass import (
2323
QuantizedLinearWeightBase,
2424
Int8DynamicallyQuantizedLinearWeight,
25+
Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight,
2526
Int8WeightOnlyQuantizedLinearWeight,
2627
Int4WeightOnlyQuantizedLinearWeight,
2728
)
@@ -33,6 +34,7 @@
3334
"apply_weight_only_int8_quant",
3435
"apply_dynamic_quant",
3536
"change_linear_weights_to_int8_dqtensors",
37+
"change_linear_weights_to_int8_dq_semi_structured_sparsetensors",
3638
"change_linear_weights_to_int8_woqtensors",
3739
"change_linear_weights_to_int4_woqtensors",
3840
"swap_conv2d_1x1_to_linear"
@@ -153,6 +155,17 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
153155
filter_fn,
154156
)
155157

158+
159+
def change_linear_weights_to_int8_dq_semi_structured_sparsetensors(model, **kwargs):
160+
filter_fn = kwargs.pop("filter_fn", _is_linear)
161+
162+
_replace_with_custom_fn_if_matches_filter(
163+
model,
164+
_get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, **kwargs),
165+
filter_fn,
166+
)
167+
168+
156169
def swap_conv2d_1x1_to_linear(model, filter_fn=None):
157170
"""
158171
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.

0 commit comments

Comments
 (0)