diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index 58865c296ed8a..f2a7f16e19a79 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -461,21 +461,14 @@ m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) { return m_Select(Op0, m_True(), Op1); } -using VPCanonicalIVPHI_match = - Recipe_match, 0, false, VPCanonicalIVPHIRecipe>; - -inline VPCanonicalIVPHI_match m_CanonicalIV() { - return VPCanonicalIVPHI_match(); -} - -template +template using VPScalarIVSteps_match = - Recipe_match, 0, false, VPScalarIVStepsRecipe>; + TernaryRecipe_match; -template -inline VPScalarIVSteps_match m_ScalarIVSteps(const Op0_t &Op0, - const Op1_t &Op1) { - return VPScalarIVSteps_match(Op0, Op1); +template +inline VPScalarIVSteps_match +m_ScalarIVSteps(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { + return VPScalarIVSteps_match({Op0, Op1, Op2}); } template diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 2db4957409c8d..82b2ed242b0cb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -62,7 +62,9 @@ bool vputils::isHeaderMask(const VPValue *V, VPlan &Plan) { if (match(V, m_ActiveLaneMask(m_VPValue(A), m_VPValue(B)))) return B == Plan.getTripCount() && - (match(A, m_ScalarIVSteps(m_CanonicalIV(), m_SpecificInt(1))) || + (match(A, m_ScalarIVSteps(m_Specific(Plan.getCanonicalIV()), + m_SpecificInt(1), + m_Specific(&Plan.getVF()))) || IsWideCanonicalIV(A)); return match(V, m_Binary(m_VPValue(A), m_VPValue(B))) && diff --git a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt index 0df39c41a9041..53eeff28c185f 100644 --- a/llvm/unittests/Transforms/Vectorize/CMakeLists.txt +++ b/llvm/unittests/Transforms/Vectorize/CMakeLists.txt @@ -12,6 +12,7 @@ add_llvm_unittest(VectorizeTests VPlanTest.cpp VPDomTreeTest.cpp VPlanHCFGTest.cpp + VPlanPatternMatchTest.cpp VPlanSlpTest.cpp VPlanVerifierTest.cpp ) diff --git a/llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp new file mode 100644 index 0000000000000..e38b4fad80b0e --- /dev/null +++ b/llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp @@ -0,0 +1,55 @@ +//===- llvm/unittests/Transforms/Vectorize/VPlanPatternMatchTest.cpp ------===// +// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../lib/Transforms/Vectorize/VPlanPatternMatch.h" +#include "../lib/Transforms/Vectorize/LoopVectorizationPlanner.h" +#include "../lib/Transforms/Vectorize/VPlan.h" +#include "../lib/Transforms/Vectorize/VPlanHelpers.h" +#include "VPlanTestBase.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "gtest/gtest.h" + +namespace llvm { + +namespace { +using VPPatternMatchTest = VPlanTestBase; + +TEST_F(VPPatternMatchTest, ScalarIVSteps) { + VPlan &Plan = getPlan(); + VPBasicBlock *VPBB = Plan.createVPBasicBlock(""); + VPBuilder Builder(VPBB); + + IntegerType *I64Ty = IntegerType::get(C, 64); + VPValue *StartV = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 0)); + auto *CanonicalIVPHI = new VPCanonicalIVPHIRecipe(StartV, DebugLoc()); + Builder.insert(CanonicalIVPHI); + + VPValue *Inc = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 1)); + VPValue *VF = &Plan.getVF(); + VPValue *Steps = Builder.createScalarIVSteps( + Instruction::Add, nullptr, CanonicalIVPHI, Inc, VF, DebugLoc()); + + VPValue *Inc2 = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 2)); + VPValue *Steps2 = Builder.createScalarIVSteps( + Instruction::Add, nullptr, CanonicalIVPHI, Inc2, VF, DebugLoc()); + + using namespace VPlanPatternMatch; + + ASSERT_TRUE(match(Steps, m_ScalarIVSteps(m_Specific(CanonicalIVPHI), + m_SpecificInt(1), m_Specific(VF)))); + ASSERT_FALSE( + match(Steps2, m_ScalarIVSteps(m_Specific(CanonicalIVPHI), + m_SpecificInt(1), m_Specific(VF)))); + ASSERT_TRUE(match(Steps2, m_ScalarIVSteps(m_Specific(CanonicalIVPHI), + m_SpecificInt(2), m_Specific(VF)))); +} + +} // namespace +} // namespace llvm