Skip to content

Commit 04bf850

Browse files
committed
[Inductor][float8] Register qconv weight prepack pass for float8
1 parent b633f89 commit 04bf850

File tree

2 files changed

+129
-51
lines changed

2 files changed

+129
-51
lines changed

test/quantization/pt2e/test_x86inductor_fusion.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,40 @@ def forward(self, input):
138138
return out
139139

140140

141+
class FP8QDQConv2d(torch.nn.Module):
142+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
143+
super().__init__()
144+
self.qtype = torch.float8_e4m3fn
145+
self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype)
146+
self.weight_scale = 2.0
147+
self.scale = 2.0
148+
self.bias = None
149+
if bias:
150+
self.bias = torch.randn((out_channels,))
151+
self.stride = stride
152+
self.padding = padding
153+
self.dilation = dilation
154+
self.groups = groups
155+
156+
def forward(self, input):
157+
weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
158+
tensor=self.weight.data,
159+
scale=torch.tensor([self.weight_scale]),
160+
output_dtype=torch.float,
161+
)
162+
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
163+
tensor=input,
164+
scale=torch.tensor([self.scale]),
165+
float8_dtype=self.qtype,
166+
)
167+
dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
168+
tensor=q_input,
169+
scale=torch.tensor([self.scale]),
170+
output_dtype=torch.float,
171+
)
172+
173+
return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
174+
141175
def qdq(input, scale):
142176
dtype = input.dtype
143177
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
@@ -172,7 +206,7 @@ def create_mod_info_recursion(parent):
172206
for name, mod in model.named_modules():
173207
mod_type_str = mod.__class__.__name__
174208
if mod_type_str not in [
175-
"Linear",
209+
"Linear", "Conv2d"
176210
]:
177211
continue
178212
param = mod.weight
@@ -190,6 +224,11 @@ def create_mod_info_recursion(parent):
190224
patched_mod.bias = mod.bias
191225
patched_mod.weight_scale = weight_scale.item()
192226
patched_mod.weight.data = q_param
227+
elif mod_type_str in ["Conv2d"]:
228+
patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False)
229+
patched_mod.bias = mod.bias
230+
patched_mod.weight_scale = weight_scale.item()
231+
patched_mod.weight.data = q_param
193232

