Skip to content

Commit 1207cc5

Browse files
committed
working
1 parent 5750d7f commit 1207cc5

File tree

2 files changed

+105
-41
lines changed

2 files changed

+105
-41
lines changed

tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,32 @@
2424

2525
aten = torch.ops.aten
2626

27+
# TODO: move to torchao/utils.py
28+
def fill_defaults(args, n, defaults_tail):
29+
"""
30+
__torch_dispatch__ doesn't guarantee the number of arguments you are
31+
passed (e.g., defaulted arguments are not passed); but usually it is
32+
convenient to pad out the arguments list with defaults. This function
33+
helps you do that.
34+
Args:
35+
args: the list of positional arguments passed to __torch_dispatch__
36+
n: the number of arguments you are expecting to get
37+
defaults_tail: default values for the arguments, starting from the
38+
end of the list
39+
Example:
40+
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
41+
[1, 2, 3, 4, 5]
42+
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
43+
[1, 2, 3, None, None]]
44+
"""
45+
if n - len(defaults_tail) > len(args):
46+
raise RuntimeError("not enough defaults to fill arguments")
47+
r = list(args)
48+
for i in range(len(args), n):
49+
r.append(defaults_tail[i - n + len(defaults_tail)])
50+
return r
51+
52+
2753
###############################
2854
# Base Layout Tensor Subclass #
2955
###############################
@@ -204,6 +230,7 @@ def __new__(
204230
cls,
205231
int_data: torch.Tensor,
206232
scale: torch.Tensor,
233+
transposed: bool,
207234
layout_type: LayoutType,
208235
):
209236
kwargs = {}
@@ -220,22 +247,24 @@ def __init__(
220247
self,
221248
int_data: torch.Tensor,
222249
scale: torch.Tensor,
250+
transposed: bool,
223251
layout_type: LayoutType,
224252
):
225253
self.int_data = int_data
226254
self.scale = scale
255+
self.transposed = transposed
227256
self.layout_type = layout_type
228257

229258
def __tensor_flatten__(self):
230-
return ["int_data", "scale"], [self.layout_type]
259+
return ["int_data", "scale"], [self.transposed, self.layout_type]
231260

232261
@classmethod
233262
def __tensor_unflatten__(
234263
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
235264
):
236265
int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"]
237-
layout_type, = tensor_attributes
238-
return cls(int_data, scale, layout_type)
266+
transposed, layout_type, = tensor_attributes
267+
return cls(int_data, scale, transposed, layout_type)
239268

