diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 878079c4fe4e8..3155e7dc38b64 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1711,6 +1711,34 @@ m_BitCast(const OpTy &Op) { return CastOperator_match(Op); } +template struct ElementWiseBitCast_match { + Op_t Op; + + ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {} + + template bool match(OpTy *V) { + BitCastInst *I = dyn_cast(V); + if (!I) + return false; + Type *SrcType = I->getSrcTy(); + Type *DstType = I->getType(); + // Make sure the bitcast doesn't change between scalar and vector and + // doesn't change the number of vector elements. + if (SrcType->isVectorTy() != DstType->isVectorTy()) + return false; + if (VectorType *SrcVecTy = dyn_cast(SrcType); + SrcVecTy && SrcVecTy->getElementCount() != + cast(DstType)->getElementCount()) + return false; + return Op.match(I->getOperand(0)); + } +}; + +template +inline ElementWiseBitCast_match m_ElementWiseBitCast(const OpTy &Op) { + return ElementWiseBitCast_match(Op); +} + /// Matches PtrToInt. template inline CastOperator_match diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 2793b798f35f3..01b017142cfcb 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, // floating-point casts: // icmp slt (bitcast (uitofp X)), 0 --> false // icmp sgt (bitcast (uitofp X)), -1 --> true - if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) { + if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) { if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) return ConstantInt::getFalse(ITy); if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes())) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 6a827e2f3a963..7b93848eab351 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2531,14 +2531,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // Assumes any IEEE-represented type has the sign bit in the high bit. // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt Value *CastOp; - if (match(Op0, m_BitCast(m_Value(CastOp))) && + if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) && match(Op1, m_MaxSignedValue()) && !Builder.GetInsertBlock()->getParent()->hasFnAttribute( - Attribute::NoImplicitFloat)) { + Attribute::NoImplicitFloat)) { Type *EltTy = CastOp->getType()->getScalarType(); - if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && - EltTy->getPrimitiveSizeInBits() == - I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) { Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); return new BitCastInst(FAbs, I.getType()); } @@ -3963,13 +3961,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // This is generous interpretation of noimplicitfloat, this is not a true // floating-point operation. Value *CastOp; - if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) && + match(Op1, m_SignMask()) && !Builder.GetInsertBlock()->getParent()->hasFnAttribute( Attribute::NoImplicitFloat)) { Type *EltTy = CastOp->getType()->getScalarType(); - if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && - EltTy->getPrimitiveSizeInBits() == - I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) { Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); Value *FNegFAbs = Builder.CreateFNeg(FAbs); return new BitCastInst(FNegFAbs, I.getType()); @@ -4739,13 +4736,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // Assumes any IEEE-represented type has the sign bit in the high bit. // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt Value *CastOp; - if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) && + match(Op1, m_SignMask()) && !Builder.GetInsertBlock()->getParent()->hasFnAttribute( Attribute::NoImplicitFloat)) { Type *EltTy = CastOp->getType()->getScalarType(); - if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && - EltTy->getPrimitiveSizeInBits() == - I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) { Value *FNeg = Builder.CreateFNeg(CastOp); return new BitCastInst(FNeg, I.getType()); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 58f0763bb0c0c..ed47de287302e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -182,9 +182,15 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() || (CI.getOpcode() == Instruction::Trunc && shouldChangeType(CI.getSrcTy(), CI.getType()))) { - if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { - replaceAllDbgUsesWith(*Sel, *NV, CI, DT); - return NV; + + // If it's a bitcast involving vectors, make sure it has the same number + // of elements on both sides. + if (CI.getOpcode() != Instruction::BitCast || + match(&CI, m_ElementWiseBitCast(m_Value()))) { + if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { + replaceAllDbgUsesWith(*Sel, *NV, CI, DT); + return NV; + } } } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 380cb3504209d..cda1061fb35a8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, Value *V; if (!Cmp.getParent()->getParent()->hasFnAttribute( Attribute::NoImplicitFloat) && - Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) { - Type *SrcType = V->getType(); - Type *DstType = X->getType(); - Type *FPType = SrcType->getScalarType(); - // Make sure the bitcast doesn't change between scalar and vector and - // doesn't change the number of vector elements. - if (SrcType->isVectorTy() == DstType->isVectorTy() && - SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() && - FPType->isIEEELikeFPTy() && C1 == *C2) { + Cmp.isEquality() && + match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) { + Type *FPType = V->getType()->getScalarType(); + if (FPType->isIEEELikeFPTy() && C1 == *C2) { APInt ExponentMask = APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt(); if (C1 == ExponentMask) { @@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { // Ignore signbit of bitcasted int when comparing equality to FP 0.0: // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0 if (match(Op1, m_PosZeroFP()) && - match(Op0, m_OneUse(m_BitCast(m_Value(X)))) && - X->getType()->isVectorTy() == OpType->isVectorTy() && - X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) { + match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) { ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE; if (Pred == FCmpInst::FCMP_OEQ) IntPred = ICmpInst::ICMP_EQ; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 2756f81ed9e62..527037881edb1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2365,9 +2365,6 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, Value *FVal = Sel.getFalseValue(); Type *SelType = Sel.getType(); - if (ICmpInst::makeCmpResultType(TVal->getType()) != Cond->getType()) - return nullptr; - // Match select ?, TC, FC where the constants are equal but negated. // TODO: Generalize to handle a negated variable operand? const APFloat *TC, *FC; @@ -2382,7 +2379,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, const APInt *C; bool IsTrueIfSignSet; ICmpInst::Predicate Pred; - if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)), + m_APInt(C)))) || !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType) return nullptr; @@ -2770,8 +2768,6 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, // Match select with (icmp slt (bitcast X to int), 0) // or (icmp sgt (bitcast X to int), -1) - if (ICmpInst::makeCmpResultType(SI.getType()) != CondVal->getType()) - return ChangedFMF ? &SI : nullptr; for (bool Swap : {false, true}) { Value *TrueVal = SI.getTrueValue(); @@ -2783,7 +2779,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, CmpInst::Predicate Pred; const APInt *C; bool TrueIfSigned; - if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) || + if (!match(CondVal, + m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) || !IC.isSignBitCheck(Pred, *C, TrueIfSigned)) continue; if (!match(TrueVal, m_FNeg(m_Specific(X)))) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 4e88a5cc535b1..651e852bf6ed0 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1474,21 +1474,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, if (SI->getType()->isIntOrIntVectorTy(1)) return nullptr; - // If it's a bitcast involving vectors, make sure it has the same number of - // elements on both sides. - if (auto *BC = dyn_cast(&Op)) { - VectorType *DestTy = dyn_cast(BC->getDestTy()); - VectorType *SrcTy = dyn_cast(BC->getSrcTy()); - - // Verify that either both or neither are vectors. - if ((SrcTy == nullptr) != (DestTy == nullptr)) - return nullptr; - - // If vectors, verify that they have the same number of elements. - if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount()) - return nullptr; - } - // Test if a FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand diff --git a/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll b/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll index 8014133c5d373..5a61a060785ff 100644 --- a/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll +++ b/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll @@ -57,6 +57,19 @@ define <2 x i1> @i32_cast_cmp_sgt_int_m1_uitofp_float_vec(<2 x i32> %i) { ret <2 x i1> %cmp } +define i1 @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_mismatch(<2 x i32> %i) { +; CHECK-LABEL: @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_mismatch( +; CHECK-NEXT: [[F:%.*]] = uitofp <2 x i32> [[I:%.*]] to <2 x float> +; CHECK-NEXT: [[B:%.*]] = bitcast <2 x float> [[F]] to i64 +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i64 [[B]], -1 +; CHECK-NEXT: ret i1 [[CMP]] +; + %f = uitofp <2 x i32> %i to <2 x float> + %b = bitcast <2 x float> %f to i64 + %cmp = icmp sgt i64 %b, -1 + ret i1 %cmp +} + define <3 x i1> @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_undef(<3 x i32> %i) { ; CHECK-LABEL: @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_undef( ; CHECK-NEXT: ret <3 x i1> diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index 9e9e41b8fbad0..885b1346cde1e 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -530,6 +530,48 @@ TEST_F(PatternMatchTest, ZExtSExtSelf) { EXPECT_TRUE(m_ZExtOrSExtOrSelf(m_One()).match(One64S)); } +TEST_F(PatternMatchTest, BitCast) { + Value *OneDouble = ConstantFP::get(IRB.getDoubleTy(), APFloat(1.0)); + Value *ScalableDouble = ConstantFP::get( + VectorType::get(IRB.getDoubleTy(), 2, /*Scalable=*/true), APFloat(1.0)); + // scalar -> scalar + Value *DoubleToI64 = IRB.CreateBitCast(OneDouble, IRB.getInt64Ty()); + // scalar -> vector + Value *DoubleToV2I32 = IRB.CreateBitCast( + OneDouble, VectorType::get(IRB.getInt32Ty(), 2, /*Scalable=*/false)); + // vector -> scalar + Value *V2I32ToDouble = IRB.CreateBitCast(DoubleToV2I32, IRB.getDoubleTy()); + // vector -> vector (same count) + Value *V2I32ToV2Float = IRB.CreateBitCast( + DoubleToV2I32, VectorType::get(IRB.getFloatTy(), 2, /*Scalable=*/false)); + // vector -> vector (different count) + Value *V2I32TOV4I16 = IRB.CreateBitCast( + DoubleToV2I32, VectorType::get(IRB.getInt16Ty(), 4, /*Scalable=*/false)); + // scalable vector -> scalable vector (same count) + Value *NXV2DoubleToNXV2I64 = IRB.CreateBitCast( + ScalableDouble, VectorType::get(IRB.getInt64Ty(), 2, /*Scalable=*/true)); + // scalable vector -> scalable vector (different count) + Value *NXV2I64ToNXV4I32 = IRB.CreateBitCast( + NXV2DoubleToNXV2I64, + VectorType::get(IRB.getInt32Ty(), 4, /*Scalable=*/true)); + + EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToI64)); + EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToV2I32)); + EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToDouble)); + EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToV2Float)); + EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32TOV4I16)); + EXPECT_TRUE(m_BitCast(m_Value()).match(NXV2DoubleToNXV2I64)); + EXPECT_TRUE(m_BitCast(m_Value()).match(NXV2I64ToNXV4I32)); + + EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(DoubleToI64)); + EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(DoubleToV2I32)); + EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32ToDouble)); + EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(V2I32ToV2Float)); + EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32TOV4I16)); + EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(NXV2DoubleToNXV2I64)); + EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32)); +} + TEST_F(PatternMatchTest, Power2) { Value *C128 = IRB.getInt32(128); Value *CNeg128 = ConstantExpr::getNeg(cast(C128));