194233
parent = parent_child_mod_dict[mod].parent
195234
name = parent_child_mod_dict[mod].name
@@ -382,7 +421,7 @@ def _test_code_common(
382421

383422
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+")
384423
class TestPatternMatcher(TestPatternMatcherBase):
385-
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
424+
def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False):
386425
class M(torch.nn.Module):
387426
def __init__(
388427
self,
@@ -408,14 +447,14 @@ def forward(self, x):
408447
def matcher_check_fn():
409448
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
410449
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
411-
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
450+
# mixed_bf16: [dequant_node, optional(convert_element_type_4),
412451
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
413452
self.assertEqual(
414453
counters["inductor"]["qconv_weight_prepack_matcher_count"], 3
415454
)
416455
self.assertEqual(
417456
counters["inductor"]["qconv_weight_prepack_matcher_nodes"],
418-
18 if int8_mixed_bf16 else 12,
457+
18 if mixed_bf16 else 12,
419458
)
420459
self.assertEqual(
421460
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3
@@ -426,7 +465,8 @@ def matcher_check_fn():
426465
(v,),
427466
matcher_check_fn,
428467
check_quantization=True,
429-
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
468+
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
469+
is_fp8=is_fp8,
430470
)
431471

432472
@skipIfNoDynamoSupport
@@ -438,6 +478,16 @@ def test_qconv2d_cpu(self):
438478
"""
439479
self._qconv2d_test_helper("cpu")
440480

481+
@skipIfNoDynamoSupport
482+
@skipIfNoONEDNN
483+
@skip_if_rocm("Not applicable to ROCm")
484+
@skipIfNoFloat8Support
485+
def test_qconv2d_fp8_cpu(self):
486+
r"""
487+
This testcase will quantize a single Conv2d module.
488+
"""
489+
self._qconv2d_test_helper("cpu", is_fp8=True)
490+
441491
@skipIfNoDynamoSupport
442492
@skipIfNoONEDNNBF16
443493
@skipIfNoONEDNN
@@ -446,7 +496,18 @@ def test_qconv2d_int8_mixed_bf16(self):
446496
r"""
447497
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
448498
"""
449-
self._qconv2d_test_helper(int8_mixed_bf16=True)
499+
self._qconv2d_test_helper(mixed_bf16=True)
500+
501+
@skipIfNoDynamoSupport
502+
@skipIfNoONEDNNBF16
503+
@skipIfNoONEDNN
504+
@skip_if_rocm("Not applicable to ROCm")
505+
@skipIfNoFloat8Support
506+
def test_qconv2d_fp8_mixed_bf16(self):
507+
r"""
508+
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
509+
"""
510+
self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True)
450511

451512
def _qconv2d_unary_test_helper(
452513
self,

torchao/quantization/pt2e/inductor_passes/x86.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -167,24 +167,28 @@ def get_dequantize_per_tensor_activation_pattern(
167167
KeywordArg("w_dtype"),
168168
)
169169

170-
dequantize_per_channel_to_bf16_weight_pattern = (
171-
_may_generate_pattern_with_dtype_convert(
172-
dequantize_per_channel_weight_pattern,
170+
dequantize_fp8_weight_pattern = CallFunction(
171+
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
172+
KeywordArg("q_weight"),
173+
KeywordArg("w_scale"),
174+
output_dtype=KeywordArg("w_dtype"),
175+
)
176+
177+
def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern):
178+
return _may_generate_pattern_with_dtype_convert(
179+
dequant_wgt_pattern,
173180
KeywordArg("autocast_wgt_dtype"),
174181
)
175-
)
176182

177-
dequantize_per_channel_clone_weight_pattern = CallFunction(
178-
aten.clone.default,
179-
dequantize_per_channel_weight_pattern,
180-
memory_format=KeywordArg("memory_format"),
181-
)
183+
def get_dequantize_clone_weight_pattern(dequant_wgt_pattern):
184+
return CallFunction(
185+
aten.clone.default,
186+
dequant_wgt_pattern,
187+
memory_format=KeywordArg("memory_format"),
188+
)
182189

183-
dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
184-
aten.clone.default,
185-
dequantize_per_channel_to_bf16_weight_pattern,
186-
memory_format=KeywordArg("memory_format"),
187-
)
190+
def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern):
191+
return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern))
188192

189193

190194
def get_qconv_pt2e_pattern(users=1):
@@ -711,7 +715,7 @@ def _inner(match):
711715
return _inner
712716

713717

714-
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
718+
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False):
715719
@register_freezing_graph_pattern(
716720
pattern,
717721
extra_check=_is_valid_dequant_conv_pattern(dtype),
@@ -724,7 +728,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
724728
|
725729
dequant_per_tensor
726730
|
727-
Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
731+
Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight
728732
729733
Insert weight prepack node and change the pattern to:
730734
int8 activation
@@ -747,7 +751,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
747751
)
748752

749753
if dtype == torch.float32:
750-
dequant_per_channel = (
754+
dequant = (
751755
clone_node.args[0] # type: ignore[union-attr]
752756
if has_clone_to_channel_last_node_in_pattern
753757
else conv_node.args[1]
@@ -758,25 +762,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
758762
if has_clone_to_channel_last_node_in_pattern
759763
else conv_node.args[1]
760764
)
761-
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
765+
dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr]
762766

763-
assert dequant_per_channel.target in [ # type: ignore[union-attr]
767+
assert dequant.target in [ # type: ignore[union-attr]
764768
quantized_decomposed.dequantize_per_channel.default,
765769
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
766770
]
767771

768772
# Activation QParams
769773
qx, x_zp, x_scale = (
770774
kwargs["x"],
771-
kwargs["x_zp"],
775+
kwargs["x_zp"] if "x_zp" in kwargs else None,
772776
kwargs["x_scale"],
773777
)
774778

775779
# Weight QParams
776780
qw, w_scale, w_zp = (
777781
kwargs["q_weight"],
778782
kwargs["w_scale"],
779-
kwargs["w_zp"],
783+
kwargs["w_zp"] if "w_zp" in kwargs else None,
780784
)
781785

782786
# Conv Params
@@ -792,14 +796,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
792796
if has_free_symbols(x_shape):
793797
# For dynamic shape case, we can't get activation shape ahead of runtime.
794798
x_shape = None
799+
if is_fp8 and w_scale.target is torch.ops.aten.full.default:
800+
with torch.utils._python_dispatch._disable_current_modes():
801+
w_scale_tensor = torch.tensor([w_scale.args[1]])
802+
match.graph.owning_module.register_buffer("w_scale", w_scale_tensor)
803+
w_scale = match.graph.create_node("get_attr", "w_scale")
795804
graph = match.graph
796805
with graph.inserting_before(conv_node):
797806
# Insert weight prepack node and the QConv node
798807
packed_weight_inputs = (
799808
qw,
800809
w_scale,
801-
x_scale,
802-
x_zp,
810+
x_scale.args[1] if is_fp8 and x_scale.target is torch.ops.aten.full.default else x_scale,
811+
0,
803812
stride,
804813
padding,
805814
dilation,
@@ -830,9 +839,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
830839
[], # scalars
831840
"", # algorithm
832841
)
833-
new_conv_node = graph.call_function(
834-
torch.ops.onednn.qconv_pointwise.default, args=new_args
835-
)
842+
Node = torch.fx.node.Node
843+
# fp8 not need zp
844+
if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8):
845+
new_conv_node = graph.call_function(
846+
torch.ops.onednn.qconv_pointwise.tensor, args=new_args
847+
)
848+
else:
849+
new_conv_node = graph.call_function(
850+
torch.ops.onednn.qconv_pointwise.default, args=new_args
851+
)
836852
conv_node.replace_all_uses_with(new_conv_node)
837853
new_conv_node.meta.update(conv_node.meta)
838854

@@ -847,25 +863,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
847863
graph.erase_node(clone_node) # type: ignore[arg-type]
848864
if dtype == torch.bfloat16:
849865
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
850-
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
866+
graph.erase_node(dequant) # type: ignore[arg-type]
851867
counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1
852868
counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len(
853869
match.nodes
854870
)
855871

856872

857873
def _generate_dequant_convolution_node_pattern(
858-
_dequant_per_channel_pattern, dtype=torch.float32
874+
_dequant_pattern, dtype=torch.float32, is_fp8=False
859875
):
860876
assert dtype in [torch.float32, torch.bfloat16]
861877
dequant_convolution_node_pattern = CallFunction(
862878
aten.convolution.default,
863879
_may_generate_pattern_with_dtype_convert(
864-
get_dequantize_per_tensor_activation_pattern(),
880+
get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8),
865881
KeywordArg("autocast_act_dtype"),
866882
dtype == torch.bfloat16,
867883
),
868-
_dequant_per_channel_pattern,
884+
_dequant_pattern,
869885
KeywordArg("b"),
870886
KeywordArg("stride"),
871887
KeywordArg("padding"),
@@ -877,24 +893,30 @@ def _generate_dequant_convolution_node_pattern(
877893
return dequant_convolution_node_pattern
878894

879895

880-
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
896+
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False):
881897
assert dtype in [torch.float32, torch.bfloat16]
898+
if is_fp8:
899+
dequant_wgt_pattern = dequantize_fp8_weight_pattern
900+
else:
901+
dequant_wgt_pattern = dequantize_per_channel_weight_pattern
882902
return (
883903
_generate_dequant_convolution_node_pattern(
884-
dequantize_per_channel_weight_pattern
904+
dequant_wgt_pattern
885905
if dtype == torch.float32
886-
else dequantize_per_channel_to_bf16_weight_pattern,
906+
else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern),
887907
dtype,
908+
is_fp8=is_fp8,
888909
),
889910
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
890911
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
891912
# Depend on some heuristics, it may or may not insert to(channel_last) node
892-
# between convolution and dequant_per_channel node
913+
# between convolution and dequant node
893914
_generate_dequant_convolution_node_pattern(
894-
dequantize_per_channel_clone_weight_pattern
915+
get_dequantize_clone_weight_pattern(dequant_wgt_pattern)
895916
if dtype == torch.float32
896-
else dequantize_per_channel_to_bf16_clone_weight_pattern,
917+
else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern),
897918
dtype,
919+
is_fp8=is_fp8,
898920
),
899921
)
900922

@@ -1302,12 +1324,7 @@ def _generate_qlinear_weight_prepack_patterns(
13021324
is_fp8=False,
13031325
):
13041326
if is_fp8:
1305-
dequant_wgt_pattern = CallFunction(
1306-
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
1307-
KeywordArg("q_weight"),
1308-
KeywordArg("w_scale"),
1309-
output_dtype=KeywordArg("w_dtype"),
1310-
)
1327+
dequant_wgt_pattern = dequantize_fp8_weight_pattern
13111328
else:
13121329
dequant_wgt_pattern = dequantize_per_channel_weight_pattern
13131330
if input_dim_exceeds_two and not input_contiguous:
@@ -1449,12 +1466,12 @@ def _register_dequant_promotion():
14491466

14501467

14511468
def _register_qconv_weight_prepack():
1452-
for dtype in [torch.float32, torch.bfloat16]:
1453-
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
1469+
for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]):
1470+
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8)
14541471
for weight_prepack_pattern in weight_prepack_patterns:
14551472
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
14561473
_register_qconv_weight_prepack_pass(
1457-
weight_prepack_pattern, pass_number=1, dtype=dtype
1474+
weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8
14581475
)
14591476

14601477

0 commit comments

Comments
 (0)