@@ -10,6 +10,30 @@ class MyDTypeTensorTP(MyDTypeTensor):
10
10
11
11
aten = torch .ops .aten
12
12
13
+ def fill_defaults (args , n , defaults_tail ):
14
+ """
15
+ __torch_dispatch__ doesn't guarantee the number of arguments you are
16
+ passed (e.g., defaulted arguments are not passed); but usually it is
17
+ convenient to pad out the arguments list with defaults. This function
18
+ helps you do that.
19
+ Args:
20
+ args: the list of positional arguments passed to __torch_dispatch__
21
+ n: the number of arguments you are expecting to get
22
+ defaults_tail: default values for the arguments, starting from the
23
+ end of the list
24
+ Example:
25
+ >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
26
+ [1, 2, 3, 4, 5]
27
+ >>> fill_defaults([1, 2, 3], 5, [None, None, None])
28
+ [1, 2, 3, None, None]]
29
+ """
30
+ if n - len (defaults_tail ) > len (args ):
31
+ raise RuntimeError ("not enough defaults to fill arguments" )
32
+ r = list (args )
33
+ for i in range (len (args ), n ):
34
+ r .append (defaults_tail [i - n + len (defaults_tail )])
35
+ return r
36
+
13
37
@implements ([aten ._to_copy .default , aten .clone .default ])
14
38
def _ (func , types , args , kwargs ):
15
39
return return_and_correct_aliasing (
@@ -27,6 +51,16 @@ def _(func, types, args, kwargs):
27
51
empty_like_layout_tensor = func (args [0 ].layout_tensor , * args [1 :], ** kwargs )
28
52
return MyDTypeTensorTP (empty_like_layout_tensor , empty_like_layout_tensor .shape )
29
53
54
+ @implements ([aten .slice .Tensor ])
55
+ def _ (func , types , args , kwargs ):
56
+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
57
+ print ("slice:" , dim , start , end , step )
58
+ if dim == 0 :
59
+ assert step == 1
60
+ return self .__class__ (aten .slice .Tensor (self .layout_tensor ), (end - start + 1 ,) + self .shape [1 :], self .dtype )
61
+ return
62
+
63
+
30
64
class M (torch .nn .Module ):
31
65
def __init__ (self , * args , ** kwargs ) -> None :
32
66
super ().__init__ (* args , ** kwargs )
@@ -84,4 +118,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
118
print ("input dtensor:" , input_dtensor )
85
119
86
120
m (input_dtensor )
87
-
0 commit comments