Skip to content

Commit 5750d7f

Browse files
committed
implementing aten.slice.Tensor (WIP)
1 parent 9394b93 commit 5750d7f

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,30 @@ class MyDTypeTensorTP(MyDTypeTensor):
1010

1111
aten = torch.ops.aten
1212

13+
def fill_defaults(args, n, defaults_tail):
14+
"""
15+
__torch_dispatch__ doesn't guarantee the number of arguments you are
16+
passed (e.g., defaulted arguments are not passed); but usually it is
17+
convenient to pad out the arguments list with defaults. This function
18+
helps you do that.
19+
Args:
20+
args: the list of positional arguments passed to __torch_dispatch__
21+
n: the number of arguments you are expecting to get
22+
defaults_tail: default values for the arguments, starting from the
23+
end of the list
24+
Example:
25+
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
26+
[1, 2, 3, 4, 5]
27+
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
28+
[1, 2, 3, None, None]]
29+
"""
30+
if n - len(defaults_tail) > len(args):
31+
raise RuntimeError("not enough defaults to fill arguments")
32+
r = list(args)
33+
for i in range(len(args), n):
34+
r.append(defaults_tail[i - n + len(defaults_tail)])
35+
return r
36+
1337
@implements([aten._to_copy.default, aten.clone.default])
1438
def _(func, types, args, kwargs):
1539
return return_and_correct_aliasing(
@@ -27,6 +51,16 @@ def _(func, types, args, kwargs):
2751
empty_like_layout_tensor = func(args[0].layout_tensor, *args[1:], **kwargs)
2852
return MyDTypeTensorTP(empty_like_layout_tensor, empty_like_layout_tensor.shape)
2953

54+
@implements([aten.slice.Tensor])
55+
def _(func, types, args, kwargs):
56+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
57+
print("slice:", dim, start, end, step)
58+
if dim == 0:
59+
assert step == 1
60+
return self.__class__(aten.slice.Tensor(self.layout_tensor), (end - start + 1,) + self.shape[1:], self.dtype)
61+
return
62+
63+
3064
class M(torch.nn.Module):
3165
def __init__(self, *args, **kwargs) -> None:
3266
super().__init__(*args, **kwargs)
@@ -84,4 +118,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
84118
print("input dtensor:", input_dtensor)
85119

86120
m(input_dtensor)
87-

0 commit comments

Comments
 (0)