Skip to content

Commit f3705e5

Browse files
committed
fix test to add in layout kwarg
1 parent 8f3bdcd commit f3705e5

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def test_fp8_cutlass_sparse_lowering_op_to(self):
180180
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())
181181

182182
original = torch.ops.aten.to.dtype_layout(
183-
model.weight.original_weight_tensor.tensor_impl, dtype=torch.float
183+
model.weight.original_weight_tensor.tensor_impl,
184+
dtype=torch.float,
185+
layout=torch.strided,
184186
)
185187
torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1)
186188

torchao/dtypes/floatx/cutlass_semi_sparse_layout.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
106106
)
107107
elif func is aten.to.dtype_layout:
108108
dense, scale, _ = args[0].get_plain()
109-
dense = dense.to(*args[1:], **kwargs)
109+
dense = dense.to(
110+
*args[1:],
111+
dtype=kwargs.get("dtype", dense.dtype),
112+
device=kwargs.get("device", dense.device),
113+
)
110114
return scale * dense
111115

112116
raise NotImplementedError(

0 commit comments

Comments
 (0)