Skip to content

Commit d396c2b

Browse files
committed
change transpose to non-inplace op
1 parent a76ef4a commit d396c2b

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def dequantize(self, output_dtype=None):
194194
block_size = (1, int_data.shape[-1])
195195
if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed:
196196
transposed = True
197-
res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype)
197+
res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype)
198198
if transposed:
199199
res = res.t()
200200
return res
@@ -327,8 +327,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
327327
elif dim == 1:
328328
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
329329
elif func is aten.t.default:
330-
args[0].transposed = not args[0].transposed
331-
return return_and_correct_aliasing(func, args, kwargs, args[0])
330+
return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type))
332331

333332
raise NotImplementedError(
334333
f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported"

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ def _(func, types, args, kwargs):
5757
@implements(aten.t.default)
5858
def _(func, types, args, kwargs):
5959
tensor = args[0]
60-
print("before transpose, ", tensor.shape)
6160
shape = tensor.shape[::-1]
6261
new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype)
63-
print("after transpose:", new.shape)
6462
return return_and_correct_aliasing(func, args, kwargs, new)
6563

6664
@implements(aten.addmm.default)
@@ -80,8 +78,7 @@ def _(func, types, args, kwargs):
8078
args[1],
8179
None
8280
)
83-
print("mm input tensor shape:", input_tensor.shape)
84-
print("mm weight tensor shape:", weight_tensor.shape)
81+
print("mm weight transposed:", weight_tensor.layout_tensor.transposed)
8582
weight_tensor = weight_tensor.dequantize()
8683
return aten.mm(input_tensor, weight_tensor)
8784

@@ -172,6 +169,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
172169

173170
# Shard the models
174171
d_up = colwise_shard(q_up, mesh)
172+
print("d_up weight shape:", d_up.linear.weight.shape)
175173
d_dn = rowwise_shard(q_dn, mesh)
176174

177175
# We need to turn inputs into DTensor form as well -- just a format change
@@ -188,10 +186,10 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
188186
# [rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
189187
# 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(128, 1024)), shape=torch.Size([1024, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=False), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
190188
# [rank0]: a and b must have same reduction dim, but got [128, 1024] X [128, 1024].
191-
c_up = torch.compile(d_up, backend="eager")
189+
c_up = torch.compile(d_up)
192190
y_up = c_up(input_dtensor)
193191
print("y_up:", y_up.shape)
194-
c_dn = torch.compile(d_dn, backend="eager")
192+
c_dn = torch.compile(d_dn)
195193
y_dn = c_dn(y_up)
196194
print("y_dn:", y_dn.shape)
197195
print("compiled result:", y_dn)

0 commit comments

Comments
 (0)