Skip to content

Commit 57dc23d

Browse files
committed
[X86] X86FixupVectorConstants - load+sign-extend vector constants that can be stored in a truncated form
Reduce the size of the vector constant by storing it in the constant pool in a truncated form, and sign-extend it as part of the load. I intend to add the matching load+zero-extend handling in a future patch, but that requires some alterations to the existing MC shuffle comments handling first. I've extended the existing FixupConstant functionality to support these constant rebuilds as well - we still select the smallest stored constant entry and prefer vzload/broadcast/vextload for same bitwidth to avoid domain flips. NOTE: Some of the FixupConstant tables are currently created on the fly as they are dependent on the supported ISAs (HasAVX2 etc.) - should we split these (to allow initializer lists instead) and have duplicate FixupConstant calls to avoid so much stack use?
1 parent 2e669ff commit 57dc23d

File tree

222 files changed

+15116
-15442
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

222 files changed

+15116
-15442
lines changed

llvm/lib/Target/X86/X86FixupVectorConstants.cpp

+90-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// replace them with smaller constant pool entries, including:
1111
// * Converting AVX512 memory-fold instructions to their broadcast-fold form
1212
// * Broadcasting of full width loads.
13-
// * TODO: Sign/Zero extension of full width loads.
13+
// * TODO: Zero extension of full width loads.
1414
//
1515
//===----------------------------------------------------------------------===//
1616

@@ -265,11 +265,47 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
265265
return nullptr;
266266
}
267267

268+
static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
269+
unsigned SrcEltBitWidth) {
270+
Type *Ty = C->getType();
271+
unsigned NumBits = Ty->getPrimitiveSizeInBits();
272+
unsigned DstEltBitWidth = NumBits / NumElts;
273+
assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
274+
(DstEltBitWidth % SrcEltBitWidth) == 0 &&
275+
(DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
276+
277+
if (std::optional<APInt> Bits = extractConstantBits(C)) {
278+
assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
279+
(Bits->getBitWidth() % DstEltBitWidth) == 0 &&
280+
"Unexpected constant extension");
281+
282+
// Ensure every vector element can be represented by the src bitwidth.
283+
APInt TruncBits = APInt::getZero(NumElts * SrcEltBitWidth);
284+
for (unsigned I = 0; I != NumElts; ++I) {
285+
APInt Elt = Bits->extractBits(DstEltBitWidth, I * DstEltBitWidth);
286+
if ((IsSExt && Elt.getSignificantBits() > SrcEltBitWidth) ||
287+
(!IsSExt && Elt.getActiveBits() > SrcEltBitWidth))
288+
return nullptr;
289+
TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
290+
}
291+
292+
return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
293+
SrcEltBitWidth);
294+
}
295+
296+
return nullptr;
297+
}
298+
static Constant *rebuildSExtCst(const Constant *C, unsigned NumElts,
299+
unsigned SrcEltBitWidth) {
300+
return rebuildExtCst(C, true, NumElts, SrcEltBitWidth);
301+
}
302+
268303
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
269304
MachineBasicBlock &MBB,
270305
MachineInstr &MI) {
271306
unsigned Opc = MI.getOpcode();
272307
MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool();
308+
bool HasSSE41 = ST->hasSSE41();
273309
bool HasAVX2 = ST->hasAVX2();
274310
bool HasDQI = ST->hasDQI();
275311
bool HasBWI = ST->hasBWI();
@@ -312,7 +348,15 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
312348
return false;
313349
};
314350

