Skip to content

Commit 765bda6

Browse files
committed
[Inductor][float8] Register qconv weight prepack pass for float8
1 parent 37a5f5c commit 765bda6

File tree

2 files changed

+129
-51
lines changed

2 files changed

+129
-51
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,40 @@ def forward(self, input):
138138
return out
139139

140140

141+
class FP8QDQConv2d(torch.nn.Module):
142+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
143+
super().__init__()
144+
self.qtype = torch.float8_e4m3fn
145+
self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype)
146+
self.weight_scale = 2.0
147+
self.scale = 2.0
148+
self.bias = None
149+
if bias:
150+
self.bias = torch.randn((out_channels,))
151+
self.stride = stride
152+
self.padding = padding
153+
self.dilation = dilation
154+
self.groups = groups
155+
156+
def forward(self, input):
157+
weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
158+
tensor=self.weight.data,
159+
scale=torch.tensor([self.weight_scale]),
160+
output_dtype=torch.float,
161+
)
162+
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
163+
tensor=input,
164+
scale=torch.tensor([self.scale]),
165+
float8_dtype=self.qtype,
166+
)
167+
dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
168+
tensor=q_input,
169+
scale=torch.tensor([self.scale]),
170+
output_dtype=torch.float,
171+
)
172+
173+
return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
174+
141175
def qdq(input, scale):
142176
dtype = input.dtype
143177
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
@@ -172,7 +206,7 @@ def create_mod_info_recursion(parent):
172206
for name, mod in model.named_modules():
173207
mod_type_str = mod.__class__.__name__
174208
if mod_type_str not in [
175-
"Linear",
209+
"Linear", "Conv2d"
176210
]:
177211
continue
178212
param = mod.weight
@@ -190,6 +224,11 @@ def create_mod_info_recursion(parent):
190224
patched_mod.bias = mod.bias
191225
patched_mod.weight_scale = weight_scale.item()
192226
patched_mod.weight.data = q_param
227+
elif mod_type_str in ["Conv2d"]:
228+
patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False)
229+
patched_mod.bias = mod.bias
230+
patched_mod.weight_scale = weight_scale.item()
231+
patched_mod.weight.data = q_param
193232

