@@ -177,24 +177,28 @@ def get_dequantize_per_tensor_activation_pattern(
177177 KeywordArg ("w_dtype" ),
178178)
179179
180- dequantize_per_channel_to_bf16_weight_pattern = (
181- _may_generate_pattern_with_dtype_convert (
182- dequantize_per_channel_weight_pattern ,
180+ dequantize_fp8_weight_pattern = CallFunction (
181+ torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
182+ KeywordArg ("q_weight" ),
183+ KeywordArg ("w_scale" ),
184+ output_dtype = KeywordArg ("w_dtype" ),
185+ )
186+
187+ def get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ):
188+ return _may_generate_pattern_with_dtype_convert (
189+ dequant_wgt_pattern ,
183190 KeywordArg ("autocast_wgt_dtype" ),
184191 )
185- )
186192
187- dequantize_per_channel_clone_weight_pattern = CallFunction (
188- aten .clone .default ,
189- dequantize_per_channel_weight_pattern ,
190- memory_format = KeywordArg ("memory_format" ),
191- )
193+ def get_dequantize_clone_weight_pattern (dequant_wgt_pattern ):
194+ return CallFunction (
195+ aten .clone .default ,
196+ dequant_wgt_pattern ,
197+ memory_format = KeywordArg ("memory_format" ),
198+ )
192199
193- dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction (
194- aten .clone .default ,
195- dequantize_per_channel_to_bf16_weight_pattern ,
196- memory_format = KeywordArg ("memory_format" ),
197- )
200+ def get_dequantize_to_bf16_clone_weight_pattern (dequant_wgt_pattern ):
201+ return get_dequantize_clone_weight_pattern (get_dequantize_to_bf16_weight_pattern (dequant_wgt_pattern ))
198202
199203
200204def get_qconv_pt2e_pattern (users = 1 ):
@@ -1596,7 +1600,7 @@ def _inner(match):
15961600 return _inner
15971601
15981602
1599- def _register_qconv_weight_prepack_pass (pattern , pass_number , dtype = torch .float32 ):
1603+ def _register_qconv_weight_prepack_pass (pattern , pass_number , dtype = torch .float32 , is_fp8 = False ):
16001604 @register_freezing_graph_pattern (
16011605 pattern ,
16021606 extra_check = _is_valid_dequant_conv_pattern (dtype ),
@@ -1609,7 +1613,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16091613 |
16101614 dequant_per_tensor
16111615 |
1612- Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
1616+ Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight
16131617
16141618 Insert weight prepack node and change the pattern to:
16151619 int8 activation
@@ -1632,7 +1636,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16321636 )
16331637
16341638 if dtype == torch .float32 :
1635- dequant_per_channel = (
1639+ dequant = (
16361640 clone_node .args [0 ] # type: ignore[union-attr]
16371641 if has_clone_to_channel_last_node_in_pattern
16381642 else conv_node .args [1 ]
@@ -1643,25 +1647,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16431647 if has_clone_to_channel_last_node_in_pattern
16441648 else conv_node .args [1 ]
16451649 )
1646- dequant_per_channel = weight_to_bf16_node .args [0 ] # type: ignore[union-attr]
1650+ dequant = weight_to_bf16_node .args [0 ] # type: ignore[union-attr]
16471651
1648- assert dequant_per_channel .target in [ # type: ignore[union-attr]
1652+ assert dequant .target in [ # type: ignore[union-attr]
16491653 quantized_decomposed .dequantize_per_channel .default ,
16501654 torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
16511655 ]
16521656
16531657 # Activation QParams
16541658 qx , x_zp , x_scale = (
16551659 kwargs ["x" ],
1656- kwargs ["x_zp" ],
1660+ kwargs ["x_zp" ] if "x_zp" in kwargs else None ,
16571661 kwargs ["x_scale" ],
16581662 )
16591663
16601664 # Weight QParams
16611665 qw , w_scale , w_zp = (
16621666 kwargs ["q_weight" ],
16631667 kwargs ["w_scale" ],
1664- kwargs ["w_zp" ],
1668+ kwargs ["w_zp" ] if "w_zp" in kwargs else None ,
16651669 )
16661670
16671671 # Conv Params
@@ -1677,14 +1681,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16771681 if has_free_symbols (x_shape ):
16781682 # For dynamic shape case, we can't get activation shape ahead of runtime.
16791683 x_shape = None
1684+ if is_fp8 and w_scale .target is torch .ops .aten .full .default :
1685+ with torch .utils ._python_dispatch ._disable_current_modes ():
1686+ w_scale_tensor = torch .tensor ([w_scale .args [1 ]])
1687+ match .graph .owning_module .register_buffer ("w_scale" , w_scale_tensor )
1688+ w_scale = match .graph .create_node ("get_attr" , "w_scale" )
16801689 graph = match .graph
16811690 with graph .inserting_before (conv_node ):
16821691 # Insert weight prepack node and the QConv node
16831692 packed_weight_inputs = (
16841693 qw ,
16851694 w_scale ,
1686- x_scale ,
1687- x_zp ,
1695+ x_scale . args [ 1 ] if is_fp8 and x_scale . target is torch . ops . aten . full . default else x_scale ,
1696+ 0 ,
16881697 stride ,
16891698 padding ,
16901699 dilation ,
@@ -1715,9 +1724,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
17151724 [], # scalars
17161725 "" , # algorithm
17171726 )
1718- new_conv_node = graph .call_function (
1719- torch .ops .onednn .qconv_pointwise .default , args = new_args
1720- )
1727+ Node = torch .fx .node .Node
1728+ # fp8 not need zp
1729+ if isinstance (x_scale , Node ) and (isinstance (x_zp , Node ) or is_fp8 ):
1730+ new_conv_node = graph .call_function (
1731+ torch .ops .onednn .qconv_pointwise .tensor , args = new_args
1732+ )
1733+ else :
1734+ new_conv_node = graph .call_function (
1735+ torch .ops .onednn .qconv_pointwise .default , args = new_args
1736+ )
17211737 conv_node .replace_all_uses_with (new_conv_node )
17221738 new_conv_node .meta .update (conv_node .meta )
17231739
@@ -1732,25 +1748,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
17321748 graph .erase_node (clone_node ) # type: ignore[arg-type]
17331749 if dtype == torch .bfloat16 :
17341750 graph .erase_node (weight_to_bf16_node ) # type: ignore[possibly-undefined, arg-type]
1735- graph .erase_node (dequant_per_channel ) # type: ignore[arg-type]
1751+ graph .erase_node (dequant ) # type: ignore[arg-type]
17361752 counters ["inductor" ]["qconv_weight_prepack_matcher_count" ] += 1
17371753 counters ["inductor" ]["qconv_weight_prepack_matcher_nodes" ] += len (
17381754 match .nodes
17391755 )
17401756
17411757
17421758def _generate_dequant_convolution_node_pattern (
1743- _dequant_per_channel_pattern , dtype = torch .float32
1759+ _dequant_pattern , dtype = torch .float32 , is_fp8 = False
17441760):
17451761 assert dtype in [torch .float32 , torch .bfloat16 ]
17461762 dequant_convolution_node_pattern = CallFunction (
17471763 aten .convolution .default ,
17481764 _may_generate_pattern_with_dtype_convert (
1749- get_dequantize_per_tensor_activation_pattern (),
1765+ get_dequantize_per_tensor_activation_pattern (is_fp8 = is_fp8 ),
17501766 KeywordArg ("autocast_act_dtype" ),
17511767 dtype == torch .bfloat16 ,
17521768 ),
1753- _dequant_per_channel_pattern ,
1769+ _dequant_pattern ,
17541770 KeywordArg ("b" ),
17551771 KeywordArg ("stride" ),
17561772 KeywordArg ("padding" ),
@@ -1762,24 +1778,30 @@ def _generate_dequant_convolution_node_pattern(
17621778 return dequant_convolution_node_pattern
17631779
17641780
1765- def _generate_qconv_weight_prepack_patterns (dtype = torch .float32 ):
1781+ def _generate_qconv_weight_prepack_patterns (dtype = torch .float32 , is_fp8 = False ):
17661782 assert dtype in [torch .float32 , torch .bfloat16 ]
1783+ if is_fp8 :
1784+ dequant_wgt_pattern = dequantize_fp8_weight_pattern
1785+ else :
1786+ dequant_wgt_pattern = dequantize_per_channel_weight_pattern
17671787 return (
17681788 _generate_dequant_convolution_node_pattern (
1769- dequantize_per_channel_weight_pattern
1789+ dequant_wgt_pattern
17701790 if dtype == torch .float32
1771- else dequantize_per_channel_to_bf16_weight_pattern ,
1791+ else get_dequantize_to_bf16_weight_pattern ( dequant_wgt_pattern ) ,
17721792 dtype ,
1793+ is_fp8 = is_fp8 ,
17731794 ),
17741795 # There is another pattern due to the pass of convert_conv_weights_to_channels_last
17751796 # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
17761797 # Depend on some heuristics, it may or may not insert to(channel_last) node
1777- # between convolution and dequant_per_channel node
1798+ # between convolution and dequant node
17781799 _generate_dequant_convolution_node_pattern (
1779- dequantize_per_channel_clone_weight_pattern
1800+ get_dequantize_clone_weight_pattern ( dequant_wgt_pattern )
17801801 if dtype == torch .float32
1781- else dequantize_per_channel_to_bf16_clone_weight_pattern ,
1802+ else get_dequantize_to_bf16_clone_weight_pattern ( dequant_wgt_pattern ) ,
17821803 dtype ,
1804+ is_fp8 = is_fp8 ,
17831805 ),
17841806 )
17851807
@@ -2187,12 +2209,7 @@ def _generate_qlinear_weight_prepack_patterns(
21872209 is_fp8 = False ,
21882210):
21892211 if is_fp8 :
2190- dequant_wgt_pattern = CallFunction (
2191- torch .ops .torchao .dequantize_affine_float8_non_decomposed .default ,
2192- KeywordArg ("q_weight" ),
2193- KeywordArg ("w_scale" ),
2194- output_dtype = KeywordArg ("w_dtype" ),
2195- )
2212+ dequant_wgt_pattern = dequantize_fp8_weight_pattern
21962213 else :
21972214 dequant_wgt_pattern = dequantize_per_channel_weight_pattern
21982215 if input_dim_exceeds_two and not input_contiguous :
@@ -2334,12 +2351,12 @@ def _register_dequant_promotion():
23342351
23352352
23362353def _register_qconv_weight_prepack ():
2337- for dtype in [torch .float32 , torch .bfloat16 ]:
2338- weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (dtype )
2354+ for dtype , is_fp8 in itertools . product ( [torch .float32 , torch .bfloat16 ], [ True , False ]) :
2355+ weight_prepack_patterns = _generate_qconv_weight_prepack_patterns (dtype , is_fp8 = is_fp8 )
23392356 for weight_prepack_pattern in weight_prepack_patterns :
23402357 # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
23412358 _register_qconv_weight_prepack_pass (
2342- weight_prepack_pattern , pass_number = 1 , dtype = dtype
2359+ weight_prepack_pattern , pass_number = 1 , dtype = dtype , is_fp8 = is_fp8
23432360 )
23442361
23452362
0 commit comments