|
40 | 40 | except RuntimeError: |
41 | 41 | pytest.skip("torchao.ops not available") |
42 | 42 |
|
43 | | -from torchao.quantization import PerGroup, PerRow, PerTensor |
44 | | -from torchao.quantization.quant_primitives import ( |
45 | | - _choose_scale_float8, |
46 | | - _dequantize_affine_float8, |
47 | | - _quantize_affine_float8, |
48 | | -) |
49 | 43 | from torchao.quantization.utils import ( |
50 | | - get_block_size, |
51 | 44 | get_groupwise_affine_qparams, |
52 | 45 | groupwise_affine_dequantize_tensor_from_qparams, |
53 | 46 | groupwise_affine_quantize_tensor_from_qparams, |
@@ -908,91 +901,5 @@ def _test_scaled_embedding_bag_cpu_helper( |
908 | 901 | torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5) |
909 | 902 |
|
910 | 903 |
|
911 | | -@pytest.mark.skipif( |
912 | | - "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), |
913 | | - reason="cpp kernels not built", |
914 | | -) |
915 | | -@pytest.mark.parametrize( |
916 | | - "multi_hot, batch_size, vector_size, index_type", |
917 | | - EMBEDINGBAG_TEST_PARAMS, |
918 | | - ids=str, |
919 | | -) |
920 | | -def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type): |
921 | | - _test_scaled_embedding_bag_cpu_helper( |
922 | | - multi_hot, batch_size, vector_size, index_type, torch.int8 |
923 | | - ) |
924 | | - |
925 | | - |
926 | | -@pytest.mark.skipif( |
927 | | - "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), |
928 | | - reason="cpp kernels not built", |
929 | | -) |
930 | | -@pytest.mark.parametrize( |
931 | | - "multi_hot, batch_size, vector_size, index_type", |
932 | | - EMBEDINGBAG_TEST_PARAMS, |
933 | | - ids=str, |
934 | | -) |
935 | | -def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type): |
936 | | - _test_scaled_embedding_bag_cpu_helper( |
937 | | - multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn |
938 | | - ) |
939 | | - |
940 | | - |
941 | | -@pytest.mark.skipif( |
942 | | - "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu") |
943 | | - or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), |
944 | | - reason="cpp kernels not built", |
945 | | -) |
946 | | -@pytest.mark.skipif( |
947 | | - not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+" |
948 | | -) |
949 | | -@pytest.mark.parametrize("shape", [(64, 64), (256, 256)]) |
950 | | -@pytest.mark.parametrize("bs", [1, 160]) |
951 | | -@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half]) |
952 | | -@pytest.mark.parametrize("bias", [True, False]) |
953 | | -@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)]) |
954 | | -@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)]) |
955 | | -def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity): |
956 | | - in_feature, out_feature = shape |
957 | | - if isinstance(x_granularity, PerGroup): |
958 | | - if x_granularity.group_size >= in_feature: |
959 | | - return |
960 | | - if not isinstance(w_granularity, PerGroup): |
961 | | - return |
962 | | - if isinstance(w_granularity, PerGroup): |
963 | | - if w_granularity.group_size >= in_feature: |
964 | | - return |
965 | | - m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval() |
966 | | - b = m.bias |
967 | | - x = torch.randn(bs, in_feature) |
968 | | - x_block_size = get_block_size(x.shape, x_granularity) |
969 | | - x_scale = _choose_scale_float8( |
970 | | - x, |
971 | | - float8_dtype=torch.float8_e4m3fn, |
972 | | - block_size=x_block_size, |
973 | | - ) |
974 | | - x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn) |
975 | | - |
976 | | - w = m.weight.detach() |
977 | | - w_block_size = get_block_size(w.shape, w_granularity) |
978 | | - w_scale = _choose_scale_float8( |
979 | | - w, |
980 | | - float8_dtype=torch.float8_e4m3fn, |
981 | | - block_size=w_block_size, |
982 | | - ) |
983 | | - w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn) |
984 | | - |
985 | | - x_dq = _dequantize_affine_float8(x_fp8, x_scale) |
986 | | - w_dq = _dequantize_affine_float8(w_fp8, w_scale) |
987 | | - ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype) |
988 | | - |
989 | | - packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale) |
990 | | - y = torch.ops.torchao.float8_linear_cpu( |
991 | | - x_fp8, x_scale, packed_w, packed_scale, b, out_dtype |
992 | | - ) |
993 | | - |
994 | | - torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2) |
995 | | - |
996 | | - |
997 | 904 | if __name__ == "__main__": |
998 | 905 | pytest.main(sys.argv) |
0 commit comments