315-
// Attempt to convert full width vector loads into broadcast/vzload loads.
351+
// Attempt to detect a suitable vzload/broadcast/vextload from increasing
352+
// constant bitwidths. Prefer vzload/broadcast/vextload for same bitwidth:
353+
// - vzload shouldn't ever need a shuffle port to zero the upper elements and
354+
// the fp/int domain versions are equally available so we don't introduce a
355+
// domain crossing penalty.
356+
// - broadcast sometimes need a shuffle port (especially for 8/16-bit
357+
// variants), AVX1 only has fp domain broadcasts but AVX2+ have good fp/int
358+
// domain equivalents.
359+
// - vextload always needs a shuffle port and is only ever int domain.
316360
switch (Opc) {
317361
/* FP Loads */
318362
case X86::MOVAPDrm:
@@ -370,22 +414,34 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
370414
/* Integer Loads */
371415
case X86::MOVDQArm:
372416
case X86::MOVDQUrm: {
373-
return FixupConstant({{X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
374-
{X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst}},
375-
1);
417+
FixupEntry Fixups[] = {
418+
{HasSSE41 ? X86::PMOVSXBQrm : 0, 2, 8, rebuildSExtCst},
419+
{X86::MOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
420+
{HasSSE41 ? X86::PMOVSXBDrm : 0, 4, 8, rebuildSExtCst},
421+
{HasSSE41 ? X86::PMOVSXWQrm : 0, 2, 16, rebuildSExtCst},
422+
{X86::MOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
423+
{HasSSE41 ? X86::PMOVSXBWrm : 0, 8, 8, rebuildSExtCst},
424+
{HasSSE41 ? X86::PMOVSXWDrm : 0, 4, 16, rebuildSExtCst},
425+
{HasSSE41 ? X86::PMOVSXDQrm : 0, 2, 32, rebuildSExtCst}};
426+
return FixupConstant(Fixups, 1);
376427
}
377428
case X86::VMOVDQArm:
378429
case X86::VMOVDQUrm: {
379430
FixupEntry Fixups[] = {
380431
{HasAVX2 ? X86::VPBROADCASTBrm : 0, 1, 8, rebuildSplatCst},
381432
{HasAVX2 ? X86::VPBROADCASTWrm : 0, 1, 16, rebuildSplatCst},
433+
{X86::VPMOVSXBQrm, 2, 8, rebuildSExtCst},
382434
{X86::VMOVDI2PDIrm, 1, 32, rebuildZeroUpperCst},
383435
{HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1, 32,
384436
rebuildSplatCst},
437+
{X86::VPMOVSXBDrm, 4, 8, rebuildSExtCst},
438+
{X86::VPMOVSXWQrm, 2, 16, rebuildSExtCst},
385439
{X86::VMOVQI2PQIrm, 1, 64, rebuildZeroUpperCst},
386440
{HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1, 64,
387441
rebuildSplatCst},
388-
};
442+
{X86::VPMOVSXBWrm, 8, 8, rebuildSExtCst},
443+
{X86::VPMOVSXWDrm, 4, 16, rebuildSExtCst},
444+
{X86::VPMOVSXDQrm, 2, 32, rebuildSExtCst}};
389445
return FixupConstant(Fixups, 1);
390446
}
391447
case X86::VMOVDQAYrm:
@@ -395,10 +451,16 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
395451
{HasAVX2 ? X86::VPBROADCASTWYrm : 0, 1, 16, rebuildSplatCst},
396452
{HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1, 32,
397453
rebuildSplatCst},
454+
{HasAVX2 ? X86::VPMOVSXBQYrm : 0, 4, 8, rebuildSExtCst},
398455
{HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1, 64,
399456
rebuildSplatCst},
457+
{HasAVX2 ? X86::VPMOVSXBDYrm : 0, 8, 8, rebuildSExtCst},
458+
{HasAVX2 ? X86::VPMOVSXWQYrm : 0, 4, 16, rebuildSExtCst},
400459
{HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1, 128,
401-
rebuildSplatCst}};
460+
rebuildSplatCst},
461+
{HasAVX2 ? X86::VPMOVSXBWYrm : 0, 16, 8, rebuildSExtCst},
462+
{HasAVX2 ? X86::VPMOVSXWDYrm : 0, 8, 16, rebuildSExtCst},
463+
{HasAVX2 ? X86::VPMOVSXDQYrm : 0, 4, 32, rebuildSExtCst}};
402464
return FixupConstant(Fixups, 1);
403465
}
404466
case X86::VMOVDQA32Z128rm:
@@ -408,10 +470,16 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
408470
FixupEntry Fixups[] = {
409471
{HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1, 8, rebuildSplatCst},
410472
{HasBWI ? X86::VPBROADCASTWZ128rm : 0, 1, 16, rebuildSplatCst},
473+
{X86::VPMOVSXBQZ128rm, 2, 8, rebuildSExtCst},
411474
{X86::VMOVDI2PDIZrm, 1, 32, rebuildZeroUpperCst},
412475
{X86::VPBROADCASTDZ128rm, 1, 32, rebuildSplatCst},
476+
{X86::VPMOVSXBDZ128rm, 4, 8, rebuildSExtCst},
477+
{X86::VPMOVSXWQZ128rm, 2, 16, rebuildSExtCst},
413478
{X86::VMOVQI2PQIZrm, 1, 64, rebuildZeroUpperCst},
414-
{X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst}};
479+
{X86::VPBROADCASTQZ128rm, 1, 64, rebuildSplatCst},
480+
{HasBWI ? X86::VPMOVSXBWZ128rm : 0, 8, 8, rebuildSExtCst},
481+
{X86::VPMOVSXWDZ128rm, 4, 16, rebuildSExtCst},
482+
{X86::VPMOVSXDQZ128rm, 2, 32, rebuildSExtCst}};
415483
return FixupConstant(Fixups, 1);
416484
}
417485
case X86::VMOVDQA32Z256rm:
@@ -422,8 +490,14 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
422490
{HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1, 8, rebuildSplatCst},
423491
{HasBWI ? X86::VPBROADCASTWZ256rm : 0, 1, 16, rebuildSplatCst},
424492
{X86::VPBROADCASTDZ256rm, 1, 32, rebuildSplatCst},
493+
{X86::VPMOVSXBQZ256rm, 4, 8, rebuildSExtCst},
425494
{X86::VPBROADCASTQZ256rm, 1, 64, rebuildSplatCst},
426-
{X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst}};
495+
{X86::VPMOVSXBDZ256rm, 8, 8, rebuildSExtCst},
496+
{X86::VPMOVSXWQZ256rm, 4, 16, rebuildSExtCst},
497+
{X86::VBROADCASTI32X4Z256rm, 1, 128, rebuildSplatCst},
498+
{HasBWI ? X86::VPMOVSXBWZ256rm : 0, 16, 8, rebuildSExtCst},
499+
{X86::VPMOVSXWDZ256rm, 8, 16, rebuildSExtCst},
500+
{X86::VPMOVSXDQZ256rm, 4, 32, rebuildSExtCst}};
427501
return FixupConstant(Fixups, 1);
428502
}
429503
case X86::VMOVDQA32Zrm:
@@ -435,8 +509,14 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
435509
{HasBWI ? X86::VPBROADCASTWZrm : 0, 1, 16, rebuildSplatCst},
436510
{X86::VPBROADCASTDZrm, 1, 32, rebuildSplatCst},
437511
{X86::VPBROADCASTQZrm, 1, 64, rebuildSplatCst},
512+
{X86::VPMOVSXBQZrm, 8, 8, rebuildSExtCst},
438513
{X86::VBROADCASTI32X4rm, 1, 128, rebuildSplatCst},
439-
{X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst}};
514+
{X86::VPMOVSXBDZrm, 16, 8, rebuildSExtCst},
515+
{X86::VPMOVSXWQZrm, 8, 16, rebuildSExtCst},
516+
{X86::VBROADCASTI64X4rm, 1, 256, rebuildSplatCst},
517+
{HasBWI ? X86::VPMOVSXBWZrm : 0, 32, 8, rebuildSExtCst},
518+
{X86::VPMOVSXWDZrm, 16, 16, rebuildSExtCst},
519+
{X86::VPMOVSXDQZrm, 8, 32, rebuildSExtCst}};
440520
return FixupConstant(Fixups, 1);
441521
}
442522
}

