Skip to content

Commit ff68bb3

Browse files
committed
add some bug fixes
1 parent 3537a22 commit ff68bb3

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,9 @@ def _(func, types, args, kwargs):
456456
def _(func, types, args, kwargs):
457457
self = args[0]
458458
src = args[1]
459+
if type(self) is torch.Tensor and isinstance(src, AffineQuantizedTensor):
460+
func(self, src.dequantize())
461+
return
459462
if _same_metadata(self, src):
460463
self_tensors = self.__tensor_flatten__()[0]
461464
for tensor_name in self_tensors:

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def quantize_(
485485
or ("_default" in config.fqn_to_config and _is_linear(module))
486486
):
487487
module_name = (
488-
module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn
488+
module_fqn.rsplit(".", 1)[0] if "." in module_fqn else module_fqn
489489
)
490490
# this replaces inplace, so no need to reassign
491491
_fqn_to_config_handler(module, module_name, config)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,14 @@ def from_hp(
202202
else:
203203
maybe_hp_value_ub_tensor = None
204204
if isinstance(granularity, PerRow):
205-
data, scale = torch.ops.triton.quantize_fp8_row(
206-
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
207-
)
208-
scale_shape = []
209-
for i in range(hp_tensor.ndim):
210-
scale_shape.append(hp_tensor.shape[i] // block_size[i])
211-
scale = scale.reshape(*scale_shape)
205+
with torch.cuda.device(hp_tensor.device):
206+
data, scale = torch.ops.triton.quantize_fp8_row(
207+
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
208+
)
209+
scale_shape = []
210+
for i in range(hp_tensor.ndim):
211+
scale_shape.append(hp_tensor.shape[i] // block_size[i])
212+
scale = scale.reshape(*scale_shape)
212213
else:
213214
assert isinstance(granularity, PerTensor), (
214215
f"Expected per tensor, got {granularity}"

torchao/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,9 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
571571
def _(func, types, args, kwargs):
572572
self = args[0]
573573
src = args[1]
574+
if type(self) is torch.Tensor and isinstance(src, TorchAOBaseTensor):
575+
func(self, src.dequantize())
576+
return
574577
if _same_metadata(self, src):
575578
self_tensors = self.__tensor_flatten__()[0]
576579
for tensor_name in self_tensors:

0 commit comments

Comments
 (0)