Skip to content

Commit b8401a6

Browse files
jerryzh168kwen2501
authored andcommitted
Adding example for quantized tensor + tensor parallelism (#785)
* [WIP] Adding example for quantized tensor + tensor parallelism Summary: This PR adds an example of how quantized tensor subclass can work with DTensor: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md End goal is to rewrite https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py with normal llama2 implementation and show case with DTensor + AffineQuantizedTensor + torch.compile we can get on par performance with the custom tensor parallel implementation Test Plan: torchrun --standalone --nnodes=1 --nproc-per-node=4 tutorials/developer_api_guide/tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: * tensor parallel file * Use DTensor.from instead of distribute_tensor * implementing aten.slice.Tensor (WIP) * working * some shape fix and use more quant primitive ops * Add rowwise test * make rowwise sharding work * compile still not working yet * fake tensor didn't pick up shape changes from transpose * backend='eager' * change transpose to non-inplace op * add error message * works now with torch nightly * remove print * ruff * Clean up * Fix device id --------- Co-authored-by: Ke Wen <[email protected]>
1 parent 36cf3ed commit b8401a6

File tree

3 files changed

+294
-35
lines changed

3 files changed

+294
-35
lines changed

tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
import torch
1616

1717
from torch.utils._python_dispatch import return_and_correct_aliasing
18-
from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType
18+
from torchao.quantization.quant_primitives import (
19+
choose_qparams_affine,
20+
MappingType,
21+
quantize_affine,
22+
dequantize_affine,
23+
)
1924
from torchao.dtypes.utils import (
2025
LayoutType,
2126
PlainLayoutType,
@@ -24,6 +29,32 @@
2429

2530
aten = torch.ops.aten
2631

32+
# TODO: move to torchao/utils.py
33+
def fill_defaults(args, n, defaults_tail):
34+
"""
35+
__torch_dispatch__ doesn't guarantee the number of arguments you are
36+
passed (e.g., defaulted arguments are not passed); but usually it is
37+
convenient to pad out the arguments list with defaults. This function
38+
helps you do that.
39+
Args:
40+
args: the list of positional arguments passed to __torch_dispatch__
41+
n: the number of arguments you are expecting to get
42+
defaults_tail: default values for the arguments, starting from the
43+
end of the list
44+
Example:
45+
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
46+
[1, 2, 3, 4, 5]
47+
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
48+
[1, 2, 3, None, None]]
49+
"""
50+
if n - len(defaults_tail) > len(args):
51+
raise RuntimeError("not enough defaults to fill arguments")
52+
r = list(args)
53+
for i in range(len(args), n):
54+
r.append(defaults_tail[i - n + len(defaults_tail)])
55+
return r
56+
57+
2758
###############################
2859
# Base Layout Tensor Subclass #
2960
###############################
@@ -140,10 +171,10 @@ def from_float(
140171
layout_type: LayoutType = PlainLayoutType(),
141172
):
142173
mapping_type = MappingType.SYMMETRIC
143-
block_size = input_float.shape
174+
block_size = (1, input_float.shape[-1])
144175
dtype = torch.int16
145-
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
146-
int_data = (input_float / scale).to(torch.int8)
176+
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
177+
int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype)
147178
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
148179
layout_tensor = layout_tensor_ctr(int_data, scale, layout_type)
149180
return cls(layout_tensor, input_float.shape)
@@ -160,7 +191,14 @@ def dequantize(self, output_dtype=None):
160191
if output_dtype is None:
161192
output_dtype = torch.get_default_dtype()
162193
int_data, scale = self.layout_tensor.get_plain()
163-
return int_data.to(output_dtype) * scale
194+
transposed = False
195+
block_size = (1, int_data.shape[-1])
196+
if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed:
197+
transposed = True
198+
res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype)
199+
if transposed:
200+
res = res.t()
201+
return res
164202

