Skip to content

Commit e52238b

Browse files
authored
[AArch64] Add @llvm.experimental.vector.match (#101974)
This patch introduces an experimental intrinsic for matching the elements of one vector against the elements of another. For AArch64 targets that support SVE2, the intrinsic lowers to a MATCH instruction for supported fixed and scalar vector types.
1 parent debfd7b commit e52238b

File tree

8 files changed

+735
-0
lines changed

8 files changed

+735
-0
lines changed

llvm/docs/LangRef.rst

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20091,6 +20091,44 @@ are undefined.
2009120091
}
2009220092

2009320093

20094+
'``llvm.experimental.vector.match.*``' Intrinsic
20095+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20096+
20097+
Syntax:
20098+
"""""""
20099+
20100+
This is an overloaded intrinsic.
20101+
20102+
::
20103+
20104+
declare <<n> x i1> @llvm.experimental.vector.match(<<n> x <ty>> %op1, <<m> x <ty>> %op2, <<n> x i1> %mask)
20105+
declare <vscale x <n> x i1> @llvm.experimental.vector.match(<vscale x <n> x <ty>> %op1, <<m> x <ty>> %op2, <vscale x <n> x i1> %mask)
20106+
20107+
Overview:
20108+
"""""""""
20109+
20110+
Find active elements of the first argument matching any elements of the second.
20111+
20112+
Arguments:
20113+
""""""""""
20114+
20115+
The first argument is the search vector, the second argument the vector of
20116+
elements we are searching for (i.e. for which we consider a match successful),
20117+
and the third argument is a mask that controls which elements of the first
20118+
argument are active. The first two arguments must be vectors of matching
20119+
integer element types. The first and third arguments and the result type must
20120+
have matching element counts (fixed or scalable). The second argument must be a
20121+
fixed vector, but its length may be different from the remaining arguments.
20122+
20123+
Semantics:
20124+
""""""""""
20125+
20126+
The '``llvm.experimental.vector.match``' intrinsic compares each active element
20127+
in the first argument against the elements of the second argument, placing
20128+
``1`` in the corresponding element of the output vector if any equality
20129+
comparison is successful, and ``0`` otherwise. Inactive elements in the mask
20130+
are set to ``0`` in the output.
20131+
2009420132
Matrix Intrinsics
2009520133
-----------------
2009620134

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,13 @@ class TargetLoweringBase {
483483
bool ZeroIsPoison,
484484
const ConstantRange *VScaleRange) const;
485485

486+
/// Return true if the @llvm.experimental.vector.match intrinsic should be
487+
/// expanded for vector type `VT' and search size `SearchSize' using generic
488+
/// code in SelectionDAGBuilder.
489+
virtual bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const {
490+
return true;
491+
}
492+
486493
// Return true if op(vecreduce(x), vecreduce(y)) should be reassociated to
487494
// vecreduce(op(x, y)) for the reduction opcode RedOpc.
488495
virtual bool shouldReassociateReduction(unsigned RedOpc, EVT VT) const {

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,14 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
19201920
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
19211921
[ IntrArgMemOnly ]>;
19221922

1923+
// Experimental match
1924+
def int_experimental_vector_match : DefaultAttrsIntrinsic<
1925+
[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
1926+
[ llvm_anyvector_ty,
1927+
llvm_anyvector_ty,
1928+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], // Mask
1929+
[ IntrNoMem, IntrNoSync, IntrWillReturn ]>;
1930+
19231931
// Operators
19241932
let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
19251933
// Integer arithmetic

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8175,6 +8175,36 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81758175
DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ResultVT, Vec, Index));
81768176
return;
81778177
}
8178+
case Intrinsic::experimental_vector_match: {
8179+
SDValue Op1 = getValue(I.getOperand(0));
8180+
SDValue Op2 = getValue(I.getOperand(1));
8181+
SDValue Mask = getValue(I.getOperand(2));
8182+
EVT Op1VT = Op1.getValueType();
8183+
EVT Op2VT = Op2.getValueType();
8184+
EVT ResVT = Mask.getValueType();
8185+
unsigned SearchSize = Op2VT.getVectorNumElements();
8186+
8187+
// If the target has native support for this vector match operation, lower
8188+
// the intrinsic untouched; otherwise, expand it below.
8189+
if (!TLI.shouldExpandVectorMatch(Op1VT, SearchSize)) {
8190+
visitTargetIntrinsic(I, Intrinsic);
8191+
return;
8192+
}
8193+
8194+
SDValue Ret = DAG.getConstant(0, sdl, ResVT);
8195+
8196+
for (unsigned i = 0; i < SearchSize; ++i) {
8197+
SDValue Op2Elem = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
8198+
Op2VT.getVectorElementType(), Op2,
8199+
DAG.getVectorIdxConstant(i, sdl));
8200+
SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, sdl, Op1VT, Op2Elem);
8201+
SDValue Cmp = DAG.getSetCC(sdl, ResVT, Op1, Splat, ISD::SETEQ);
8202+
Ret = DAG.getNode(ISD::OR, sdl, ResVT, Ret, Cmp);
8203+
}
8204+
8205+
setValue(&I, DAG.getNode(ISD::AND, sdl, ResVT, Ret, Mask));
8206+
return;
8207+
}
81788208
case Intrinsic::vector_reverse:
81798209
visitVectorReverse(I);
81808210
return;

llvm/lib/IR/Verifier.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6150,6 +6150,31 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
61506150
&Call);
61516151
break;
61526152
}
6153+
case Intrinsic::experimental_vector_match: {
6154+
Value *Op1 = Call.getArgOperand(0);
6155+
Value *Op2 = Call.getArgOperand(1);
6156+
Value *Mask = Call.getArgOperand(2);
6157+
6158+
VectorType *Op1Ty = dyn_cast<VectorType>(Op1->getType());
6159+
VectorType *Op2Ty = dyn_cast<VectorType>(Op2->getType());
6160+
VectorType *MaskTy = dyn_cast<VectorType>(Mask->getType());
6161+
6162+
Check(Op1Ty && Op2Ty && MaskTy, "Operands must be vectors.", &Call);
6163+
Check(isa<FixedVectorType>(Op2Ty),
6164+
"Second operand must be a fixed length vector.", &Call);
6165+
Check(Op1Ty->getElementType()->isIntegerTy(),
6166+
"First operand must be a vector of integers.", &Call);
6167+
Check(Op1Ty->getElementType() == Op2Ty->getElementType(),
6168+
"First two operands must have the same element type.", &Call);
6169+
Check(Op1Ty->getElementCount() == MaskTy->getElementCount(),
6170+
"First operand and mask must have the same number of elements.",
6171+
&Call);
6172+
Check(MaskTy->getElementType()->isIntegerTy(1),
6173+
"Mask must be a vector of i1's.", &Call);
6174+
Check(Call.getType() == MaskTy, "Return type must match the mask type.",
6175+
&Call);
6176+
break;
6177+
}
61536178
case Intrinsic::vector_insert: {
61546179
Value *Vec = Call.getArgOperand(0);
61556180
Value *SubVec = Call.getArgOperand(1);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,19 @@ bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
20592059
VT != MVT::v4i1 && VT != MVT::v2i1;
20602060
}
20612061

2062+
bool AArch64TargetLowering::shouldExpandVectorMatch(EVT VT,
2063+
unsigned SearchSize) const {
2064+
// MATCH is SVE2 and only available in non-streaming mode.
2065+
if (!Subtarget->hasSVE2() || !Subtarget->isSVEAvailable())
2066+
return true;
2067+
// Furthermore, we can only use it for 8-bit or 16-bit elements.
2068+
if (VT == MVT::nxv8i16 || VT == MVT::v8i16)
2069+
return SearchSize != 8;
2070+
if (VT == MVT::nxv16i8 || VT == MVT::v16i8 || VT == MVT::v8i8)
2071+
return SearchSize != 8 && SearchSize != 16;
2072+
return true;
2073+
}
2074+
20622075
void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
20632076
assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
20642077

@@ -5780,6 +5793,72 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
57805793
DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
57815794
}
57825795

5796+
SDValue LowerVectorMatch(SDValue Op, SelectionDAG &DAG) {
5797+
SDLoc dl(Op);
5798+
SDValue ID =
5799+
DAG.getTargetConstant(Intrinsic::aarch64_sve_match, dl, MVT::i64);
5800+
5801+
auto Op1 = Op.getOperand(1);
5802+
auto Op2 = Op.getOperand(2);
5803+
auto Mask = Op.getOperand(3);
5804+
5805+
EVT Op1VT = Op1.getValueType();
5806+
EVT Op2VT = Op2.getValueType();
5807+
EVT ResVT = Op.getValueType();
5808+
5809+
assert((Op1VT.getVectorElementType() == MVT::i8 ||
5810+
Op1VT.getVectorElementType() == MVT::i16) &&
5811+
"Expected 8-bit or 16-bit characters.");
5812+
5813+
// Scalable vector type used to wrap operands.
5814+
// A single container is enough for both operands because ultimately the
5815+
// operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
5816+
EVT OpContainerVT = Op1VT.isScalableVector()
5817+
? Op1VT
5818+
: getContainerForFixedLengthVector(DAG, Op1VT);
5819+
5820+
if (Op2VT.is128BitVector()) {
5821+
// If Op2 is a full 128-bit vector, wrap it trivially in a scalable vector.
5822+
Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
5823+
// Further, if the result is scalable, broadcast Op2 to a full SVE register.
5824+
if (ResVT.isScalableVector())
5825+
Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
5826+
DAG.getTargetConstant(0, dl, MVT::i64));
5827+
} else {
5828+
// If Op2 is not a full 128-bit vector, we always need to broadcast it.
5829+
unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
5830+
MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
5831+
EVT Op2PromotedVT = getPackedSVEVectorVT(Op2IntVT);
5832+
Op2 = DAG.getBitcast(MVT::getVectorVT(Op2IntVT, 1), Op2);
5833+
Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT, Op2,
5834+
DAG.getConstant(0, dl, MVT::i64));
5835+
Op2 = DAG.getSplatVector(Op2PromotedVT, dl, Op2);
5836+
Op2 = DAG.getBitcast(OpContainerVT, Op2);
5837+
}
5838+
5839+
// If the result is scalable, we just need to carry out the MATCH.
5840+
if (ResVT.isScalableVector())
5841+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1, Op2);
5842+
5843+
// If the result is fixed, we can still use MATCH but we need to wrap the
5844+
// first operand and the mask in scalable vectors before doing so.
5845+
5846+
// Wrap the operands.
5847+
Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
5848+
Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, Op1VT, Mask);
5849+
Mask = convertFixedMaskToScalableVector(Mask, DAG);
5850+
5851+
// Carry out the match.
5852+
SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Mask.getValueType(),
5853+
ID, Mask, Op1, Op2);
5854+
5855+
// Extract and promote the match result (nxv16i1/nxv8i1) to ResVT
5856+
// (v16i8/v8i8).
5857+
Match = DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match);
5858+
Match = convertFromScalableVector(DAG, Op1VT, Match);
5859+
return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
5860+
}
5861+
57835862
SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
57845863
SelectionDAG &DAG) const {
57855864
unsigned IntNo = Op.getConstantOperandVal(1);
@@ -6383,6 +6462,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
63836462
DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
63846463
return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
63856464
}
6465+
case Intrinsic::experimental_vector_match: {
6466+
return LowerVectorMatch(Op, DAG);
6467+
}
63866468
}
63876469
}
63886470

@@ -27153,6 +27235,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
2715327235
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
2715427236
return;
2715527237
}
27238+
case Intrinsic::experimental_vector_match:
2715627239
case Intrinsic::get_active_lane_mask: {
2715727240
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
2715827241
return;

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,8 @@ class AArch64TargetLowering : public TargetLowering {
985985

986986
bool shouldExpandCttzElements(EVT VT) const override;
987987

988+
bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;
989+
988990
/// If a change in streaming mode is required on entry to/return from a
989991
/// function call it emits and returns the corresponding SMSTART or SMSTOP
990992
/// node. \p Condition should be one of the enum values from

0 commit comments

Comments
 (0)