We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e6cf304 commit 49b4e0fCopy full SHA for 49b4e0f
torchao/testing/utils.py
@@ -309,6 +309,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
309
y_q = dn_quant(up_quant(example_input))
310
311
mesh = self.build_device_mesh()
312
+ mesh.device_type = "cuda"
313
+
314
# Shard the models
315
up_dist = self.colwise_shard(up_quant, mesh)
316
dn_dist = self.rowwise_shard(dn_quant, mesh)
0 commit comments