Skip to content

Commit f1faba2

Browse files
committed
[InstCombine][X86] Add constant folding for PMADDWD/PMADDUBSW intrinsics
1 parent 66caf01 commit f1faba2

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,21 +503,53 @@ static Value *simplifyX86pack(IntrinsicInst &II,
503503
}
504504

505505
static Value *simplifyX86pmadd(IntrinsicInst &II,
506-
InstCombiner::BuilderTy &Builder) {
506+
InstCombiner::BuilderTy &Builder,
507+
bool IsPMADDWD) {
507508
Value *Arg0 = II.getArgOperand(0);
508509
Value *Arg1 = II.getArgOperand(1);
509510
auto *ResTy = cast<FixedVectorType>(II.getType());
510511
[[maybe_unused]] auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
511512

512-
assert(ArgTy->getNumElements() == (2 * ResTy->getNumElements()) &&
513+
unsigned NumDstElts = ResTy->getNumElements();
514+
assert(ArgTy->getNumElements() == (2 * NumDstElts) &&
513515
ResTy->getScalarSizeInBits() == (2 * ArgTy->getScalarSizeInBits()) &&
514516
"Unexpected PMADD types");
515517

516518
// Multiply by zero.
517519
if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1))
518520
return ConstantAggregateZero::get(ResTy);
519521

520-
return nullptr;
522+
// Constant folding.
523+
if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
524+
return nullptr;
525+
526+
// Split Lo/Hi elements pairs, extend and add together.
527+
// PMADDWD(X,Y) =
528+
// add(mul(sext(lhs[0]),sext(rhs[0])),mul(sext(lhs[1]),sext(rhs[1])))
529+
// PMADDUBSW(X,Y) =
530+
// sadd_sat(mul(zext(lhs[0]),sext(rhs[0])),mul(zext(lhs[1]),sext(rhs[1])))
531+
SmallVector<int> LoMask, HiMask;
532+
for (unsigned I = 0; I != NumDstElts; ++I) {
533+
LoMask.push_back(2 * I + 0);
534+
HiMask.push_back(2 * I + 1);
535+
}
536+
537+
auto *LHSLo = Builder.CreateShuffleVector(Arg0, LoMask);
538+
auto *LHSHi = Builder.CreateShuffleVector(Arg0, HiMask);
539+
auto *RHSLo = Builder.CreateShuffleVector(Arg1, LoMask);
540+
auto *RHSHi = Builder.CreateShuffleVector(Arg1, HiMask);
541+
542+
auto LHSCast =
543+
IsPMADDWD ? Instruction::CastOps::SExt : Instruction::CastOps::ZExt;
544+
LHSLo = Builder.CreateCast(LHSCast, LHSLo, ResTy);
545+
LHSHi = Builder.CreateCast(LHSCast, LHSHi, ResTy);
546+
RHSLo = Builder.CreateCast(Instruction::CastOps::SExt, RHSLo, ResTy);
547+
RHSHi = Builder.CreateCast(Instruction::CastOps::SExt, RHSHi, ResTy);
548+
Value *Lo = Builder.CreateMul(LHSLo, RHSLo);
549+
Value *Hi = Builder.CreateMul(LHSHi, RHSHi);
550+
return IsPMADDWD
551+
? Builder.CreateAdd(Lo, Hi)
552+
: Builder.CreateIntrinsic(ResTy, Intrinsic::sadd_sat, {Lo, Hi});
521553
}
522554

523555
static Value *simplifyX86movmsk(const IntrinsicInst &II,
@@ -2499,15 +2531,15 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
24992531
case Intrinsic::x86_sse2_pmadd_wd:
25002532
case Intrinsic::x86_avx2_pmadd_wd:
25012533
case Intrinsic::x86_avx512_pmaddw_d_512:
2502-
if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
2534+
if (Value *V = simplifyX86pmadd(II, IC.Builder, true)) {
25032535
return IC.replaceInstUsesWith(II, V);
25042536
}
25052537
break;
25062538

25072539
case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
25082540
case Intrinsic::x86_avx2_pmadd_ub_sw:
25092541
case Intrinsic::x86_avx512_pmaddubs_w_512:
2510-
if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
2542+
if (Value *V = simplifyX86pmadd(II, IC.Builder, false)) {
25112543
return IC.replaceInstUsesWith(II, V);
25122544
}
25132545
break;

llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,23 @@ define <32 x i16> @zero_pmaddubsw_512_commute(<64 x i8> %a0) {
117117

118118
define <8 x i16> @fold_pmaddubsw_128() {
119119
; CHECK-LABEL: @fold_pmaddubsw_128(
120-
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
121-
; CHECK-NEXT: ret <8 x i16> [[TMP1]]
120+
; CHECK-NEXT: ret <8 x i16> <i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450>
122121
;
123122
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
124123
ret <8 x i16> %1
125124
}
126125

127126
define <16 x i16> @fold_pmaddubsw_256() {
128127
; CHECK-LABEL: @fold_pmaddubsw_256(
129-
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>, <32 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>)
130-
; CHECK-NEXT: ret <16 x i16> [[TMP1]]
128+
; CHECK-NEXT: ret <16 x i16> <i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -256, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450>
131129
;
132130
%1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>, <32 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>)
133131
ret <16 x i16> %1
134132
}
135133

136134
define <32 x i16> @fold_pmaddubsw_512() {
137135
; CHECK-LABEL: @fold_pmaddubsw_512(
138-
; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <64 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
139-
; CHECK-NEXT: ret <32 x i16> [[TMP1]]
136+
; CHECK-NEXT: ret <32 x i16> <i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -256, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -256, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450>
140137
;
141138
%1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <64 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
142139
ret <32 x i16> %1

llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,23 @@ define <16 x i32> @zero_pmaddwd_512_commute(<32 x i16> %a0) {
117117

118118
define <4 x i32> @fold_pmaddwd_128() {
119119
; CHECK-LABEL: @fold_pmaddwd_128(
120-
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
121-
; CHECK-NEXT: ret <4 x i32> [[TMP1]]
120+
; CHECK-NEXT: ret <4 x i32> <i32 19, i32 -229364, i32 -21, i32 -491429>
122121
;
123122
%1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
124123
ret <4 x i32> %1
125124
}
126125

127126
define <8 x i32> @fold_pmaddwd_256() {
128127
; CHECK-LABEL: @fold_pmaddwd_256(
129-
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
130-
; CHECK-NEXT: ret <8 x i32> [[TMP1]]
128+
; CHECK-NEXT: ret <8 x i32> <i32 -7, i32 32762, i32 91, i32 32750, i32 -239, i32 687938, i32 -451, i32 -32756>
131129
;
132130
%1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
133131
ret <8 x i32> %1
134132
}
135133

136134
define <16 x i32> @fold_pmaddwd_512() {
137135
; CHECK-LABEL: @fold_pmaddwd_512(
138-
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
139-
; CHECK-NEXT: ret <16 x i32> [[TMP1]]
136+
; CHECK-NEXT: ret <16 x i32> <i32 -7, i32 32762, i32 91, i32 32750, i32 -239, i32 687938, i32 -451, i32 -32756, i32 -7, i32 32762, i32 91, i32 32750, i32 -239, i32 687938, i32 -451, i32 -32756>
140137
;
141138
%1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
142139
ret <16 x i32> %1

0 commit comments

Comments
 (0)