Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 62 additions & 45 deletions torchao/quantization/pt2e/inductor_passes/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,28 @@ def get_dequantize_per_tensor_activation_pattern(
KeywordArg("w_dtype"),
)

dequantize_per_channel_to_bf16_weight_pattern = (
_may_generate_pattern_with_dtype_convert(
dequantize_per_channel_weight_pattern,
dequantize_fp8_weight_pattern = CallFunction(
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
KeywordArg("q_weight"),
KeywordArg("w_scale"),
output_dtype=KeywordArg("w_dtype"),
)

def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern):
return _may_generate_pattern_with_dtype_convert(
dequant_wgt_pattern,
KeywordArg("autocast_wgt_dtype"),
)
)

dequantize_per_channel_clone_weight_pattern = CallFunction(
aten.clone.default,
dequantize_per_channel_weight_pattern,
memory_format=KeywordArg("memory_format"),
)
def get_dequantize_clone_weight_pattern(dequant_wgt_pattern):
return CallFunction(
aten.clone.default,
dequant_wgt_pattern,
memory_format=KeywordArg("memory_format"),
)

dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
aten.clone.default,
dequantize_per_channel_to_bf16_weight_pattern,
memory_format=KeywordArg("memory_format"),
)
def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern):
return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern))


def get_qconv_pt2e_pattern(users=1):
Expand Down Expand Up @@ -711,7 +715,7 @@ def _inner(match):
return _inner


def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_dequant_conv_pattern(dtype),
Expand All @@ -724,7 +728,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
|
dequant_per_tensor
|
Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight

Insert weight prepack node and change the pattern to:
int8 activation
Expand All @@ -747,7 +751,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
)

if dtype == torch.float32:
dequant_per_channel = (
dequant = (
clone_node.args[0] # type: ignore[union-attr]
if has_clone_to_channel_last_node_in_pattern
else conv_node.args[1]
Expand All @@ -758,25 +762,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
if has_clone_to_channel_last_node_in_pattern
else conv_node.args[1]
)
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr]

assert dequant_per_channel.target in [ # type: ignore[union-attr]
assert dequant.target in [ # type: ignore[union-attr]
quantized_decomposed.dequantize_per_channel.default,
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
]

# Activation QParams
qx, x_zp, x_scale = (
kwargs["x"],
kwargs["x_zp"],
kwargs["x_zp"] if "x_zp" in kwargs else None,
kwargs["x_scale"],
)

# Weight QParams
qw, w_scale, w_zp = (
kwargs["q_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
kwargs["w_zp"] if "w_zp" in kwargs else None,
)

# Conv Params
Expand All @@ -792,14 +796,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
if has_free_symbols(x_shape):
# For dynamic shape case, we can't get activation shape ahead of runtime.
x_shape = None
if w_scale.target is torch.ops.aten.full.default:
with torch.utils._python_dispatch._disable_current_modes():
w_scale_tensor = torch.tensor([w_scale.args[1]])
match.graph.owning_module.register_buffer("w_scale", w_scale_tensor)
w_scale = match.graph.create_node("get_attr", "w_scale")
graph = match.graph
with graph.inserting_before(conv_node):
# Insert weight prepack node and the QConv node
packed_weight_inputs = (
qw,
w_scale,
x_scale,
x_zp,
x_scale.args[1] if x_scale.target is torch.ops.aten.full.default else x_scale,
0,
stride,
padding,
dilation,
Expand Down Expand Up @@ -830,9 +839,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
[], # scalars
"", # algorithm
)
new_conv_node = graph.call_function(
torch.ops.onednn.qconv_pointwise.default, args=new_args
)
Node = torch.fx.node.Node
# fp8 not need zp
if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8):
new_conv_node = graph.call_function(
torch.ops.onednn.qconv_pointwise.tensor, args=new_args
)
else:
new_conv_node = graph.call_function(
torch.ops.onednn.qconv_pointwise.default, args=new_args
)
conv_node.replace_all_uses_with(new_conv_node)
new_conv_node.meta.update(conv_node.meta)

Expand All @@ -847,25 +863,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
graph.erase_node(clone_node) # type: ignore[arg-type]
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
graph.erase_node(dequant) # type: ignore[arg-type]
counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1
counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len(
match.nodes
)


def _generate_dequant_convolution_node_pattern(
_dequant_per_channel_pattern, dtype=torch.float32
_dequant_pattern, dtype=torch.float32, is_fp8=False
):
assert dtype in [torch.float32, torch.bfloat16]
dequant_convolution_node_pattern = CallFunction(
aten.convolution.default,
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(),
get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
_dequant_per_channel_pattern,
_dequant_pattern,
KeywordArg("b"),
KeywordArg("stride"),
KeywordArg("padding"),
Expand All @@ -877,24 +893,30 @@ def _generate_dequant_convolution_node_pattern(
return dequant_convolution_node_pattern


def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False):
assert dtype in [torch.float32, torch.bfloat16]
if is_fp8:
dequant_wgt_pattern = dequantize_fp8_weight_pattern
else:
dequant_wgt_pattern = dequantize_per_channel_weight_pattern
return (
_generate_dequant_convolution_node_pattern(
dequantize_per_channel_weight_pattern
dequant_wgt_pattern
if dtype == torch.float32
else dequantize_per_channel_to_bf16_weight_pattern,
else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern),
dtype,
is_fp8=is_fp8,
),
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
# Depend on some heuristics, it may or may not insert to(channel_last) node
# between convolution and dequant_per_channel node
# between convolution and dequant node
_generate_dequant_convolution_node_pattern(
dequantize_per_channel_clone_weight_pattern
get_dequantize_clone_weight_pattern(dequant_wgt_pattern)
if dtype == torch.float32
else dequantize_per_channel_to_bf16_clone_weight_pattern,
else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern),
dtype,
is_fp8=is_fp8,
),
)

Expand Down Expand Up @@ -1302,12 +1324,7 @@ def _generate_qlinear_weight_prepack_patterns(
is_fp8=False,
):
if is_fp8:
dequant_wgt_pattern = CallFunction(
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
KeywordArg("q_weight"),
KeywordArg("w_scale"),
output_dtype=KeywordArg("w_dtype"),
)
dequant_wgt_pattern = dequantize_fp8_weight_pattern
else:
dequant_wgt_pattern = dequantize_per_channel_weight_pattern
if input_dim_exceeds_two and not input_contiguous:
Expand Down Expand Up @@ -1449,12 +1466,12 @@ def _register_dequant_promotion():


def _register_qconv_weight_prepack():
for dtype in [torch.float32, torch.bfloat16]:
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]):
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8)
for weight_prepack_pattern in weight_prepack_patterns:
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
_register_qconv_weight_prepack_pass(
weight_prepack_pattern, pass_number=1, dtype=dtype
weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8
)


Expand Down