@@ -57,10 +57,8 @@ def _(func, types, args, kwargs):
57
57
@implements (aten .t .default )
58
58
def _ (func , types , args , kwargs ):
59
59
tensor = args [0 ]
60
- print ("before transpose, " , tensor .shape )
61
60
shape = tensor .shape [::- 1 ]
62
61
new = tensor .__class__ (tensor .layout_tensor .t (), shape , tensor .dtype )
63
- print ("after transpose:" , new .shape )
64
62
return return_and_correct_aliasing (func , args , kwargs , new )
65
63
66
64
@implements (aten .addmm .default )
@@ -80,8 +78,7 @@ def _(func, types, args, kwargs):
80
78
args [1 ],
81
79
None
82
80
)
83
- print ("mm input tensor shape:" , input_tensor .shape )
84
- print ("mm weight tensor shape:" , weight_tensor .shape )
81
+ print ("mm weight transposed:" , weight_tensor .layout_tensor .transposed )
85
82
weight_tensor = weight_tensor .dequantize ()
86
83
return aten .mm (input_tensor , weight_tensor )
87
84
@@ -172,6 +169,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
172
169
173
170
# Shard the models
174
171
d_up = colwise_shard (q_up , mesh )
172
+ print ("d_up weight shape:" , d_up .linear .weight .shape )
175
173
d_dn = rowwise_shard (q_dn , mesh )
176
174
177
175
# We need to turn inputs into DTensor form as well -- just a format change
@@ -188,10 +186,10 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
188
186
# [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
189
187
# 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
190
188
# [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
191
- c_up = torch .compile (d_up , backend = "eager" )
189
+ c_up = torch .compile (d_up )
192
190
y_up = c_up (input_dtensor )
193
191
print ("y_up:" , y_up .shape )
194
- c_dn = torch .compile (d_dn , backend = "eager" )
192
+ c_dn = torch .compile (d_dn )
195
193
y_dn = c_dn (y_up )
196
194
print ("y_dn:" , y_dn .shape )
197
195
print ("compiled result:" , y_dn )
0 commit comments