Skip to content

Commit 792ddf3

Browse files
authored
[Snippets][CPU] Precision enforcement fix (#32914)
### Details: - *Fixed precision enforcement issue via workaround for precision enforcement pipeline: do not enforce precisions if precision enforcement breaks transpose fusion. Although this leads to non-optimal performance, such cases (Subgraph input precisions are f32, but bf16 is forced internally) don't occur in real scenarios.* - *The changes were validated on large model scope to confirm that there are no perf degradations on real models* - *Tests are extended to cover the fixed scenario (f32 in/out precision and bf16 inference precision)* ### Tickets: - *CVS-176621*
1 parent 76a84a4 commit 792ddf3

File tree

5 files changed

+60
-49
lines changed

5 files changed

+60
-49
lines changed

src/common/snippets/src/pass/align_element_types.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,6 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
4747
std::shared_ptr<ov::Node> consumer = shape_infer_leaf ? shape_infer_leaf : results[i];
4848
auto parent_output = consumer->get_input_source_output(0);
4949

50-
// Snippets supports Transpose only after Parameter or before Result nodes
51-
// So we have to insert Convert before Transpose (if there is) on Subgraph outputs
52-
const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent_output.get_node_shared_ptr());
53-
if (transpose) {
54-
OPENVINO_ASSERT(
55-
parent_output.get_target_inputs().size() == 1,
56-
"If Result has Transpose on input, this Result must be single consumer of the Transpose");
57-
parent_output = transpose->get_input_source_output(0);
58-
consumer = transpose;
59-
}
60-
6150
// If there is already Convert[needed_in_type->original_type] and this node has only one consumer, we can
6251
// remove the Convert, since the sequence existing Convert[needed_in_type->original_type] -> new
6352
// Convert[original_type->needed_in_type] is redundant
@@ -81,9 +70,6 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
8170

8271
consumer->set_argument(0, convert);
8372
consumer->validate_and_infer_types();
84-
if (transpose) {
85-
results[i]->validate_and_infer_types();
86-
}
8773
is_modified = true;
8874
}
8975
}

src/plugins/intel_cpu/src/nodes/subgraph.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646

4747
# include "emitters/snippets/x64/cpu_generator.hpp"
4848
# include "executors/x64/subgraph.hpp"
49+
# include "snippets/lowered/port_descriptor.hpp"
4950
# include "snippets/op/brgemm.hpp"
50-
# include "snippets/pass/matmul_to_brgemm.hpp"
51+
# include "snippets/utils/utils.hpp"
5152
# include "transformations/snippets/x64/op/brgemm_utils.hpp"
5253
#elif defined(OPENVINO_ARCH_ARM64)
5354
# include <cpu/aarch64/cpu_isa_traits.hpp>
@@ -86,6 +87,7 @@
8687
# include "snippets/lowered/pass/init_loops.hpp"
8788
# include "snippets/lowered/pass/insert_buffers.hpp"
8889
# include "snippets/lowered/pass/insert_loops.hpp"
90+
# include "snippets/pass/fuse_transpose_brgemm.hpp"
8991
# include "transformations/snippets/common/pass/enforce_precision.hpp"
9092
# include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
9193
# include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
@@ -552,34 +554,38 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
552554

553555
if (any_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
554556
subgraph_attrs->snippet->has_domain_sensitive_ops()) {
555-
// enforce BF16 precisions to supported operations
556-
// MatMul has to be decomposed to Brgemm operations before enforcement
557-
// Notes:
558-
// - MatMul decomposition will be run later again for case if BF16 enforcement is not happened
559-
// - `MatMulToBrgemm` pass fuse `transpose_a` and `transpose_b` from MatMul to inputs of Brgemm as layouts.
560-
// These layouts are resized to ranks of input shapes. But since `Canonicalization` might
561-
// reshape shapes, the pass `MatMulToBrgemm` should be after the pass `Canonicalization` to
562-
// fuse layouts with ranks aligned with updated shapes after RankNormalization insertions.
563-
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
564-
ov::snippets::pass::Canonicalization,
565-
ov::snippets::pass::MatMulToBrgemm);
557+
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(
558+
Place::After,
559+
ov::snippets::pass::FuseTransposeBrgemm,
560+
pass::EnforcePrecision,
561+
element::f32,
562+
context->getConfig().inferencePrecision,
563+
[](const std::shared_ptr<ov::Node>& op) {
564+
std::set<std::vector<ov::element::Type>> types;
565+
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
566+
const auto& a_port =
567+
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(op->input(0));
568+
// WA: We can't perform precision enforcement in case of strided access to A matrix:
569+
// snippets eltwise loops for precision conversion are generated by last 2 dims,
570+
// which are not [M, K] in case of strided access in brgemm A
571+
// There are no limitations for B matrix, since precision conversion is fused in BrgemmCopyB
572+
// Ticket: 177121
573+
if (ov::snippets::utils::is_planar_layout(a_port->get_layout())) {
574+
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
575+
types.insert({ov::element::f16, ov::element::f16});
576+
}
577+
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
578+
types.insert({ov::element::bf16, ov::element::bf16});
579+
}
580+
}
581+
}
582+
return types;
583+
});
584+
// Note: EnforcePrecision might also eliminate Convert pairs (e.g. bf16->f32->bf16),
585+
// so FuseTransposeBrgemm has to be run after it as well
566586
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
567-
ov::snippets::pass::MatMulToBrgemm,
568587
pass::EnforcePrecision,
569-
element::f32,
570-
context->getConfig().inferencePrecision,
571-
[](const std::shared_ptr<ov::Node>& op) {
572-
std::set<std::vector<ov::element::Type>> types;
573-
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
574-
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
575-
types.insert({ov::element::f16, ov::element::f16});
576-
}
577-
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
578-
types.insert({ov::element::bf16, ov::element::bf16});
579-
}
580-
}
581-
return types;
582-
});
588+
ov::snippets::pass::FuseTransposeBrgemm);
583589
}
584590

