@@ -167,24 +167,28 @@ def get_dequantize_per_tensor_activation_pattern(
167167 KeywordArg ("w_dtype" ),
168168)
169169
170- dequantize_per_channel_to_bf16_weight_pattern = (
171- _may_generate_pattern_with_dtype_convert (
172- dequantize_per_channel_weight_pattern ,
170+ dequantize_fp8_weight_pattern = CallFunction (
171+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
172+ KeywordArg ("q_weight" ),
173+ KeywordArg ("w_scale" ),
174+ output_dtype = KeywordArg ("w_dtype" ),
175+ )
176+
177+ def get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ):
178+ return _may_generate_pattern_with_dtype_convert (
179+ dequant_wgt_pattern ,
173180 KeywordArg ("autocast_wgt_dtype" ),
174181 )
175- )
176182
177- dequantize_per_channel_clone_weight_pattern = CallFunction (
178- aten .clone .default ,
179- dequantize_per_channel_weight_pattern ,
180- memory_format = KeywordArg ("memory_format" ),
181- )
183+ def get_dequantize_clone_weight_pattern (dequant_wgt_pattern ):
184+ return CallFunction (
185+ aten .clone .default ,
186+ dequant_wgt_pattern ,
187+ memory_format = KeywordArg ("memory_format" ),
188+ )
182189
183- dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction (
184- aten .clone .default ,
185- dequantize_per_channel_to_bf16_weight_pattern ,
186- memory_format = KeywordArg ("memory_format" ),
187- )
190+ def get_dequantize_to_bf16_clone_weight_pattern (dequant_wgt_pattern ):
191+ return get_dequantize_clone_weight_pattern (get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ))
188192
189193
190194def get_qconv_pt2e_pattern (users = 1 ):
@@ -711,7 +715,7 @@ def _inner(match):
711715 return _inner
712716
713717
714- def _register_qconv_weight_prepack_pass (pattern , pass_number , dtype = torch .float32 ):
718+ def _register_qconv_weight_prepack_pass (pattern , pass_number , dtype = torch .float32 , is_fp8 = False ):
715719 @register_freezing_graph_pattern (
716720 pattern ,
717721 extra_check = _is_valid_dequant_conv_pattern (dtype ),
@@ -724,7 +728,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
724728 |
725729 dequant_per_tensor
726730 |
727- Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
731+ Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight
728732
729733 Insert weight prepack node and change the pattern to:
730734 int8 activation
@@ -747,7 +751,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
747751 )
748752
749753 if dtype == torch .float32 :
750- dequant_per_channel = (
754+ dequant = (
751755 clone_node .args [0 ] # type: ignore[union-attr]
752756 if has_clone_to_channel_last_node_in_pattern
753757 else conv_node .args [1 ]
@@ -758,25 +762,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
758762 if has_clone_to_channel_last_node_in_pattern
759763 else conv_node .args [1 ]
760764 )
761- dequant_per_channel = weight_to_bf16_node .args [0 ] # type: ignore[union-attr]
765+ dequant = weight_to_bf16_node .args [0 ] # type: ignore[union-attr]
762766
763- assert dequant_per_channel .target in [ # type: ignore[union-attr]
767+ assert dequant .target in [ # type: ignore[union-attr]
764768 quantized_decomposed .dequantize_per_channel .default ,
765769 torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
766770 ]
767771
768772 # Activation QParams
769773 qx , x_zp , x_scale = (
770774 kwargs ["x" ],
771- kwargs ["x_zp" ],
775+ kwargs ["x_zp" ] if "x_zp" in kwargs else None ,
772776 kwargs ["x_scale" ],
773777 )
774778
775779 # Weight QParams
776780 qw , w_scale , w_zp = (
777781 kwargs ["q_weight" ],
778782 kwargs ["w_scale" ],
779- kwargs ["w_zp" ],
783+ kwargs ["w_zp" ] if "w_zp" in kwargs else None ,
780784 )
781785
782786 # Conv Params
@@ -792,14 +796,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
792796 if has_free_symbols (x_shape ):
793797 # For dynamic shape case, we can't get activation shape ahead of runtime.
794798 x_shape = None
799+ if is_fp8 and w_scale .target is torch .ops .aten .full .default :
800+ with torch .utils ._python_dispatch ._disable_current_modes ():
801+ w_scale_tensor = torch .tensor ([w_scale .args [1 ]])
802+ match .graph .owning_module .register_buffer ("w_scale" , w_scale_tensor )
803+ w_scale = match .graph .create_node ("get_attr" , "w_scale" )
795804 graph = match .graph
796805 with graph .inserting_before (conv_node ):
797806 # Insert weight prepack node and the QConv node
798807 packed_weight_inputs = (
799808 qw ,
800809 w_scale ,
801- x_scale ,
802- x_zp ,
810+ x_scale . args [ 1 ] if is_fp8 and x_scale . target is torch . ops . aten . full . default else x_scale ,
811+ 0 ,
803812 stride ,
804813 padding ,
805814 dilation ,
@@ -830,9 +839,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
830839 [], # scalars
831840 "" , # algorithm
832841 )
833- new_conv_node = graph .call_function (
834- torch .ops .onednn .qconv_pointwise .default , args = new_args
835- )
842+ Node = torch .fx .node .Node
843+ # fp8 not need zp
844+ if isinstance (x_scale , Node ) and (isinstance (x_zp , Node ) or is_fp8 ):
845+ new_conv_node = graph .call_function (
846+ torch .ops .onednn .qconv_pointwise .tensor , args = new_args
847+ )
848+ else :
849+ new_conv_node = graph .call_function (
850+ torch .ops .onednn .qconv_pointwise .default , args = new_args
851+ )
836852 conv_node .replace_all_uses_with (new_conv_node )
837853 new_conv_node .meta .update (conv_node .meta )
838854
@@ -847,25 +863,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
847863 graph .erase_node (clone_node ) # type: ignore[arg-type]
848864 if dtype == torch .bfloat16 :
849865 graph .erase_node (weight_to_bf16_node ) # type: ignore[possibly-undefined, arg-type]
850- graph .erase_node (dequant_per_channel ) # type: ignore[arg-type]
866+ graph .erase_node (dequant ) # type: ignore[arg-type]
851867 counters ["inductor" ]["qconv_weight_prepack_matcher_count" ] += 1
852868 counters ["inductor" ]["qconv_weight_prepack_matcher_nodes" ] += len (
853869 match .nodes
854870 )
855871
856872
857873def _generate_dequant_convolution_node_pattern (
858- _dequant_per_channel_pattern , dtype = torch .float32
874+ _dequant_pattern , dtype = torch .float32 , is_fp8 = False
859875):
860876 assert dtype in [torch .float32 , torch .bfloat16 ]
861877 dequant_convolution_node_pattern = CallFunction (
862878 aten .convolution .default ,
863879 _may_generate_pattern_with_dtype_convert (
864- get_dequantize_per_tensor_activation_pattern (),
880+ get_dequantize_per_tensor_activation_pattern (is_fp8 = is_fp8 ),
865881 KeywordArg ("autocast_act_dtype" ),
866882 dtype == torch .bfloat16 ,
867883 ),
868- _dequant_per_channel_pattern ,
884+ _dequant_pattern ,
869885 KeywordArg ("b" ),
870886 KeywordArg ("stride" ),
871887 KeywordArg ("padding" ),
@@ -877,24 +893,30 @@ def _generate_dequant_convolution_node_pattern(
877893 return dequant_convolution_node_pattern
878894
879895
880- def _generate_qconv_weight_prepack_patterns (dtype = torch .float32 ):
896+ def _generate_qconv_weight_prepack_patterns (dtype = torch .float32 , is_fp8 = False ):
881897 assert dtype in [torch .float32 , torch .bfloat16 ]
898+ if is_fp8 :
899+ dequant_wgt_pattern = dequantize_fp8_weight_pattern
900+ else :
901+ dequant_wgt_pattern = dequantize_per_channel_weight_pattern
882902 return (
883903 _generate_dequant_convolution_node_pattern (
884- dequantize_per_channel_weight_pattern
904+ dequant_wgt_pattern
885905 if dtype == torch .float32
886- else dequantize_per_channel_to_bf16_weight_pattern ,
906+ else get_dequantize_to_bf16_weight_pattern ( dequant_wgt_pattern ) ,
887907 dtype ,
908+ is_fp8 = is_fp8 ,
888909 ),
889910 # There is another pattern due to the pass of convert_conv_weights_to_channels_last
890911 # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
891912 # Depend on some heuristics, it may or may not insert to(channel_last) node
892- # between convolution and dequant_per_channel node
913+ # between convolution and dequant node
893914 _generate_dequant_convolution_node_pattern (
894- dequantize_per_channel_clone_weight_pattern
915+ get_dequantize_clone_weight_pattern ( dequant_wgt_pattern )
895916 if dtype == torch .float32
896- else dequantize_per_channel_to_bf16_clone_weight_pattern ,
917+ else get_dequantize_to_bf16_clone_weight_pattern ( dequant_wgt_pattern ) ,
897918 dtype ,
919+ is_fp8 = is_fp8 ,
898920 ),
899921 )
900922
@@ -1302,12 +1324,7 @@ def _generate_qlinear_weight_prepack_patterns(
13021324 is_fp8 = False ,
13031325):
13041326 if is_fp8 :
1305- dequant_wgt_pattern = CallFunction (
1306- torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
1307- KeywordArg ("q_weight" ),
1308- KeywordArg ("w_scale" ),
1309- output_dtype = KeywordArg ("w_dtype" ),
1310- )
1327+ dequant_wgt_pattern = dequantize_fp8_weight_pattern
13111328 else :
13121329 dequant_wgt_pattern = dequantize_per_channel_weight_pattern
13131330 if input_dim_exceeds_two and not input_contiguous :
@@ -1449,12 +1466,12 @@ def _register_dequant_promotion():
14491466
14501467
14511468def _register_qconv_weight_prepack ():
1452- for dtype in [torch .float32 , torch .bfloat16 ]:
1453- weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (dtype )
1469+ for dtype , is_fp8 in itertools . product ( [torch .float32 , torch .bfloat16 ], [ True , False ]) :
1470+ weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (dtype , is_fp8 = is_fp8 )
14541471 for weight_prepack_pattern in weight_prepack_patterns :
14551472 # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
14561473 _register_qconv_weight_prepack_pass (
1457- weight_prepack_pattern , pass_number = 1 , dtype = dtype
1474+ weight_prepack_pattern , pass_number = 1 , dtype = dtype , is_fp8 = is_fp8
14581475 )
14591476
14601477
0 commit comments