diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 66a60c1f56..0457cba149 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -351,7 +351,13 @@ def groupwise_affine_quantize_tensor_from_qparams( int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) if TORCH_VERSION_AFTER_2_5: + int_data_device_type = int_data.device.type + # Move to cpu, until issue with MPS memory management of temporary tensors is resolved + if int_data_device_type == 'mps': + int_data = int_data.cpu() int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + if int_data_device_type == 'mps': + int_data = int_data.to(device='mps') return int_data def groupwise_affine_dequantize_tensor_from_qparams(