585591
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ bool pass::EliminateBrgemmCopyB::run_on_model(const std::shared_ptr<ov::Model>&
6666
// Since repacking is moved out of Subgraph body,
6767
// the rest weights subgraph must be updated with precision after repacking
6868
param->set_element_type(copy_b_node->get_config().wei_dt());
69+
// Note: validation is called manually since set_element_type doesn't update output element type
70+
param->validate_and_infer_types();
6971
if (pattern_map.count(m_rank_norm)) {
7072
pattern_map.at(m_rank_norm).get_node_shared_ptr()->validate_and_infer_types();
7173
}

src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
160160
MHA,
161161
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
162162
::testing::ValuesIn(precision_bf16_if_supported(4)),
163-
::testing::Values(ov::element::f32),
163+
::testing::Values(ov::element::bf16),
164164
::testing::Values(false),
165165
::testing::Values(MHA::default_thread_count),
166-
::testing::Values(8), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
167-
::testing::Values(6), // MHA + 5 Converts on inputs and output
166+
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
167+
::testing::Values(2), // decomposed Transpose + MHA
168168
::testing::Values(ov::test::utils::DEVICE_CPU),
169169
::testing::Values(CPUTestUtils::empty_plugin_config)),
170170
MHA::getTestCaseName);
@@ -182,6 +182,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
182182
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
183183
MHA::getTestCaseName);
184184

185+
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16_f32_in_prc,
186+
MHA,
187+
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
188+
::testing::ValuesIn(precision_f32(4)),
189+
::testing::Values(ov::element::f32),
190+
::testing::ValuesIn({false}),
191+
::testing::Values(MHA::default_thread_count),
192+
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
193+
::testing::Values(2), // decomposed Transpose + MHA
194+
::testing::Values(ov::test::utils::DEVICE_CPU),
195+
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
196+
MHA::getTestCaseName);
197+
185198
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply,
186199
MHA,
187200
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),

src/tests/functional/plugin/shared/src/snippets/mha.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ void MHABase::generate_inputs(const std::vector<ov::Shape>& targetInputStaticSha
2828
const auto& model_input = model_inputs[i];
2929
ov::Tensor tensor;
3030
ov::test::utils::InputGenerateData in_data;
31+
const bool bf16_precision =
32+
configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>() == ov::element::bf16 ||
33+
model_input.get_element_type() == ov::element::bf16;
3134
// To avoid big relative errors in the vicinity of zero, only positive values are generated for bf16 precision
32-
in_data.start_from = model_input.get_element_type() == ov::element::bf16 ? 0 : -1;
35+
in_data.start_from = bf16_precision ? 0 : -1;
3336
in_data.range = 2;
3437
in_data.resolution = 256;
3538
tensor =
@@ -55,16 +58,17 @@ void MHABase::SetUp() {
5558
setInferenceType(prc);
5659
}
5760

58-
void MHABase::init_thresholds() {
61+
void MHABase::init_thresholds() {
5962
// Note: Libxsmm calculates Exp in a slightly different way, so the abs values might differ a bit. Ticket: 130699
6063
#ifdef SNIPPETS_LIBXSMM_TPP
6164
abs_threshold = 1e-6;
6265
#endif
63-
if (inType == ov::element::bf16)
66+
auto infer_precision = configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>();
67+
if (infer_precision == ov::element::bf16)
6468
rel_threshold = 0.05f;
65-
if (inType == ov::element::f16)
69+
if (infer_precision == ov::element::f16)
6670
abs_threshold = 2e-2;
67-
}
71+
}
6872

6973
std::string MHA::getTestCaseName(const testing::TestParamInfo<ov::test::snippets::MHAParams>& obj) {
7074
const auto& [input_shapes,

0 commit comments

Comments
 (0)