diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 436cdbff75669..5645fdd73a0d4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1283,6 +1283,64 @@ reassociateMinMaxWithConstantInOperand(IntrinsicInst *II, return CallInst::Create(MinMax, {NewInner, C}); } +/// Reduce a sequence of min/max intrinsics using distributivity +static Instruction * +factorizeMinMaxDistributivity(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + auto *LHS = dyn_cast(II->getArgOperand(0)); + auto *RHS = dyn_cast(II->getArgOperand(1)); + + if (!LHS || !RHS || LHS->getIntrinsicID() != RHS->getIntrinsicID() || + LHS->getCalledFunction()->arg_size() != 2) + return nullptr; + + Value *A = LHS->getArgOperand(0); + Value *B = LHS->getArgOperand(1); + Value *C = RHS->getArgOperand(0); + Value *D = RHS->getArgOperand(1); + Value *Outer, *InnerLHS, *InnerRHS; + + if (A != C && A != D && B != C && B != D) + return nullptr; + + if (A == C) { + Outer = A; + InnerLHS = B; + InnerRHS = D; + } else if (A == D) { + Outer = A; + InnerLHS = B; + InnerRHS = C; + } else if (B == C) { + Outer = B; + InnerLHS = A; + InnerRHS = D; + } else if (B == D) { + Outer = B; + InnerLHS = A; + InnerRHS = C; + } + + Intrinsic::ID OuterID = II->getIntrinsicID(); + Intrinsic::ID LHSID = LHS->getIntrinsicID(); + + // umax(umin(a, c), umin(b, c)) --> umin(umax(a, b), c) + // umin(umax(a, c), umax(b, c)) --> umax(umin(a, b), c) + // smax(smin(a, c), smin(b, c)) --> smin(smax(a, b), c) + // smin(smax(a, c), smax(b, c)) --> smax(smin(a, b), c) + if (LHSID == Intrinsic::umin && OuterID == Intrinsic::umax || + LHSID == Intrinsic::umax && OuterID == Intrinsic::umin || + LHSID == Intrinsic::smin && OuterID == Intrinsic::smax || + LHSID == Intrinsic::smax && OuterID == Intrinsic::smin) { + Module *Mod = II->getModule(); + Function *OuterFn = Intrinsic::getDeclaration(Mod, LHSID, II->getType()); + return CallInst::Create(OuterFn, {Outer, Builder.CreateBinaryIntrinsic( + OuterID, InnerLHS, InnerRHS)}); + } + + return nullptr; +} + /// Reduce a sequence of min/max intrinsics with a common operand. static Instruction *factorizeMinMaxTree(IntrinsicInst *II) { // Match 3 of the same min/max ops. Example: umin(umin(), umin()). @@ -1843,6 +1901,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *NewMinMax = factorizeMinMaxTree(II)) return NewMinMax; + if (Instruction *I = factorizeMinMaxDistributivity(II, Builder)) + return I; + // Try to fold minmax with constant RHS based on range information if (match(I1, m_APIntAllowPoison(RHSC))) { ICmpInst::Predicate Pred = diff --git a/llvm/test/Transforms/InstCombine/minmax-factor.ll b/llvm/test/Transforms/InstCombine/minmax-factor.ll new file mode 100644 index 0000000000000..667d165cae9f7 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/minmax-factor.ll @@ -0,0 +1,98 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i8 @umin_umax(i8 %a, i8 %b, i8 %c) { +; CHECK-LABEL: @umin_umax( +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[A:%.*]], i8 [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.umin.i8(i8 [[C:%.*]], i8 [[TMP1]]) +; CHECK-NEXT: ret i8 [[F]] +; + %d = call i8 @llvm.umin.i8(i8 %a, i8 %c) + %e = call i8 @llvm.umin.i8(i8 %b, i8 %c) + %f = call i8 @llvm.umax.i8(i8 %d, i8 %e) + ret i8 %f +} + +define i8 @umax_umin(i8 %a, i8 %b, i8 %c) { +; CHECK-LABEL: @umax_umin( +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[A:%.*]], i8 [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.umax.i8(i8 [[C:%.*]], i8 [[TMP1]]) +; CHECK-NEXT: ret i8 [[F]] +; + %d = call i8 @llvm.umax.i8(i8 %a, i8 %c) + %e = call i8 @llvm.umax.i8(i8 %b, i8 %c) + %f = call i8 @llvm.umin.i8(i8 %d, i8 %e) + ret i8 %f +} + +define i8 @smin_smax(i8 %a, i8 %b, i8 %c) { +; CHECK-LABEL: @smin_smax( +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[A:%.*]], i8 [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.smin.i8(i8 [[C:%.*]], i8 [[TMP1]]) +; CHECK-NEXT: ret i8 [[F]] +; + %d = call i8 @llvm.smin.i8(i8 %a, i8 %c) + %e = call i8 @llvm.smin.i8(i8 %b, i8 %c) + %f = call i8 @llvm.smax.i8(i8 %d, i8 %e) + ret i8 %f +} + +define i8 @smax_smin(i8 %a, i8 %b, i8 %c) { +; CHECK-LABEL: @smax_smin( +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[A:%.*]], i8 [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call i8 @llvm.smax.i8(i8 [[C:%.*]], i8 [[TMP1]]) +; CHECK-NEXT: ret i8 [[F]] +; + %d = call i8 @llvm.smax.i8(i8 %a, i8 %c) + %e = call i8 @llvm.smax.i8(i8 %b, i8 %c) + %f = call i8 @llvm.smin.i8(i8 %d, i8 %e) + ret i8 %f +} + +define <2 x i8> @umin_umax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) { +; CHECK-LABEL: @umin_umax_vector( +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]]) +; CHECK-NEXT: ret <2 x i8> [[F]] +; + %d = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %a, <2 x i8> %c) + %e = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %b, <2 x i8> %c) + %f = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %d, <2 x i8> %e) + ret <2 x i8> %f +} + +define <2 x i8> @umax_umin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) { +; CHECK-LABEL: @umax_umin_vector( +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]]) +; CHECK-NEXT: ret <2 x i8> [[F]] +; + %d = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %a, <2 x i8> %c) + %e = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %b, <2 x i8> %c) + %f = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %d, <2 x i8> %e) + ret <2 x i8> %f +} + +define <2 x i8> @smin_smax_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) { +; CHECK-LABEL: @smin_smax_vector( +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]]) +; CHECK-NEXT: ret <2 x i8> [[F]] +; + %d = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %a, <2 x i8> %c) + %e = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %b, <2 x i8> %c) + %f = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %d, <2 x i8> %e) + ret <2 x i8> %f +} + +define <2 x i8> @smax_smin_vector(<2 x i8> %a, <2 x i8> %b, <2 x i8> %c) { +; CHECK-LABEL: @smax_smin_vector( +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> [[B:%.*]]) +; CHECK-NEXT: [[F:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[C:%.*]], <2 x i8> [[TMP1]]) +; CHECK-NEXT: ret <2 x i8> [[F]] +; + %d = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %a, <2 x i8> %c) + %e = call <2 x i8> @llvm.smax.v2i8(<2 x i8> %b, <2 x i8> %c) + %f = call <2 x i8> @llvm.smin.v2i8(<2 x i8> %d, <2 x i8> %e) + ret <2 x i8> %f +}