Skip to content

Commit 6ac4fe8

Browse files
committed
[X86] X86FixupVectorConstants.cpp - refactor constant search loop to take array of sorted candidates
Pulled out of #79815 - refactors the internal FixupConstant logic to just accept an array of vzload/broadcast candidates that are pre-sorted in ascending constant pool size
1 parent 390d66b commit 6ac4fe8

File tree

1 file changed

+125
-83
lines changed

1 file changed

+125
-83
lines changed

llvm/lib/Target/X86/X86FixupVectorConstants.cpp

+125-83
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
216216

217217
// Attempt to rebuild a normalized splat vector constant of the requested splat
218218
// width, built up of potentially smaller scalar values.
219-
static Constant *rebuildSplatableConstant(const Constant *C,
220-
unsigned SplatBitWidth) {
219+
static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
220+
unsigned SplatBitWidth) {
221221
std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
222222
if (!Splat)
223223
return nullptr;
@@ -238,8 +238,8 @@ static Constant *rebuildSplatableConstant(const Constant *C,
238238
return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
239239
}
240240

241-
static Constant *rebuildZeroUpperConstant(const Constant *C,
242-
unsigned ScalarBitWidth) {
241+
static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
242+
unsigned ScalarBitWidth) {
243243
Type *Ty = C->getType();
244244
Type *SclTy = Ty->getScalarType();
245245
unsigned NumBits = Ty->getPrimitiveSizeInBits();
@@ -265,8 +265,6 @@ static Constant *rebuildZeroUpperConstant(const Constant *C,
265265
return nullptr;
266266
}
267267

268-
typedef std::function<Constant *(const Constant *, unsigned)> RebuildFn;
269-
270268
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
271269
MachineBasicBlock &MBB,
272270
MachineInstr &MI) {
@@ -277,43 +275,42 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
277275
bool HasBWI = ST->hasBWI();
278276
bool HasVLX = ST->hasVLX();
279277

280-
auto FixupConstant =
281-
[&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
282-
unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
283-
unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
284-
assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
285-
"Unexpected number of operands!");
286-
287-
if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
288-
// Attempt to detect a suitable splat/vzload from increasing constant
289-
// bitwidths.
290-
// Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
291-
std::tuple<unsigned, unsigned, RebuildFn> FixupLoad[] = {
292-
{8, OpBcst8, rebuildSplatableConstant},
293-
{16, OpBcst16, rebuildSplatableConstant},
294-
{32, OpUpper32, rebuildZeroUpperConstant},
295-
{32, OpBcst32, rebuildSplatableConstant},
296-
{64, OpUpper64, rebuildZeroUpperConstant},
297-
{64, OpBcst64, rebuildSplatableConstant},
298-
{128, OpBcst128, rebuildSplatableConstant},
299-
{256, OpBcst256, rebuildSplatableConstant},
300-
};
301-
for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
302-
if (Op) {
303-
// Construct a suitable constant and adjust the MI to use the new
304-
// constant pool entry.
305-
if (Constant *NewCst = RebuildConstant(C, BitWidth)) {
306-
unsigned NewCPI =
307-
CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8));
308-
MI.setDesc(TII->get(Op));
309-
MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
310-
return true;
311-
}
312-
}
278+
struct FixupEntry {
279+
int Op;
280+
int NumCstElts;
281+
int BitWidth;
282+
std::function<Constant *(const Constant *, unsigned, unsigned)>
283+
RebuildConstant;
284+
};
285+
auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
286+
#ifdef EXPENSIVE_CHECKS
287+
assert(llvm::is_sorted(Fixups,
288+
[](const FixupEntry &A, const FixupEntry &B) {
289+
return (A.NumCstElts * A.BitWidth) <
290+
(B.NumCstElts * B.BitWidth);
291+
}) &&
292+
"Constant fixup table not sorted in ascending constant size");
293+
#endif
294+
assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
295+
"Unexpected number of operands!");
296+
if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
297+
for (const FixupEntry &Fixup : Fixups) {
298+
if (Fixup.Op) {
299+
// Construct a suitable constant and adjust the MI to use the new
300+
// constant pool entry.
301+
if (Constant *NewCst =
302+
Fixup.RebuildConstant(C, Fixup.NumCstElts, Fixup.BitWidth)) {
303+
unsigned NewCPI =
304+
CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
305+
MI.setDesc(TII->get(Fixup.Op));
306+
MI.getOperand(OperandNo + X86::AddrDisp).setIndex(NewCPI);
307+
return true;
313308
}
314309
}
315-
return false;
316-
};
310+
}
311+
}
312+
return false;
313+
};
317314

