Skip to content

Commit 9bcc422

Browse files
jcaipDiogo-V
authored andcommitted
compile kind of working
1 parent a3f32f9 commit 9bcc422

File tree

4 files changed

+172
-42
lines changed

4 files changed

+172
-42
lines changed

test/sparsity/test_marlin.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
class SparseMarlin24(TestCase):
2222

2323
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
24-
def test_quant_sparse_marlin_layout_e2e(self):
25-
input = torch.randn((16, 4096), dtype=torch.float16, device="cuda")
24+
def test_quant_sparse_marlin_layout_eager(self):
25+
# this batch input fails
26+
input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
2627
model = (
2728
nn.Sequential(
2829
nn.Linear(4096, 11008), # Llama2 shapes
@@ -35,20 +36,57 @@ def test_quant_sparse_marlin_layout_e2e(self):
3536
.cuda()
3637
)
3738

39+
apply_fake_sparsity(model)
3840
# Baseline
39-
ref_result = model(input)
41+
model_copy = copy.deepcopy(model)
42+
43+
# Quantized
44+
quantize_(model_copy.bfloat16(), int4_weight_only())
45+
dense_result = model_copy(input.bfloat16()).half()
46+
47+
# Sparse + quantized
48+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
49+
sparse_result = model(input)
50+
51+
error_dense = torch.mean(torch.abs(ref_result - dense_result) ** 2)
52+
error_sparse = torch.mean(torch.abs(ref_result - sparse_result) ** 2)
53+
assert torch.allclose(dense_model, sparse_model, atol=1e-2), "Mean error is not close"
4054

55+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
56+
def test_quant_sparse_marlin_layout_compile(self):
57+
input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
58+
model = (
59+
nn.Sequential(
60+
nn.Linear(4096, 11008), # Llama2 shapes
61+
# nn.Linear(11008, 4096),
62+
# nn.ReLU(),
63+
# nn.Linear(4096, 11008),
64+
# nn.Linear(11008, 4096),
65+
)
66+
.half()
67+
.cuda()
68+
)
69+
70+
# Baseline
4171
apply_fake_sparsity(model)
72+
ref_result = model(input)
73+
4274
model_copy = copy.deepcopy(model)
4375

4476
# Quantized
4577
quantize_(model_copy.bfloat16(), int4_weight_only())
78+
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
4679
dense_result = model_copy(input.bfloat16()).half()
4780

4881
# Sparse + quantized
4982
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
83+
model.forward = torch.compile(model.forward, fullgraph=True)
5084
sparse_result = model(input)
5185

86+
print(dense_result)
87+
print(sparse_result)
88+
torch.allclose(sparse_result, dense_result)
89+
5290
error_dense = torch.mean(torch.abs(ref_result - dense_result) ** 2)
5391
error_sparse = torch.mean(torch.abs(ref_result - sparse_result) ** 2)
5492
assert torch.allclose(error_dense, error_sparse, atol=1e-2), "Mean error is not close"
@@ -70,7 +108,6 @@ def test_pack_unpack_equivalence(self):
70108
)
71109

