diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a0aef11541..cf1bac235d 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -167,24 +167,28 @@ def get_dequantize_per_tensor_activation_pattern( KeywordArg("w_dtype"), ) -dequantize_per_channel_to_bf16_weight_pattern = ( - _may_generate_pattern_with_dtype_convert( - dequantize_per_channel_weight_pattern, +dequantize_fp8_weight_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), +) + +def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): + return _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, KeywordArg("autocast_wgt_dtype"), ) -) -dequantize_per_channel_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): + return CallFunction( + aten.clone.default, + dequant_wgt_pattern, + memory_format=KeywordArg("memory_format"), + ) -dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_to_bf16_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): + return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)) def get_qconv_pt2e_pattern(users=1): @@ -711,7 +715,7 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -724,7 +728,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -747,7 +751,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) if dtype == torch.float32: - dequant_per_channel = ( + dequant = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] @@ -758,9 +762,9 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) - dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert dequant_per_channel.target in [ # type: ignore[union-attr] + assert dequant.target in [ # type: ignore[union-attr] quantized_decomposed.dequantize_per_channel.default, torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] @@ -768,7 +772,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], - kwargs["x_zp"], + kwargs["x_zp"] if "x_zp" in kwargs else None, kwargs["x_scale"], ) @@ -776,7 +780,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"], + kwargs["w_zp"] if "w_zp" in kwargs else None, ) # Conv Params @@ -792,14 +796,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None + if w_scale.target is torch.ops.aten.full.default: + with torch.utils._python_dispatch._disable_current_modes(): + w_scale_tensor = torch.tensor([w_scale.args[1]]) + match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) + w_scale = match.graph.create_node("get_attr", "w_scale") graph = match.graph with graph.inserting_before(conv_node): # Insert weight prepack node and the QConv node packed_weight_inputs = ( qw, w_scale, - x_scale, - x_zp, + x_scale.args[1] if x_scale.target is torch.ops.aten.full.default else x_scale, + 0, stride, padding, dilation, @@ -830,9 +839,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): [], # scalars "", # algorithm ) - new_conv_node = graph.call_function( - torch.ops.onednn.qconv_pointwise.default, args=new_args - ) + Node = torch.fx.node.Node + # fp8 not need zp + if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.tensor, args=new_args + ) + else: + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -847,7 +863,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(clone_node) # type: ignore[arg-type] if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] - graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + graph.erase_node(dequant) # type: ignore[arg-type] counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes @@ -855,17 +871,17 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_convolution_node_pattern( - _dequant_per_channel_pattern, dtype=torch.float32 + _dequant_pattern, dtype=torch.float32, is_fp8=False ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), - _dequant_per_channel_pattern, + _dequant_pattern, KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -877,24 +893,30 @@ def _generate_dequant_convolution_node_pattern( return dequant_convolution_node_pattern -def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False): assert dtype in [torch.float32, torch.bfloat16] + if is_fp8: + dequant_wgt_pattern = dequantize_fp8_weight_pattern + else: + dequant_wgt_pattern = dequantize_per_channel_weight_pattern return ( _generate_dequant_convolution_node_pattern( - dequantize_per_channel_weight_pattern + dequant_wgt_pattern if dtype == torch.float32 - else dequantize_per_channel_to_bf16_weight_pattern, + else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. # Depend on some heuristics, it may or may not insert to(channel_last) node - # between convolution and dequant_per_channel node + # between convolution and dequant node _generate_dequant_convolution_node_pattern( - dequantize_per_channel_clone_weight_pattern + get_dequantize_clone_weight_pattern(dequant_wgt_pattern) if dtype == torch.float32 - else dequantize_per_channel_to_bf16_clone_weight_pattern, + else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), ) @@ -1302,12 +1324,7 @@ def _generate_qlinear_weight_prepack_patterns( is_fp8=False, ): if is_fp8: - dequant_wgt_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), - ) + dequant_wgt_pattern = dequantize_fp8_weight_pattern else: dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: @@ -1449,12 +1466,12 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype in [torch.float32, torch.bfloat16]: - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]): + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( - weight_prepack_pattern, pass_number=1, dtype=dtype + weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8 )