165203
def __repr__(self):
166204
return (
@@ -203,6 +241,7 @@ def __new__(
203241
cls,
204242
int_data: torch.Tensor,
205243
scale: torch.Tensor,
244+
transposed: bool,
206245
layout_type: LayoutType,
207246
):
208247
kwargs = {}
@@ -219,22 +258,24 @@ def __init__(
219258
self,
220259
int_data: torch.Tensor,
221260
scale: torch.Tensor,
261+
transposed: bool,
222262
layout_type: LayoutType,
223263
):
224264
self.int_data = int_data
225265
self.scale = scale
266+
self.transposed = transposed
226267
self.layout_type = layout_type
227268

228269
def __tensor_flatten__(self):
229-
return ["int_data", "scale"], [self.layout_type]
270+
return ["int_data", "scale"], [self.transposed, self.layout_type]
230271

231272
@classmethod
232273
def __tensor_unflatten__(
233274
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
234275
):
235276
int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"]
236-
layout_type, = tensor_attributes
237-
return cls(int_data, scale, layout_type)
277+
transposed, layout_type, = tensor_attributes
278+
return cls(int_data, scale, transposed, layout_type)
238279

239280
@classmethod
240281
def from_plain(
@@ -247,12 +288,13 @@ def from_plain(
247288
extra metadata for packing etc.
248289
"""
249290
assert isinstance(layout_type, PlainLayoutType)
250-
return cls(int_data, scale, layout_type)
291+
return cls(int_data, scale, False, layout_type)
251292

252293
def _apply_fn_to_data(self, fn):
253294
return self.__class__(
254295
fn(self.int_data),
255296
fn(self.scale),
297+
self.transposed,
256298
self.layout_type,
257299
)
258300

@@ -265,8 +307,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
265307
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
266308
)
267309

310+
# Tensor parallel support START
311+
elif func in [aten._to_copy.default, aten.clone.default]:
312+
return return_and_correct_aliasing(
313+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
314+
)
315+
elif func is aten.split.Tensor:
316+
int_data_list = func(args[0].int_data, *args[1:], **kwargs)
317+
scale_list = func(args[0].scale, *args[1:], **kwargs)
318+
out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)]
319+
return out
320+
elif func is aten.empty_like.default:
321+
int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs)
322+
return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type)
323+
elif func is aten.slice.Tensor:
324+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
325+
if dim == 0:
326+
return return_and_correct_aliasing(
327+
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
328+
)
329+
elif dim == 1:
330+
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
331+
else:
332+
raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
333+
elif func is aten.t.default:
334+
return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type))
335+
336+
# Tensor parallel support END
337+
268338
raise NotImplementedError(
269-
f"MyDTypeLayout dispatch: attempting to run {func}, this is not supported"
339+
f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported"
270340
)
271341

272342
#####################################################
@@ -315,15 +385,6 @@ def _(func, types, args, kwargs):
315385
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
316386
)
317387

318-
319-
class M(torch.nn.Module):
320-
def __init__(self, *args, **kwargs) -> None:
321-
super().__init__(*args, **kwargs)
322-
self.linear = torch.nn.Linear(1024, 1024)
323-
324-
def forward(self, x: torch.Tensor) -> torch.Tensor:
325-
return self.linear(x)
326-
327388
#####################
328389
# Factory functions #
329390
#####################
@@ -333,42 +394,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
333394
########
334395
# Test #
335396
########
336-
337-
def test():
397+
def main():
338398
from torchao.utils import benchmark_model
339-
399+
400+
class M(torch.nn.Module):
401+
def __init__(self) -> None:
402+
super().__init__()
403+
self.linear = torch.nn.Linear(1024, 128)
404+
405+
def forward(self, x: torch.Tensor) -> torch.Tensor:
406+
return self.linear(x)
407+
340408
m = M()
341-
example_inputs = (100 * torch.randn(1024, 1024),)
409+
example_inputs = (100 * torch.randn(512, 1024),)
342410
NUM_WARMUPS = 10
343411
NUM_RUNS = 100
344-
412+
345413
for _ in range(NUM_WARMUPS):
346414
m(*example_inputs)
347415
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
348-
416+
349417
compiled = torch.compile(m, mode="max-autotune")
350418
for _ in range(NUM_WARMUPS):
351419
compiled(*example_inputs)
352420
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))
353-
421+
354422
# convert weights to quantized weights
355423
m.linear.weight = torch.nn.Parameter(
356424
to_my_dtype(m.linear.weight), requires_grad=False
357425
)
358-
426+
359427
for _ in range(NUM_WARMUPS):
360428
m(*example_inputs)
361-
429+
362430
print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
363-
431+
364432
m = torch.compile(m, mode="max-autotune")
365-
433+
366434
for _ in range(NUM_WARMUPS):
367435
m(*example_inputs)
368-
436+
369437
# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
370438
# we plan to add custom op example in the future and that will help us to get speedup
371439
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))
372440

373441
if __name__ == "__main__":
374-
test()
442+
main()

tutorials/developer_api_guide/my_trainable_tensor_subclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def from_float(
6161
return _ToMyTrainableDTypeTensor.apply(input_float, layout_type)
6262

6363
class _ToMyTrainableDTypeTensor(torch.autograd.Function):
64-
"""
64+
"""
6565
Differentiable constructor for `MyTrainableDTypeTensor`.
6666
"""
6767

@@ -163,8 +163,8 @@ def _(func, types, args, kwargs):
163163
########
164164

165165
class M(torch.nn.Module):
166-
def __init__(self, *args, **kwargs) -> None:
167-
super().__init__(*args, **kwargs)
166+
def __init__(self) -> None:
167+
super().__init__()
168168
self.linear = torch.nn.Linear(512, 1024, bias=False)
169169

170170
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)