72110
scales = scales.reshape(-1, w_q_24.shape[1])
73-
74111
# Test pack/unpack equivalence
75112
q_w_comp, packed_scales, meta = pack_to_marlin_24(
76113
w_q_24, scales, num_bits, group_size

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,13 @@ def from_plain(
5858
):
5959
pass
6060

61+
@torch._dynamo.disable
6162
def __repr__(self):
62-
int_data, scale, zero_point = self.get_plain()
63-
layout_type = self.get_layout_type()
64-
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"
63+
# This is a hack, torch.compile tries to trace the __repr__ function which then calls `dequantize` function, causing an error.
64+
# by removing the call to dequantize the error goes away.
65+
# int_data, scale, zero_point = self.get_plain()
66+
# layout_type = self.get_layout_type()
67+
return f"{self.__class__.__name__}" #(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"
6568

6669
def _get_to_kwargs(self, *args, **kwargs):
6770
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
@@ -152,10 +155,13 @@ def __init__(
152155
self.quant_max = quant_max
153156
self.zero_point_domain = zero_point_domain
154157

158+
@torch._dynamo.disable
155159
def __repr__(self):
156160
return (
157-
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
158-
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
161+
f"{self.__class__.__name__}"
162+
# Same hack here
163+
#(data={self.dequantize()}, shape={self.shape}, "
164+
#f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
159165
)
160166

161167
def dequantize(self, output_dtype=None):
@@ -552,6 +558,8 @@ class MarlinSparseAQTLayout(AQTLayout):
552558
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
553559
__torch_function__ = classmethod(_dispatch__torch_function__)
554560

561+
@staticmethod
562+
@torch._dynamo.disable
555563
def __new__(
556564
cls,
557565
int_data: torch.Tensor,
@@ -573,6 +581,7 @@ def __new__(
573581
shape = int_data.shape
574582
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
575583

584+
@torch._dynamo.disable
576585
def __init__(
577586
self,
578587
int_data: torch.Tensor,
@@ -593,8 +602,24 @@ def __init__(
593602
self.group_size = group_size
594603
self.num_bits = num_bits
595604

605+
def __tensor_flatten__(self):
606+
return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits]
607+
608+
@classmethod
609+
def __tensor_unflatten__(
610+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
611+
):
612+
int_data = tensor_data_dict["int_data"]
613+
scale = tensor_data_dict["scale"]
614+
zero_point = tensor_data_dict["zero_point"]
615+
meta = tensor_data_dict["meta"]
616+
layout_type, original_shape, group_size, num_bits = tensor_attributes
617+
return cls(int_data, scale, zero_point, meta, layout_type, original_shape, group_size, num_bits)
618+
619+
@torch._dynamo.disable
596620
def get_plain(self):
597621
from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import
622+
unpack_from_marlin_24 = torch._dynamo.disable(unpack_from_marlin_24)
598623
int_data_expanded, scales_expanded = unpack_from_marlin_24(
599624
self.int_data,
600625
self.scale,
@@ -606,6 +631,7 @@ def get_plain(self):
606631
return int_data_expanded, scales_expanded, self.zero_point
607632

608633
@classmethod
634+
@torch._dynamo.disable
609635
def from_plain(
610636
cls,
611637
int_data: torch.Tensor,
@@ -674,7 +700,7 @@ def _apply_fn_to_data(self, fn):
674700
@MarlinSparseAQTLayout.implements(aten.detach.default)
675701
def block_sparse_detach(func, types, args, kwargs):
676702
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
677-
703+
678704

679705
@register_layout_cls(TensorCoreTiledLayoutType)
680706
class TensorCoreTiledAQTLayout(AQTLayout):
@@ -920,7 +946,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
920946
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
921947
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
922948
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
923-
w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16
949+
w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16,
924950
).t()
925951
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
926952
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
@@ -1037,6 +1063,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
10371063

10381064
def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
10391065
return (
1066+
isinstance(weight_tensor, AffineQuantizedTensor) and
10401067
_aqt_is_uint4(weight_tensor) and
10411068
input_tensor.dtype == torch.float16 and
10421069
len(weight_tensor.shape) == 2 and
@@ -1046,11 +1073,13 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
10461073

10471074
def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
10481075
from torchao.sparsity.marlin import marlin_24_workspace, const
1076+
assert isinstance(weight_tensor, AffineQuantizedTensor)
10491077

10501078
sparse_w_int4 = weight_tensor.layout_tensor.int_data
10511079
scale = weight_tensor.layout_tensor.scale
10521080
meta = weight_tensor.layout_tensor.meta
10531081
original_shape = weight_tensor.layout_tensor.original_shape
1082+
print("original_shape", original_shape)
10541083
num_bits = weight_tensor.layout_tensor.num_bits
10551084

10561085
# Saves batch size for reshaping back to original shape after the matmul
@@ -1059,13 +1088,15 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
10591088
batch_size = -1
10601089
if input_tensor.dim() == 3:
10611090
batch_size = input_tensor.size(0)
1062-
input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]).contiguous()
1091+
input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1])
10631092

10641093
size_m = input_tensor.shape[0]
10651094
size_n = original_shape[1]
10661095
size_k = input_tensor.shape[1]
10671096
workspace_24 = marlin_24_workspace(original_shape[1])
10681097

1098+
print(size_m, size_n, size_k)
1099+
10691100
# Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
10701101
if size_k % const.TILE != 0:
10711102
pad_size = find_multiple(size_k, const.TILE)
@@ -1076,11 +1107,9 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
10761107
input_tensor, sparse_w_int4, meta, scale,
10771108
workspace_24, num_bits, size_m, size_n, size_k
10781109
)
1079-
torch.cuda.synchronize()
10801110

1081-
# Reshape back to original shape
10821111
if batch_size != -1:
1083-
out = out.reshape(batch_size, -1, out.shape[-1])
1112+
out = out.view(batch_size, -1, out.shape[-1])
10841113

10851114
if bias is not None:
10861115
out += bias.to(out.dtype)
@@ -1113,14 +1142,14 @@ def _(func, types, args, kwargs):
11131142
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
11141143
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
11151144
# make the branches easier to understand in `_quantized_linear_op`
1116-
try:
1117-
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
1118-
except:
1119-
if isinstance(input_tensor, AffineQuantizedTensor):
1120-
input_tensor = input_tensor.dequantize()
1121-
if isinstance(weight_tensor, AffineQuantizedTensor):
1122-
weight_tensor = weight_tensor.dequantize()
1123-
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
1145+
# try:
1146+
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
1147+
# except:
1148+
# if isinstance(input_tensor, AffineQuantizedTensor):
1149+
# input_tensor = input_tensor.dequantize()
1150+
# if isinstance(weight_tensor, AffineQuantizedTensor):
1151+
# weight_tensor = weight_tensor.dequantize()
1152+
# return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
11241153

11251154
@implements(aten.addmm.default)
11261155
def _(func, types, args, kwargs):

torchao/quantization/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
362362
# Move to cpu, until issue with MPS memory management of temporary tensors is resolved
363363
if int_data_device_type == 'mps':
364364
int_data = int_data.cpu()
365-
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
366-
if int_data_device_type == 'mps':
365+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
367366
int_data = int_data.to(device='mps')
368367
return int_data
369368

wip_test_llama2.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,89 @@
1+
# This script shows how to accelerate an off-the-shelf 2:4 sparse checkpoint
2+
# using pytorch's `to_sparse_semi_structured`
3+
4+
# Also shows how to use marlin
5+
6+
# It takes advantage of the model checkpoints offered by neuralmagic:
7+
# https://huggingface.co/nm-testing/SparseLlama-3-8B-pruned_50.2of4-FP8
8+
9+
import os
110
import torch
2-
from torchao import quantize_
3-
from torchao.quantization import int4_weight_only
11+
from torchao.sparsity import sparsify_, semi_sparse_weight
12+
13+
from tqdm import tqdm
14+
from transformers import AutoModelForCausalLM, AutoTokenizer
15+
from torchao.utils import benchmark_model, profiler_runner
16+
from torchao.quantization import int4_weight_only, quantize_
417
from torchao.dtypes import MarlinSparseLayoutType
5-
from transformers import AutoTokenizer, LlamaForCausalLM
618

7-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8-
name = "meta-llama/Llama-2-7b-hf"
9-
token = "your token"
19+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
20+
21+
torch.set_float32_matmul_precision('high')
22+
23+
24+
# Even though we need to pad the matmul shapes from (1, hidden) @ (hidden, output)
25+
# to (8, hidden) @ (hidden, output) we are still able to achieve speedups on
26+
# the mlp.up and mlp.gate linear layers of the FFN.
27+
def is_mlp_up_or_mlp_gate(mod, name):
28+
return isinstance(mod, torch.nn.Linear) and ('mlp.gate' in name or 'mlp.up' in name)
29+
30+
def run_benchmark(compression_config="baseline", dtype=torch.float16):
31+
print (f"\n Running: {compression_config} benchmark with dtype={dtype}\n")
32+
33+
model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", torch_dtype=dtype).cuda()
34+
tokenizer = AutoTokenizer.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4")
35+
prompt = "Why dogs are so cute?"
36+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
37+
38+
# Specify the max length (including both the prompt and the response)
39+
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
40+
# with sequence length = `max_length`. The longer the more you will re-use it
41+
model.generation_config.max_length = 128
42+
model.generation_config.pad_token_id = tokenizer.eos_token_id
43+
model.generation_config.cache_implementation = "static"
44+
45+
if compression_config == "24_sparse":
46+
sparsify_(model, semi_sparse_weight(), filter_fn=is_mlp_up_or_mlp_gate)
47+
elif compression_config == "int4_wo":
48+
assert dtype == torch.bfloat16, "int4 quantization only works with bf16"
49+
quantize_(model, int4_weight_only())
50+
elif compression_config == "sparse_marlin":
51+
assert dtype == torch.float16, "sparse_marlin only works with fp16"
52+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
53+
elif compression_config == "baseline":
54+
pass
55+
else:
56+
raise ValueError(f"Unknown compression config: {compression_config}")
57+
58+
# `torch.compile(model, ...)` is not recommended as you compile callbacks
59+
# and full generate. We recommend compiling only the forward for now.
60+
# "reduce-overhead" will use cudagraphs.
61+
torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit = None
62+
63+
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
64+
65+
# WARMUP
66+
benchmark_model(lambda: model.generate(**inputs), 5, device_type="cuda")
67+
# res is in ms so multiply by 1000 to get tok/s
68+
res = benchmark_model(lambda: model.generate(**inputs), 25, device_type="cuda")
69+
tokens_per_second = 1000 * (121 / res)
70+
print(f"Average time: {res:.3f}ms | Tokens/second: {tokens_per_second:.3f}")
71+
72+
# sanity check we get same output as non-compiled model
73+
outputs = model.generate(**inputs)
74+
response = tokenizer.batch_decode(outputs)[0]
75+
print(response)
76+
77+
del model
1078

11-
model = LlamaForCausalLM.from_pretrained(name, torch_dtype=torch.float16, token=token).to(device)
12-
tokenizer = AutoTokenizer.from_pretrained(name, token=token)
79+
## baseline
80+
# run_benchmark(compression_config="baseline", dtype=torch.bfloat16)
1381

14-
prompt = "Hey, are you conscious? Can you talk to me? I'm"
15-
inputs = tokenizer(prompt, return_tensors="pt")
82+
# # ## int4_wo
83+
run_benchmark(compression_config="int4_wo", dtype=torch.bfloat16)
1684

17-
# Quantize
18-
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
85+
# ## sparse marlin
86+
# run_benchmark(compression_config="sparse_marlin", dtype=torch.float16)
1987

20-
# Generate
21-
ids = inputs.input_ids.to(device)
22-
generate_ids = model.generate(ids, max_length=30)
23-
out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
24-
print(out)
88+
## sparse
89+
# run_benchmark(compression_config="24_sparse", dtype=torch.bfloat16)

0 commit comments

Comments
 (0)