194233
parent = parent_child_mod_dict[mod].parent
195234
name = parent_child_mod_dict[mod].name
@@ -382,7 +421,7 @@ def _test_code_common(
382421

383422
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+")
384423
class TestPatternMatcher(TestPatternMatcherBase):
385-
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
424+
def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False):
386425
class M(torch.nn.Module):
387426
def __init__(
388427
self,
@@ -408,14 +447,14 @@ def forward(self, x):
408447
def matcher_check_fn():
409448
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
410449
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
411-
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
450+
# mixed_bf16: [dequant_node, optional(convert_element_type_4),
412451
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
413452
self.assertEqual(
414453
counters["inductor"]["qconv_weight_prepack_matcher_count"], 3
415454
)
416455
self.assertEqual(
417456
counters["inductor"]["qconv_weight_prepack_matcher_nodes"],
418-
18 if int8_mixed_bf16 else 12,
457+
18 if mixed_bf16 else 12,
419458
)
420459
self.assertEqual(
421460
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3
@@ -426,7 +465,8 @@ def matcher_check_fn():
426465
(v,),
427466
matcher_check_fn,
428467
check_quantization=True,
429-
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
468+
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
469+
is_fp8=is_fp8,
430470
)
431471

432472
@skipIfNoDynamoSupport
@@ -438,6 +478,16 @@ def test_qconv2d_cpu(self):
438478
"""
439479
self._qconv2d_test_helper("cpu")
440480

481+
@skipIfNoDynamoSupport
482+
@skipIfNoONEDNN
483+
@skip_if_rocm("Not applicable to ROCm")
484+
@skipIfNoFloat8Support
485+
def test_qconv2d_fp8_cpu(self):
486+
r"""
487+
This testcase will quantize a single Conv2d module.
488+
"""
489+
self._qconv2d_test_helper("cpu", is_fp8=True)
490+
441491
@skipIfNoDynamoSupport
442492
@skipIfNoONEDNNBF16
443493
@skipIfNoONEDNN
@@ -446,7 +496,18 @@ def test_qconv2d_int8_mixed_bf16(self):
446496
r"""
447497
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
448498
"""
449-
self._qconv2d_test_helper(int8_mixed_bf16=True)
499+
self._qconv2d_test_helper(mixed_bf16=True)
500+
501+
@skipIfNoDynamoSupport
502+
@skipIfNoONEDNNBF16
503+
@skipIfNoONEDNN
504+
@skip_if_rocm("Not applicable to ROCm")
505+
@skipIfNoFloat8Support
506+
def test_qconv2d_fp8_mixed_bf16(self):
507+
r"""
508+
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
509+
"""
510+
self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True)
450511

451512
def _qconv2d_unary_test_helper(
452513
self,

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -177,24 +177,28 @@ def get_dequantize_per_tensor_activation_pattern(
177177
KeywordArg("w_dtype"),
178178
)
179179

180-
dequantize_per_channel_to_bf16_weight_pattern = (
181-
_may_generate_pattern_with_dtype_convert(
182-
dequantize_per_channel_weight_pattern,
180+
dequantize_fp8_weight_pattern = CallFunction(
181+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
182+
KeywordArg("q_weight"),
183+
KeywordArg("w_scale"),
184+
output_dtype=KeywordArg("w_dtype"),
185+
)
186+
187+
def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern):
188+
return _may_generate_pattern_with_dtype_convert(
189+
dequant_wgt_pattern,
183190
KeywordArg("autocast_wgt_dtype"),
184191
)
185-
)
186192

187-
dequantize_per_channel_clone_weight_pattern = CallFunction(
188-
aten.clone.default,
189-
dequantize_per_channel_weight_pattern,
190-
memory_format=KeywordArg("memory_format"),
191-
)
193+
def get_dequantize_clone_weight_pattern(dequant_wgt_pattern):
194+
return CallFunction(
195+
aten.clone.default,
196+
dequant_wgt_pattern,
197+
memory_format=KeywordArg("memory_format"),
198+
)
192199

193-
dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
194-
aten.clone.default,
195-
dequantize_per_channel_to_bf16_weight_pattern,
196-
memory_format=KeywordArg("memory_format"),
197-
)
200+
def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern):
201+
return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern))
198202

199203

200204
def get_qconv_pt2e_pattern(users=1):
@@ -1596,7 +1600,7 @@ def _inner(match):
15961600
return _inner
15971601

15981602

1599-
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
1603+
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False):
16001604
@register_freezing_graph_pattern(
16011605
pattern,
16021606
extra_check=_is_valid_dequant_conv_pattern(dtype),
@@ -1609,7 +1613,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16091613
|
16101614
dequant_per_tensor
16111615
|
1612-
Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
1616+
Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight
16131617
16141618
Insert weight prepack node and change the pattern to:
16151619
int8 activation
@@ -1632,7 +1636,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16321636
)
16331637

16341638
if dtype == torch.float32:
1635-
dequant_per_channel = (
1639+
dequant = (
16361640
clone_node.args[0] # type: ignore[union-attr]
16371641
if has_clone_to_channel_last_node_in_pattern
16381642
else conv_node.args[1]
@@ -1643,25 +1647,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16431647
if has_clone_to_channel_last_node_in_pattern
16441648
else conv_node.args[1]
16451649
)
1646-
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
1650+
dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr]
16471651

1648-
assert dequant_per_channel.target in [ # type: ignore[union-attr]
1652+
assert dequant.target in [ # type: ignore[union-attr]
16491653
quantized_decomposed.dequantize_per_channel.default,
16501654
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
16511655
]
16521656

16531657
# Activation QParams
16541658
qx, x_zp, x_scale = (
16551659
kwargs["x"],
1656-
kwargs["x_zp"],
1660+
kwargs["x_zp"] if "x_zp" in kwargs else None,
16571661
kwargs["x_scale"],
16581662
)
16591663

16601664
# Weight QParams
16611665
qw, w_scale, w_zp = (
16621666
kwargs["q_weight"],
16631667
kwargs["w_scale"],
1664-
kwargs["w_zp"],
1668+
kwargs["w_zp"] if "w_zp" in kwargs else None,
16651669
)
16661670

16671671
# Conv Params
@@ -1677,14 +1681,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
16771681
if has_free_symbols(x_shape):
16781682
# For dynamic shape case, we can't get activation shape ahead of runtime.
16791683
x_shape = None
1684+
if is_fp8 and w_scale.target is torch.ops.aten.full.default:
1685+
with torch.utils._python_dispatch._disable_current_modes():
1686+
w_scale_tensor = torch.tensor([w_scale.args[1]])
1687+
match.graph.owning_module.register_buffer("w_scale", w_scale_tensor)
1688+
w_scale = match.graph.create_node("get_attr", "w_scale")
16801689
graph = match.graph
16811690
with graph.inserting_before(conv_node):
16821691
# Insert weight prepack node and the QConv node
16831692
packed_weight_inputs = (
16841693
qw,
16851694
w_scale,
1686-
x_scale,
1687-
x_zp,
1695+
x_scale.args[1] if is_fp8 and x_scale.target is torch.ops.aten.full.default else x_scale,
1696+
0,
16881697
stride,
16891698
padding,
16901699
dilation,
@@ -1715,9 +1724,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
17151724
[], # scalars
17161725
"", # algorithm
17171726
)
1718-
new_conv_node = graph.call_function(
1719-
torch.ops.onednn.qconv_pointwise.default, args=new_args
1720-
)
1727+
Node = torch.fx.node.Node
1728+
# fp8 not need zp
1729+
if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8):
1730+
new_conv_node = graph.call_function(
1731+
torch.ops.onednn.qconv_pointwise.tensor, args=new_args
1732+
)
1733+
else:
1734+
new_conv_node = graph.call_function(
1735+
torch.ops.onednn.qconv_pointwise.default, args=new_args
1736+
)
17211737
conv_node.replace_all_uses_with(new_conv_node)
17221738
new_conv_node.meta.update(conv_node.meta)
17231739

@@ -1732,25 +1748,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
17321748
graph.erase_node(clone_node) # type: ignore[arg-type]
17331749
if dtype == torch.bfloat16:
17341750
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
1735-
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
1751+
graph.erase_node(dequant) # type: ignore[arg-type]
17361752
counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1
17371753
counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len(
17381754
match.nodes
17391755
)
17401756

17411757

17421758
def _generate_dequant_convolution_node_pattern(
1743-
_dequant_per_channel_pattern, dtype=torch.float32
1759+
_dequant_pattern, dtype=torch.float32, is_fp8=False
17441760
):
17451761
assert dtype in [torch.float32, torch.bfloat16]
17461762
dequant_convolution_node_pattern = CallFunction(
17471763
aten.convolution.default,
17481764
_may_generate_pattern_with_dtype_convert(
1749-
get_dequantize_per_tensor_activation_pattern(),
1765+
get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8),
17501766
KeywordArg("autocast_act_dtype"),
17511767
dtype == torch.bfloat16,
17521768
),
1753-
_dequant_per_channel_pattern,
1769+
_dequant_pattern,
17541770
KeywordArg("b"),
17551771
KeywordArg("stride"),
17561772
KeywordArg("padding"),
@@ -1762,24 +1778,30 @@ def _generate_dequant_convolution_node_pattern(
17621778
return dequant_convolution_node_pattern
17631779

17641780

1765-
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
1781+
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False):
17661782
assert dtype in [torch.float32, torch.bfloat16]
1783+
if is_fp8:
1784+
dequant_wgt_pattern = dequantize_fp8_weight_pattern
1785+
else:
1786+
dequant_wgt_pattern = dequantize_per_channel_weight_pattern
17671787
return (
17681788
_generate_dequant_convolution_node_pattern(
1769-
dequantize_per_channel_weight_pattern
1789+
dequant_wgt_pattern
17701790
if dtype == torch.float32
1771-
else dequantize_per_channel_to_bf16_weight_pattern,
1791+
else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern),
17721792
dtype,
1793+
is_fp8=is_fp8,
17731794
),
17741795
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
17751796
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
17761797
# Depend on some heuristics, it may or may not insert to(channel_last) node
1777-
# between convolution and dequant_per_channel node
1798+
# between convolution and dequant node
17781799
_generate_dequant_convolution_node_pattern(
1779-
dequantize_per_channel_clone_weight_pattern
1800+
get_dequantize_clone_weight_pattern(dequant_wgt_pattern)
17801801
if dtype == torch.float32
1781-
else dequantize_per_channel_to_bf16_clone_weight_pattern,
1802+
else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern),
17821803
dtype,
1804+
is_fp8=is_fp8,
17831805
),
17841806
)
17851807

@@ -2187,12 +2209,7 @@ def _generate_qlinear_weight_prepack_patterns(
21872209
is_fp8=False,
21882210
):
21892211
if is_fp8:
2190-
dequant_wgt_pattern = CallFunction(
2191-
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
2192-
KeywordArg("q_weight"),
2193-
KeywordArg("w_scale"),
2194-
output_dtype=KeywordArg("w_dtype"),
2195-
)
2212+
dequant_wgt_pattern = dequantize_fp8_weight_pattern
21962213
else:
21972214
dequant_wgt_pattern = dequantize_per_channel_weight_pattern
21982215
if input_dim_exceeds_two and not input_contiguous:
@@ -2334,12 +2351,12 @@ def _register_dequant_promotion():
23342351

23352352

23362353
def _register_qconv_weight_prepack():
2337-
for dtype in [torch.float32, torch.bfloat16]:
2338-
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
2354+
for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]):
2355+
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8)
23392356
for weight_prepack_pattern in weight_prepack_patterns:
23402357
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
23412358
_register_qconv_weight_prepack_pass(
2342-
weight_prepack_pattern, pass_number=1, dtype=dtype
2359+
weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8
23432360
)
23442361

23452362

0 commit comments

Comments
 (0)