Skip to content

Commit 5277507

Browse files
authored
[Tp Test] Fixe the placment of the device tensor (#1054)
1 parent ec860a1 commit 5277507

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torchao/testing/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
309309
y_q = dn_quant(up_quant(example_input))
310310

311311
mesh = self.build_device_mesh()
312+
mesh.device_type = "cuda"
313+
312314
# Shard the models
313315
up_dist = self.colwise_shard(up_quant, mesh)
314316
dn_dist = self.rowwise_shard(dn_quant, mesh)

0 commit comments

Comments
 (0)