Skip to content

Commit c862f80

Browse files
RandySherifffacebook-github-bot
authored andcommitted
Minor fix on TAO op to support lowering
Summary: Fix a few functionality gaps to let TAO operator works during AOTI lowering ~ Differential Revision: D82492826
1 parent 18dbe87 commit c862f80

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

torchao/dtypes/floatx/cutlass_semi_sparse_layout.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
106106
)
107107
elif func is aten.to.dtype_layout:
108108
dense, scale, _ = args[0].get_plain()
109-
dense = dense.to(
109+
product = dense.to(torch.float) * scale.to(torch.float)
110+
return product.to(
110111
*args[1:],
111112
dtype=kwargs.get("dtype", dense.dtype),
112113
device=kwargs.get("device", dense.device),
113114
)
114-
return scale * dense
115115

116116
raise NotImplementedError(
117117
f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported"
@@ -135,11 +135,14 @@ def get_plain(self):
135135
# semi-structured format, so multiplying with identity matrix,
136136
# and using identity scale factors, for the conversion.
137137
cols = self.shape[1]
138-
input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device)
139-
input_scale = torch.ones(
140-
(cols,), dtype=self.scale.dtype, device=self.sparse.device
138+
input_fp16 = torch.eye(cols, dtype=torch.float16, device=self.sparse.device)
139+
input = input_fp16.to(dtype=self.sparse.dtype)
140+
input_scale_fp16 = torch.ones(
141+
(cols,), dtype=torch.float16, device=self.sparse.device
141142
)
143+
input_scale = input_scale_fp16.to(dtype=self.scale.dtype)
142144
sparse_scale = torch.ones_like(self.scale)
145+
143146
out_dtype = torch.bfloat16
144147
dense = (
145148
rowwise_scaled_linear_sparse_cutlass_f8f8(

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,15 @@ def _same_metadata(
133133

134134
@implements([torch.nn.functional.linear, aten.linear.default])
135135
def _(func, types, args, kwargs):
136-
input_tensor, weight_tensor, bias = (
137-
args[0],
138-
args[1],
139-
args[2] if len(args) > 2 else None,
140-
)
136+
if len(args) > 1:
137+
input_tensor, weight_tensor, bias = (
138+
args[0],
139+
args[1],
140+
args[2] if len(args) > 2 else None,
141+
)
142+
else:
143+
input_tensor, weight_tensor, bias = kwargs["input"], kwargs["weight"], kwargs["bias"] if "bias" in kwargs else None
144+
141145
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
142146
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
143147

@@ -216,6 +220,11 @@ def _(func, types, args, kwargs):
216220
for tensor_name in self_tensors:
217221
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
218222
return
223+
elif type(self) is torch.Tensor and type(src) is LinearActivationQuantizedTensor:
224+
new_src = src.to(dtype=self.dtype, device=self.device)
225+
self.copy_(new_src)
226+
return
227+
219228
raise ValueError(
220229
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
221230
)

0 commit comments

Comments
 (0)