Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,48 @@ def test_preprocess_scale_3d_reshape(self):
expected_shape = (8, 1) # Flattened (2*2*2, 1)
self.assertEqual(result.shape, expected_shape)

@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@common_utils.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype):
quantize_affine_float8 = torch.ops.torchao.quantize_affine_float8_non_decomposed
dequantize_affine_float8 = (
torch.ops.torchao.dequantize_affine_float8_non_decomposed
)
input = torch.randn(10, 10)
with torch.no_grad():
torch._dynamo.reset()
expected_scale = torch.tensor(2.0)
expected_quantized = quantize_affine_float8(
input,
expected_scale,
float8_dtype=float8_dtype,
)
expected_dequantized = dequantize_affine_float8(
expected_quantized,
expected_scale,
output_dtype=hp_dtype,
)
test_q, (code_q,) = torch._inductor.utils.run_and_get_code(
torch.compile(quantize_affine_float8),
input,
expected_scale,
float8_dtype=float8_dtype,
)
torch.testing.FileCheck().check(f"{quantize_affine_float8}.default").run(
code_q
)
test_dq, (code_dq,) = torch._inductor.utils.run_and_get_code(
torch.compile(dequantize_affine_float8),
test_q,
expected_scale,
hp_dtype,
)
torch.testing.FileCheck().check(f"{dequantize_affine_float8}.default").run(
code_dq
)
torch.testing.assert_close(expected_quantized, test_q)
torch.testing.assert_close(expected_dequantized, test_dq)

@torch.no_grad()
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
Expand Down
45 changes: 42 additions & 3 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,8 +2310,6 @@ def _quantize_affine_float8(
return _RoundToFloat8.apply(tensor_clamped, float8_dtype)


# TODO: don't register as custom op?
@_register_custom_op(quant_lib, False)
def _dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
Expand All @@ -2329,7 +2327,48 @@ def _dequantize_affine_float8(
return hp_tensor.to(output_dtype)


@_register_meta_op(quant_lib, "dequantize_affine_float8")
@_register_custom_op(quant_lib, False)
def _quantize_affine_float8_non_decomposed(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
"""
return _quantize_affine_float8(
tensor=tensor,
scale=scale,
float8_dtype=float8_dtype,
)


@_register_meta_op(quant_lib, "quantize_affine_float8_non_decomposed")
def _quantize_affine_float8_meta(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
return torch.empty_like(tensor, dtype=float8_dtype)


@_register_custom_op(quant_lib, False)
def _dequantize_affine_float8_non_decomposed(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to high precision tensor.
"""
return _dequantize_affine_float8(
tensor=tensor,
scale=scale,
output_dtype=output_dtype,
)


@_register_meta_op(quant_lib, "dequantize_affine_float8_non_decomposed")
def _dequantize_affine_float8_meta(
tensor: torch.Tensor,
scale: torch.Tensor,
Expand Down
Loading