318315
// Attempt to convert full width vector loads into broadcast/vzload loads.
319316
switch (Opc) {
@@ -323,82 +320,125 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
323320
case X86::MOVUPDrm:
324321
case X86::MOVUPSrm:
325322
// TODO: SSE3 MOVDDUP Handling
326-
return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVSDrm, X86::MOVSSrm, 1);
323+
return FixupConstant({{X86::MOVSSrm, 1, 32, rebuildZeroUpperCst},
324+
{X86::MOVSDrm, 1, 64, rebuildZeroUpperCst}},
325+
1);
327326
case X86::VMOVAPDrm:
328327
case X86::VMOVAPSrm:
329328
case X86::VMOVUPDrm:
330329
case X86::VMOVUPSrm:
331-
return FixupConstant(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0,
332-
X86::VMOVSDrm, X86::VMOVSSrm, 1);
330+
return FixupConstant({{X86::VMOVSSrm, 1, 32, rebuildZeroUpperCst},
331+
{X86::VBROADCASTSSrm, 1, 32, rebuildSplatCst},
332+
{X86::VMOVSDrm, 1, 64, rebuildZeroUpperCst},
333+
{X86::VMOVDDUPrm, 1, 64, rebuildSplatCst}},
334+
1);
333335
case X86::VMOVAPDYrm:
334336
case X86::VMOVAPSYrm:
335337
case X86::VMOVUPDYrm:
336338
case X86::VMOVUPSYrm:
337-
return FixupConstant(0, X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
338-
X86::VBROADCASTSSYrm, 0, 0, 0, 0, 1);
339+
return FixupConstant({{X86::VBROADCASTSSYrm, 1, 32, rebuildSplatCst},
340+
{X86::VBROADCASTSDYrm, 1, 64, rebuildSplatCst},
341+
{X86::VBROADCASTF128rm, 1, 128, rebuildSplatCst}},
342+
1);
339343
case X86::VMOVAPDZ128rm:
340344
case X86::VMOVAPSZ128rm:
341345
case X86::VMOVUPDZ128rm:
342346
case X86::VMOVUPSZ128rm:
343-
return FixupConstant(0, 0, X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0,
344-
0, X86::VMOVSDZrm, X86::VMOVSSZrm, 1);
347+
return FixupConstant({{X86::VMOVSSZrm, 1, 32, rebuildZeroUpperCst},
348+
{X86::VBROADCASTSSZ128rm, 1, 32, rebuildSplatCst},
349+
{X86::VMOVSDZrm, 1, 64, rebuildZeroUpperCst},
350+
{X86::VMOVDDUPZ128rm, 1, 64, rebuildSplatCst}},
351+
1);
345352
case X86::VMOVAPDZ256rm:
346353
case X86::VMOVAPSZ256rm:
347354
case X86::VMOVUPDZ256rm:
348355
case X86::VMOVUPSZ256rm:
349-
return FixupConstant(0, X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
350-
X86::VBROADCASTSSZ256rm, 0, 0, 0, 0, 1);
356+
return FixupConstant(
357+
{{X86::VBROADCASTSSZ256rm, 1, 32, rebuildSplatCst},
358+
{X86::VBROADCASTSDZ256rm, 1, 64, rebuildSplatCst},
359+
{X86::VBROADCASTF32X4Z256rm, 1, 128, rebuildSplatCst}},
360+
1);
351361
case X86::VMOVAPDZrm:
352362
case X86::VMOVAPSZrm:
353363
case X86::VMOVUPDZrm:
354364
case X86::VMOVUPSZrm:
355-
return FixupConstant(X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
356-
X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0, 0, 0,
365+
return FixupConstant({{X86::VBROADCASTSSZrm, 1, 32, rebuildSplatCst},
366+
{X86::VBROADCASTSDZrm, 1, 64, rebuildSplatCst},
367+
{X86::VBROADCASTF32X4rm, 1, 128, rebuildSplatCst},
368+
{X86::VBROADCASTF64X4rm, 1, 256, rebuildSplatCst}},
357369
1);
358370
/* Integer Loads */
359371
case X86::MOVDQArm:
360-
case X86::MOVDQUrm:
361-
return FixupConstant(0, 0, 0, 0, 0, 0, X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
372+
case X86::MOVDQUrm: {
373+
return FixupConstant({{X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
374+
{X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst}},
362375
1);
376+
}
363377
case X86::VMOVDQArm:
364-
case X86::VMOVDQUrm:
365-
return FixupConstant(0, 0, HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
366-
HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
367-
HasAVX2 ? X86::VPBROADCASTWrm : 0,
368-
HasAVX2 ? X86::VPBROADCASTBrm : 0, X86::VMOVQI2PQIrm,
369-
X86::VMOVDI2PDIrm, 1);
378+
case X86::VMOVDQUrm: {
379+
FixupEntry Fixups[] = {
380+
{HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
381+
{HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
382+
{X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
383+
{HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
384+
rebuildSplatCst},
385+
{X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
386+
{HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
387+
rebuildSplatCst},
388+
};
389+
return FixupConstant(Fixups, 1);
390+
}
370391
case X86::VMOVDQAYrm:
371-
case X86::VMOVDQUYrm:
372-
return FixupConstant(
373-
0, HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
374-
HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
375-
HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
376-
HasAVX2 ? X86::VPBROADCASTWYrm : 0, HasAVX2 ? X86::VPBROADCASTBYrm : 0,
377-
0, 0, 1);
392+
case X86::VMOVDQUYrm: {
393+
FixupEntry Fixups[] = {
394+
{HasAVX2 ? X86::VPBROADCASTBYrm : 0, 1, 8, rebuildSplatCst},
395+
{HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
396+
{HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
397+
rebuildSplatCst},
398+
{HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
399+
rebuildSplatCst},
400+
{HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
401+
rebuildSplatCst}};
402+
return FixupConstant(Fixups, 1);
403+
}
378404
case X86::VMOVDQA32Z128rm:
379405
case X86::VMOVDQA64Z128rm:
380406
case X86::VMOVDQU32Z128rm:
381-
case X86::VMOVDQU64Z128rm:
382-
return FixupConstant(0, 0, X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm,
383-
HasBWI ? X86::VPBROADCASTWZ128rm : 0,
384-
HasBWI ? X86::VPBROADCASTBZ128rm : 0,
385-
X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm, 1);
407+
case X86::VMOVDQU64Z128rm: {
408+
FixupEntry Fixups[] = {
409+
{HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
410+
{HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
411+
{X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
412+
{X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
413+
{X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
414+
{X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst}};
415+
return FixupConstant(Fixups, 1);
416+
}
386417
case X86::VMOVDQA32Z256rm:
387418
case X86::VMOVDQA64Z256rm:
388419
case X86::VMOVDQU32Z256rm:
389-
case X86::VMOVDQU64Z256rm:
390-
return FixupConstant(0, X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm,
391-
X86::VPBROADCASTDZ256rm,
392-
HasBWI ? X86::VPBROADCASTWZ256rm : 0,
393-
HasBWI ? X86::VPBROADCASTBZ256rm : 0, 0, 0, 1);
420+
case X86::VMOVDQU64Z256rm: {
421+
FixupEntry Fixups[] = {
422+
{HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
423+
{HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
424+
{X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
425+
{X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
426+
{X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst}};
427+
return FixupConstant(Fixups, 1);
428+
}
394429
case X86::VMOVDQA32Zrm:
395430
case X86::VMOVDQA64Zrm:
396431
case X86::VMOVDQU32Zrm:
397-
case X86::VMOVDQU64Zrm:
398-
return FixupConstant(X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
399-
X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
400-
HasBWI ? X86::VPBROADCASTWZrm : 0,
401-
HasBWI ? X86::VPBROADCASTBZrm : 0, 0, 0, 1);
432+
case X86::VMOVDQU64Zrm: {
433+
FixupEntry Fixups[] = {
434+
{HasBWI ? X86::VPBROADCASTBZrm : 0, 1, 8, rebuildSplatCst},
435+
{HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
436+
{X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
437+
{X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
438+
{X86::VBROADCASTI32X4rm, 1, 128, rebuildSplatCst},
439+
{X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst}};
440+
return FixupConstant(Fixups, 1);
441+
}
402442
}
403443

404444
auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
@@ -423,7 +463,9 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
423463

424464
if (OpBcst32 || OpBcst64) {
425465
unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
426-
return FixupConstant(0, 0, OpBcst64, OpBcst32, 0, 0, 0, 0, OpNo);
466+
FixupEntry Fixups[] = {{(int)OpBcst32, 32, 32, rebuildSplatCst},
467+
{(int)OpBcst64, 64, 64, rebuildSplatCst}};
468+
return FixupConstant(Fixups, OpNo);
427469
}
428470
return false;
429471
};

0 commit comments

Comments
 (0)