240269
@classmethod
241270
def from_plain(
@@ -248,12 +277,13 @@ def from_plain(
248277
extra metadata for packing etc.
249278
"""
250279
assert isinstance(layout_type, PlainLayoutType)
251-
return cls(int_data, scale, layout_type)
280+
return cls(int_data, scale, False, layout_type)
252281

253282
def _apply_fn_to_data(self, fn):
254283
return self.__class__(
255284
fn(self.int_data),
256285
fn(self.scale),
286+
self.transposed,
257287
self.layout_type,
258288
)
259289

@@ -273,11 +303,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
273303
elif func is aten.split.Tensor:
274304
int_data_list = func(args[0].int_data, *args[1:], **kwargs)
275305
scale_list = func(args[0].scale, *args[1:], **kwargs)
276-
out = [PlainMyDTypeLayout(int_data, scale, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)]
306+
out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)]
277307
return out
278308
elif func is aten.empty_like.default:
279309
int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs)
280-
return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].layout_type)
310+
return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type)
311+
elif func is aten.slice.Tensor:
312+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
313+
if dim == 0:
314+
return return_and_correct_aliasing(
315+
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
316+
)
317+
elif dim == 1:
318+
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
319+
elif func is aten.t.default:
320+
args[0].transposed = not args[0].transposed
321+
return return_and_correct_aliasing(func, args, kwargs, args[0])
281322

282323
raise NotImplementedError(
283324
f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported"

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from my_dtype_tensor_subclass import MyDTypeTensor
2+
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults
33
from torch.utils._python_dispatch import return_and_correct_aliasing
44

55
# a tensor subclass that supports tensor parallelism with DTensor
@@ -10,30 +10,6 @@ class MyDTypeTensorTP(MyDTypeTensor):
1010

1111
aten = torch.ops.aten
1212

13-
def fill_defaults(args, n, defaults_tail):
14-
"""
15-
__torch_dispatch__ doesn't guarantee the number of arguments you are
16-
passed (e.g., defaulted arguments are not passed); but usually it is
17-
convenient to pad out the arguments list with defaults. This function
18-
helps you do that.
19-
Args:
20-
args: the list of positional arguments passed to __torch_dispatch__
21-
n: the number of arguments you are expecting to get
22-
defaults_tail: default values for the arguments, starting from the
23-
end of the list
24-
Example:
25-
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
26-
[1, 2, 3, 4, 5]
27-
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
28-
[1, 2, 3, None, None]]
29-
"""
30-
if n - len(defaults_tail) > len(args):
31-
raise RuntimeError("not enough defaults to fill arguments")
32-
r = list(args)
33-
for i in range(len(args), n):
34-
r.append(defaults_tail[i - n + len(defaults_tail)])
35-
return r
36-
3713
@implements([aten._to_copy.default, aten.clone.default])
3814
def _(func, types, args, kwargs):
3915
return return_and_correct_aliasing(
@@ -51,20 +27,67 @@ def _(func, types, args, kwargs):
5127
empty_like_layout_tensor = func(args[0].layout_tensor, *args[1:], **kwargs)
5228
return MyDTypeTensorTP(empty_like_layout_tensor, empty_like_layout_tensor.shape)
5329

54-
@implements([aten.slice.Tensor])
30+
@implements(aten.slice.Tensor)
5531
def _(func, types, args, kwargs):
5632
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
57-
print("slice:", dim, start, end, step)
58-
if dim == 0:
59-
assert step == 1
60-
return self.__class__(aten.slice.Tensor(self.layout_tensor), (end - start + 1,) + self.shape[1:], self.dtype)
61-
return
33+
assert step == 1
34+
if end >= self.shape[dim]:
35+
end = self.shape[dim]
36+
print("dim:", dim, "start:", start, " end:", end, " shape:", end - start)
37+
print("manual shape:", (end - start,) + self.shape[1:])
38+
return self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), (end - start,) + self.shape[1:], self.dtype)
39+
40+
# this is needed for DTensor.from_local() and for flattening tensor
41+
@implements(aten.view.default)
42+
def _(func, types, args, kwargs):
43+
x, shape = args
44+
45+
if tuple(x.shape) == tuple(shape):
46+
return x.__class__(x.layout_tensor, x.shape, x.dtype)
47+
48+
if len(shape) == 1 and shape[0] == -1:
49+
return x.__class__(x.layout_tensor, (x.numel(),), x.dtype)
50+
51+
raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]")
52+
53+
@implements(aten.t.default)
54+
def _(func, types, args, kwargs):
55+
tensor = args[0]
56+
shape = tensor.shape[::-1]
57+
new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype)
58+
return return_and_correct_aliasing(func, args, kwargs, new)
59+
60+
@implements(aten.addmm.default)
61+
def _(func, types, args, kwargs):
62+
input_tensor, weight_tensor, bias = (
63+
args[1],
64+
args[2],
65+
args[0],
66+
)
67+
transposed = weight_tensor.layout_tensor.transposed
68+
weight_tensor = weight_tensor.dequantize()
69+
if transposed:
70+
weight_tensor = weight_tensor.t()
71+
return aten.addmm(input_tensor, weight_tensor, bias)
72+
73+
@implements(aten.mm.default)
74+
def _(func, types, args, kwargs):
75+
input_tensor, weight_tensor, bias = (
76+
args[0],
77+
args[1],
78+
None
79+
)
80+
transposed = weight_tensor.layout_tensor.transposed
81+
weight_tensor = weight_tensor.dequantize()
82+
if transposed:
83+
weight_tensor = weight_tensor.t()
84+
return aten.mm(input_tensor, weight_tensor)
6285

6386

6487
class M(torch.nn.Module):
6588
def __init__(self, *args, **kwargs) -> None:
6689
super().__init__(*args, **kwargs)
67-
self.linear = torch.nn.Linear(1024, 1024)
90+
self.linear = torch.nn.Linear(1024, 1024, bias=False, device="cuda")
6891

6992
def forward(self, x: torch.Tensor) -> torch.Tensor:
7093
return self.linear(x)
@@ -79,7 +102,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
79102
torch.manual_seed(5)
80103

81104
m = M()
82-
example_input = 100 * torch.randn(128, 1024)
105+
example_input = 100 * torch.randn(128, 1024, device="cuda")
83106
m(example_input)
84107

85108

@@ -103,7 +126,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
103126
quantized_shard = quantized_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
104127
print("quantized shard:", quantized_shard)
105128
# Construct DTensor from local shard
106-
quantized_dtensor = DTensor.from_local(quantized_shard, device_mesh, [Shard(0)])
129+
quantized_dtensor = DTensor.from_local(quantized_shard, mesh, [Shard(0)])
107130
print("quantized dtensor:", quantized_dtensor)
108131

109132
# Replace parameter in module
@@ -117,4 +140,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
117140
)
118141
print("input dtensor:", input_dtensor)
119142

120-
m(input_dtensor)
143+
print("result:", m(input_dtensor))

0 commit comments

Comments
 (0)