llvm/lib/Target/X86/X86MCInstLower.cpp

+61-1
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,36 @@ static void printBroadcast(const MachineInstr *MI, MCStreamer &OutStreamer,
15821582
}
15831583
}
15841584

1585+
static bool printSignExtend(const MachineInstr *MI, MCStreamer &OutStreamer,
1586+
int SrcEltBits, int DstEltBits) {
1587+
auto *C = X86::getConstantFromPool(*MI, 1);
1588+
if (C && C->getType()->getScalarSizeInBits() == SrcEltBits) {
1589+
if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
1590+
int NumElts = CDS->getNumElements();
1591+
std::string Comment;
1592+
raw_string_ostream CS(Comment);
1593+
1594+
const MachineOperand &DstOp = MI->getOperand(0);
1595+
CS << X86ATTInstPrinter::getRegisterName(DstOp.getReg()) << " = ";
1596+
CS << "[";
1597+
for (int i = 0; i != NumElts; ++i) {
1598+
if (i != 0)
1599+
CS << ",";
1600+
if (CDS->getElementType()->isIntegerTy()) {
1601+
APInt Elt = CDS->getElementAsAPInt(i).sext(DstEltBits);
1602+
printConstant(Elt, CS);
1603+
} else
1604+
CS << "?";
1605+
}
1606+
CS << "]";
1607+
OutStreamer.AddComment(CS.str());
1608+
return true;
1609+
}
1610+
}
1611+
1612+
return false;
1613+
}
1614+
15851615
void X86AsmPrinter::EmitSEHInstruction(const MachineInstr *MI) {
15861616
assert(MF->hasWinCFI() && "SEH_ instruction in function without WinCFI?");
15871617
assert((getSubtarget().isOSWindows() || TM.getTargetTriple().isUEFI()) &&
@@ -1844,7 +1874,7 @@ static void addConstantComments(const MachineInstr *MI,
18441874
case X86::VMOVQI2PQIrm:
18451875
case X86::VMOVQI2PQIZrm:
18461876
printZeroUpperMove(MI, OutStreamer, 64, 128, "mem[0],zero");
1847-
break;
1877+
break;
18481878

18491879
case X86::MOVSSrm:
18501880
case X86::VMOVSSrm:
@@ -1979,6 +2009,36 @@ static void addConstantComments(const MachineInstr *MI,
19792009
case X86::VPBROADCASTBZrm:
19802010
printBroadcast(MI, OutStreamer, 64, 8);
19812011
break;
2012+
2013+
#define MOVX_CASE(Prefix, Ext, Type, Suffix) \
2014+
case X86::Prefix##PMOV##Ext##Type##Suffix##rm:
2015+
2016+
#define CASE_MOVX_RM(Ext, Type) \
2017+
MOVX_CASE(, Ext, Type, ) \
2018+
MOVX_CASE(V, Ext, Type, ) \
2019+
MOVX_CASE(V, Ext, Type, Y) \
2020+
MOVX_CASE(V, Ext, Type, Z128) \
2021+
MOVX_CASE(V, Ext, Type, Z256) \
2022+
MOVX_CASE(V, Ext, Type, Z)
2023+
2024+
CASE_MOVX_RM(SX, BD)
2025+
printSignExtend(MI, OutStreamer, 8, 32);
2026+
break;
2027+
CASE_MOVX_RM(SX, BQ)
2028+
printSignExtend(MI, OutStreamer, 8, 64);
2029+
break;
2030+
CASE_MOVX_RM(SX, BW)
2031+
printSignExtend(MI, OutStreamer, 8, 16);
2032+
break;
2033+
CASE_MOVX_RM(SX, DQ)
2034+
printSignExtend(MI, OutStreamer, 32, 64);
2035+
break;
2036+
CASE_MOVX_RM(SX, WD)
2037+
printSignExtend(MI, OutStreamer, 16, 32);
2038+
break;
2039+
CASE_MOVX_RM(SX, WQ)
2040+
printSignExtend(MI, OutStreamer, 16, 64);
2041+
break;
19822042
}
19832043
}
19842044

0 commit comments

Comments
 (0)