1
1
import torch
2
- from my_dtype_tensor_subclass import MyDTypeTensor
2
+ from my_dtype_tensor_subclass import MyDTypeTensor , fill_defaults
3
3
from torch .utils ._python_dispatch import return_and_correct_aliasing
4
4
5
5
# a tensor subclass that supports tensor parallelism with DTensor
@@ -10,30 +10,6 @@ class MyDTypeTensorTP(MyDTypeTensor):
10
10
11
11
aten = torch .ops .aten
12
12
13
- def fill_defaults (args , n , defaults_tail ):
14
- """
15
- __torch_dispatch__ doesn't guarantee the number of arguments you are
16
- passed (e.g., defaulted arguments are not passed); but usually it is
17
- convenient to pad out the arguments list with defaults. This function
18
- helps you do that.
19
- Args:
20
- args: the list of positional arguments passed to __torch_dispatch__
21
- n: the number of arguments you are expecting to get
22
- defaults_tail: default values for the arguments, starting from the
23
- end of the list
24
- Example:
25
- >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
26
- [1, 2, 3, 4, 5]
27
- >>> fill_defaults([1, 2, 3], 5, [None, None, None])
28
- [1, 2, 3, None, None]]
29
- """
30
- if n - len (defaults_tail ) > len (args ):
31
- raise RuntimeError ("not enough defaults to fill arguments" )
32
- r = list (args )
33
- for i in range (len (args ), n ):
34
- r .append (defaults_tail [i - n + len (defaults_tail )])
35
- return r
36
-
37
13
@implements ([aten ._to_copy .default , aten .clone .default ])
38
14
def _ (func , types , args , kwargs ):
39
15
return return_and_correct_aliasing (
@@ -51,20 +27,67 @@ def _(func, types, args, kwargs):
51
27
empty_like_layout_tensor = func (args [0 ].layout_tensor , * args [1 :], ** kwargs )
52
28
return MyDTypeTensorTP (empty_like_layout_tensor , empty_like_layout_tensor .shape )
53
29
54
- @implements ([ aten .slice .Tensor ] )
30
+ @implements (aten .slice .Tensor )
55
31
def _ (func , types , args , kwargs ):
56
32
self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
57
- print ("slice:" , dim , start , end , step )
58
- if dim == 0 :
59
- assert step == 1
60
- return self .__class__ (aten .slice .Tensor (self .layout_tensor ), (end - start + 1 ,) + self .shape [1 :], self .dtype )
61
- return
33
+ assert step == 1
34
+ if end >= self .shape [dim ]:
35
+ end = self .shape [dim ]
36
+ print ("dim:" , dim , "start:" , start , " end:" , end , " shape:" , end - start )
37
+ print ("manual shape:" , (end - start ,) + self .shape [1 :])
38
+ return self .__class__ (aten .slice .Tensor (self .layout_tensor , dim , start , end , step ), (end - start ,) + self .shape [1 :], self .dtype )
39
+
40
+ # this is needed for DTensor.from_local() and for flattening tensor
41
+ @implements (aten .view .default )
42
+ def _ (func , types , args , kwargs ):
43
+ x , shape = args
44
+
45
+ if tuple (x .shape ) == tuple (shape ):
46
+ return x .__class__ (x .layout_tensor , x .shape , x .dtype )
47
+
48
+ if len (shape ) == 1 and shape [0 ] == - 1 :
49
+ return x .__class__ (x .layout_tensor , (x .numel (),), x .dtype )
50
+
51
+ raise ValueError (f"{ x .__class__ .__name__ } only supports .view() with same shape or shape=[-1]" )
52
+
53
+ @implements (aten .t .default )
54
+ def _ (func , types , args , kwargs ):
55
+ tensor = args [0 ]
56
+ shape = tensor .shape [::- 1 ]
57
+ new = tensor .__class__ (tensor .layout_tensor .t (), shape , tensor .dtype )
58
+ return return_and_correct_aliasing (func , args , kwargs , new )
59
+
60
+ @implements (aten .addmm .default )
61
+ def _ (func , types , args , kwargs ):
62
+ input_tensor , weight_tensor , bias = (
63
+ args [1 ],
64
+ args [2 ],
65
+ args [0 ],
66
+ )
67
+ transposed = weight_tensor .layout_tensor .transposed
68
+ weight_tensor = weight_tensor .dequantize ()
69
+ if transposed :
70
+ weight_tensor = weight_tensor .t ()
71
+ return aten .addmm (input_tensor , weight_tensor , bias )
72
+
73
+ @implements (aten .mm .default )
74
+ def _ (func , types , args , kwargs ):
75
+ input_tensor , weight_tensor , bias = (
76
+ args [0 ],
77
+ args [1 ],
78
+ None
79
+ )
80
+ transposed = weight_tensor .layout_tensor .transposed
81
+ weight_tensor = weight_tensor .dequantize ()
82
+ if transposed :
83
+ weight_tensor = weight_tensor .t ()
84
+ return aten .mm (input_tensor , weight_tensor )
62
85
63
86
64
87
class M (torch .nn .Module ):
65
88
def __init__ (self , * args , ** kwargs ) -> None :
66
89
super ().__init__ (* args , ** kwargs )
67
- self .linear = torch .nn .Linear (1024 , 1024 )
90
+ self .linear = torch .nn .Linear (1024 , 1024 , bias = False , device = "cuda" )
68
91
69
92
def forward (self , x : torch .Tensor ) -> torch .Tensor :
70
93
return self .linear (x )
@@ -79,7 +102,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
102
torch .manual_seed (5 )
80
103
81
104
m = M ()
82
- example_input = 100 * torch .randn (128 , 1024 )
105
+ example_input = 100 * torch .randn (128 , 1024 , device = "cuda" )
83
106
m (example_input )
84
107
85
108
@@ -103,7 +126,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
126
quantized_shard = quantized_weight [rank * n_local_rows : (rank + 1 ) * n_local_rows , :]
104
127
print ("quantized shard:" , quantized_shard )
105
128
# Construct DTensor from local shard
106
- quantized_dtensor = DTensor .from_local (quantized_shard , device_mesh , [Shard (0 )])
129
+ quantized_dtensor = DTensor .from_local (quantized_shard , mesh , [Shard (0 )])
107
130
print ("quantized dtensor:" , quantized_dtensor )
108
131
109
132
# Replace parameter in module
@@ -117,4 +140,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
140
)
118
141
print ("input dtensor:" , input_dtensor )
119
142
120
- m (input_dtensor )
143
+ print ( "result:" , m (input_dtensor ) )
0 commit comments