Skip to content

Commit 083361b

Browse files
authored
turn float8 inference kernel check test back on (#2808)
Update [ghstack-poisoned]
1 parent 9473060 commit 083361b

File tree

1 file changed

+22
-37
lines changed

1 file changed

+22
-37
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
_quantize_affine_float8,
3737
choose_qparams_affine,
3838
)
39+
from torchao.quantization.quantize_.common import KernelPreference
3940
from torchao.utils import (
4041
is_sm_at_least_89,
4142
is_sm_at_least_90,
@@ -732,20 +733,13 @@ def test_preprocess_scale_3d_reshape(self):
732733
self.assertEqual(result.shape, expected_shape)
733734

734735
@torch.no_grad()
735-
@unittest.skip("test is flaky in CI, will turn on a bit later")
736736
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
737737
@unittest.skipIf(
738738
not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0"
739739
)
740740
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
741-
@common_utils.parametrize(
742-
"torch_compile_mode",
743-
[
744-
"default",
745-
"reduce-overhead",
746-
],
747-
)
748-
def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
741+
@common_utils.parametrize("float8_config_version", [1, 2])
742+
def test_expected_kernels_on_gpu(self, granularity, float8_config_version):
749743
"""
750744
Verify that float8 quantization + torch.compile results in the
751745
expected number of kernels in the GPU trace.
@@ -756,14 +750,23 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
756750
m = torch.nn.Sequential(
757751
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
758752
)
753+
if float8_config_version == 1:
754+
config = Float8DynamicActivationFloat8WeightConfig(
755+
granularity=granularity, version=1
756+
)
757+
else:
758+
assert float8_config_version == 2
759+
config = Float8DynamicActivationFloat8WeightConfig(
760+
granularity=granularity,
761+
version=2,
762+
kernel_preference=KernelPreference.TORCH,
763+
)
759764
quantize_(
760765
m,
761-
Float8DynamicActivationFloat8WeightConfig(
762-
granularity=granularity, version=1
763-
),
766+
config,
764767
)
765768

766-
m = torch.compile(m, mode=torch_compile_mode)
769+
m = torch.compile(m, mode="default")
767770
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
768771

769772
# warm up
@@ -779,34 +782,16 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
779782
# kernel 2: x_max = max(x_max_tmp)
780783
# kernel 3: x_float8 = to_float8(x, x_max)
781784
# kernel 4: gemm
782-
if torch_compile_mode == "default":
783-
assert len(cuda_kernel_events) == 4, (
784-
f"too many cuda kernels: {cuda_kernel_events}"
785-
)
786-
elif torch_compile_mode == "reduce-overhead":
787-
# two extra kernels with reduce-overhead:
788-
# void at::native::(anonymous namespace)::multi_tensor...
789-
# void at::native::vectorized_elementwise_kernel<2, at...
790-
# TODO(future): debug and remove these
791-
assert len(cuda_kernel_events) == 6, (
792-
f"too many cuda kernels: {cuda_kernel_events}"
793-
)
785+
assert len(cuda_kernel_events) == 4, (
786+
f"too many cuda kernels: {cuda_kernel_events}"
787+
)
794788
else:
795789
assert granularity == PerRow()
796790
# kernel 1: x_float8 = to_float8(x)
797791
# kernel 2: gemm
798-
if torch_compile_mode == "default":
799-
assert len(cuda_kernel_events) == 2, (
800-
f"too many cuda kernels: {cuda_kernel_events}"
801-
)
802-
elif torch_compile_mode == "reduce-overhead":
803-
# two extra kernels with reduce-overhead:
804-
# void at::native::(anonymous namespace)::multi_tensor...
805-
# void at::native::vectorized_elementwise_kernel<2, at...
806-
# TODO(future): debug and remove these
807-
assert len(cuda_kernel_events) == 4, (
808-
f"too many cuda kernels: {cuda_kernel_events}"
809-
)
792+
assert len(cuda_kernel_events) == 2, (
793+
f"too many cuda kernels: {cuda_kernel_events}"
794+
)
810795

811796

812797
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

0 commit comments

Comments
 (0)