3636 _quantize_affine_float8 ,
3737 choose_qparams_affine ,
3838)
39+ from torchao .quantization .quantize_ .common import KernelPreference
3940from torchao .utils import (
4041 is_sm_at_least_89 ,
4142 is_sm_at_least_90 ,
@@ -732,20 +733,13 @@ def test_preprocess_scale_3d_reshape(self):
732733 self .assertEqual (result .shape , expected_shape )
733734
734735 @torch .no_grad ()
735- @unittest .skip ("test is flaky in CI, will turn on a bit later" )
736736 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
737737 @unittest .skipIf (
738738 not is_sm_at_least_90 (), "Requires GPU with compute capability >= 9.0"
739739 )
740740 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
741- @common_utils .parametrize (
742- "torch_compile_mode" ,
743- [
744- "default" ,
745- "reduce-overhead" ,
746- ],
747- )
748- def test_expected_kernels_on_gpu (self , granularity , torch_compile_mode ):
741+ @common_utils .parametrize ("float8_config_version" , [1 , 2 ])
742+ def test_expected_kernels_on_gpu (self , granularity , float8_config_version ):
749743 """
750744 Verify that float8 quantization + torch.compile results in the
751745 expected number of kernels in the GPU trace.
@@ -756,14 +750,23 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
756750 m = torch .nn .Sequential (
757751 torch .nn .Linear (K , N , device = "cuda" , dtype = torch .bfloat16 )
758752 )
753+ if float8_config_version == 1 :
754+ config = Float8DynamicActivationFloat8WeightConfig (
755+ granularity = granularity , version = 1
756+ )
757+ else :
758+ assert float8_config_version == 2
759+ config = Float8DynamicActivationFloat8WeightConfig (
760+ granularity = granularity ,
761+ version = 2 ,
762+ kernel_preference = KernelPreference .TORCH ,
763+ )
759764 quantize_ (
760765 m ,
761- Float8DynamicActivationFloat8WeightConfig (
762- granularity = granularity , version = 1
763- ),
766+ config ,
764767 )
765768
766- m = torch .compile (m , mode = torch_compile_mode )
769+ m = torch .compile (m , mode = "default" )
767770 x = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
768771
769772 # warm up
@@ -779,34 +782,16 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
779782 # kernel 2: x_max = max(x_max_tmp)
780783 # kernel 3: x_float8 = to_float8(x, x_max)
781784 # kernel 4: gemm
782- if torch_compile_mode == "default" :
783- assert len (cuda_kernel_events ) == 4 , (
784- f"too many cuda kernels: { cuda_kernel_events } "
785- )
786- elif torch_compile_mode == "reduce-overhead" :
787- # two extra kernels with reduce-overhead:
788- # void at::native::(anonymous namespace)::multi_tensor...
789- # void at::native::vectorized_elementwise_kernel<2, at...
790- # TODO(future): debug and remove these
791- assert len (cuda_kernel_events ) == 6 , (
792- f"too many cuda kernels: { cuda_kernel_events } "
793- )
785+ assert len (cuda_kernel_events ) == 4 , (
786+ f"too many cuda kernels: { cuda_kernel_events } "
787+ )
794788 else :
795789 assert granularity == PerRow ()
796790 # kernel 1: x_float8 = to_float8(x)
797791 # kernel 2: gemm
798- if torch_compile_mode == "default" :
799- assert len (cuda_kernel_events ) == 2 , (
800- f"too many cuda kernels: { cuda_kernel_events } "
801- )
802- elif torch_compile_mode == "reduce-overhead" :
803- # two extra kernels with reduce-overhead:
804- # void at::native::(anonymous namespace)::multi_tensor...
805- # void at::native::vectorized_elementwise_kernel<2, at...
806- # TODO(future): debug and remove these
807- assert len (cuda_kernel_events ) == 4 , (
808- f"too many cuda kernels: { cuda_kernel_events } "
809- )
792+ assert len (cuda_kernel_events ) == 2 , (
793+ f"too many cuda kernels: { cuda_kernel_events } "
794+ )
810795
811796
812797common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
0 commit comments