@@ -67,6 +67,9 @@ FunctionPass *llvm::createX86FixupVectorConstants() {
67
67
static std::optional<APInt> extractConstantBits (const Constant *C) {
68
68
unsigned NumBits = C->getType ()->getPrimitiveSizeInBits ();
69
69
70
+ if (auto *CUndef = dyn_cast<UndefValue>(C))
71
+ return APInt::getZero (NumBits);
72
+
70
73
if (auto *CInt = dyn_cast<ConstantInt>(C))
71
74
return CInt->getValue ();
72
75
@@ -80,6 +83,18 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
80
83
return APInt::getSplat (NumBits, *Bits);
81
84
}
82
85
}
86
+
87
+ APInt Bits = APInt::getZero (NumBits);
88
+ for (unsigned I = 0 , E = CV->getNumOperands (); I != E; ++I) {
89
+ Constant *Elt = CV->getOperand (I);
90
+ std::optional<APInt> SubBits = extractConstantBits (Elt);
91
+ if (!SubBits)
92
+ return std::nullopt;
93
+ assert (NumBits == (E * SubBits->getBitWidth ()) &&
94
+ " Illegal vector element size" );
95
+ Bits.insertBits (*SubBits, I * SubBits->getBitWidth ());
96
+ }
97
+ return Bits;
83
98
}
84
99
85
100
if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) {
@@ -223,6 +238,35 @@ static Constant *rebuildSplatableConstant(const Constant *C,
223
238
return rebuildConstant (OriginalType->getContext (), SclTy, *Splat, NumSclBits);
224
239
}
225
240
241
+ static Constant *rebuildZeroUpperConstant (const Constant *C,
242
+ unsigned ScalarBitWidth) {
243
+ Type *Ty = C->getType ();
244
+ Type *SclTy = Ty->getScalarType ();
245
+ unsigned NumBits = Ty->getPrimitiveSizeInBits ();
246
+ unsigned NumSclBits = SclTy->getPrimitiveSizeInBits ();
247
+ LLVMContext &Ctx = C->getContext ();
248
+
249
+ if (NumBits > ScalarBitWidth) {
250
+ // Determine if the upper bits are all zero.
251
+ if (std::optional<APInt> Bits = extractConstantBits (C)) {
252
+ if (Bits->countLeadingZeros () >= (NumBits - ScalarBitWidth)) {
253
+ // If the original constant was made of smaller elements, try to retain
254
+ // those types.
255
+ if (ScalarBitWidth > NumSclBits && (ScalarBitWidth % NumSclBits) == 0 )
256
+ return rebuildConstant (Ctx, SclTy, *Bits, NumSclBits);
257
+
258
+ // Fallback to raw integer bits.
259
+ APInt RawBits = Bits->zextOrTrunc (ScalarBitWidth);
260
+ return ConstantInt::get (Ctx, RawBits);
261
+ }
262
+ }
263
+ }
264
+
265
+ return nullptr ;
266
+ }
267
+
268
+ typedef std::function<Constant *(const Constant *, unsigned )> RebuildFn;
269
+
226
270
bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
227
271
MachineBasicBlock &MBB,
228
272
MachineInstr &MI) {
@@ -233,117 +277,128 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
233
277
bool HasBWI = ST->hasBWI ();
234
278
bool HasVLX = ST->hasVLX ();
235
279
236
- auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128,
237
- unsigned OpBcst64, unsigned OpBcst32,
238
- unsigned OpBcst16, unsigned OpBcst8,
239
- unsigned OperandNo) {
240
- assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
241
- " Unexpected number of operands!" );
242
-
243
- if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
244
- // Attempt to detect a suitable splat from increasing splat widths.
245
- std::pair<unsigned , unsigned > Broadcasts[] = {
246
- {8 , OpBcst8}, {16 , OpBcst16}, {32 , OpBcst32},
247
- {64 , OpBcst64}, {128 , OpBcst128}, {256 , OpBcst256},
248
- };
249
- for (auto [BitWidth, OpBcst] : Broadcasts) {
250
- if (OpBcst) {
251
- // Construct a suitable splat constant and adjust the MI to
252
- // use the new constant pool entry.
253
- if (Constant *NewCst = rebuildSplatableConstant (C, BitWidth)) {
254
- unsigned NewCPI =
255
- CP->getConstantPoolIndex (NewCst, Align (BitWidth / 8 ));
256
- MI.setDesc (TII->get (OpBcst));
257
- MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
258
- return true ;
280
+ auto FixupConstant =
281
+ [&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
282
+ unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
283
+ unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
284
+ assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
285
+ " Unexpected number of operands!" );
286
+
287
+ if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
288
+ // Attempt to detect a suitable splat/vzload from increasing constant
289
+ // bitwidths.
290
+ // Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
291
+ std::tuple<unsigned , unsigned , RebuildFn> FixupLoad[] = {
292
+ {8 , OpBcst8, rebuildSplatableConstant},
293
+ {16 , OpBcst16, rebuildSplatableConstant},
294
+ {32 , OpUpper32, rebuildZeroUpperConstant},
295
+ {32 , OpBcst32, rebuildSplatableConstant},
296
+ {64 , OpUpper64, rebuildZeroUpperConstant},
297
+ {64 , OpBcst64, rebuildSplatableConstant},
298
+ {128 , OpBcst128, rebuildSplatableConstant},
299
+ {256 , OpBcst256, rebuildSplatableConstant},
300
+ };
301
+ for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
302
+ if (Op) {
303
+ // Construct a suitable constant and adjust the MI to use the new
304
+ // constant pool entry.
305
+ if (Constant *NewCst = RebuildConstant (C, BitWidth)) {
306
+ unsigned NewCPI =
307
+ CP->getConstantPoolIndex (NewCst, Align (BitWidth / 8 ));
308
+ MI.setDesc (TII->get (Op));
309
+ MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
310
+ return true ;
311
+ }
312
+ }
259
313
}
260
314
}
261
- }
262
- }
263
- return false ;
264
- };
315
+ return false ;
316
+ };
265
317
266
- // Attempt to convert full width vector loads into broadcast loads.
318
+ // Attempt to convert full width vector loads into broadcast/vzload loads.
267
319
switch (Opc) {
268
320
/* FP Loads */
269
321
case X86::MOVAPDrm:
270
322
case X86::MOVAPSrm:
271
323
case X86::MOVUPDrm:
272
324
case X86::MOVUPSrm:
273
325
// TODO: SSE3 MOVDDUP Handling
274
- return false ;
326
+ return FixupConstant ( 0 , 0 , 0 , 0 , 0 , 0 , X86::MOVSDrm, X86::MOVSSrm, 1 ) ;
275
327
case X86::VMOVAPDrm:
276
328
case X86::VMOVAPSrm:
277
329
case X86::VMOVUPDrm:
278
330
case X86::VMOVUPSrm:
279
- return ConvertToBroadcast (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
280
- 1 );
331
+ return FixupConstant (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
332
+ X86::VMOVSDrm, X86::VMOVSSrm, 1 );
281
333
case X86::VMOVAPDYrm:
282
334
case X86::VMOVAPSYrm:
283
335
case X86::VMOVUPDYrm:
284
336
case X86::VMOVUPSYrm:
285
- return ConvertToBroadcast (0 , X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
286
- X86::VBROADCASTSSYrm, 0 , 0 , 1 );
337
+ return FixupConstant (0 , X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
338
+ X86::VBROADCASTSSYrm, 0 , 0 , 0 , 0 , 1 );
287
339
case X86::VMOVAPDZ128rm:
288
340
case X86::VMOVAPSZ128rm:
289
341
case X86::VMOVUPDZ128rm:
290
342
case X86::VMOVUPSZ128rm:
291
- return ConvertToBroadcast (0 , 0 , X86::VMOVDDUPZ128rm,
292
- X86::VBROADCASTSSZ128rm, 0 , 0 , 1 );
343
+ return FixupConstant (0 , 0 , X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0 ,
344
+ 0 , X86::VMOVSDZrm, X86::VMOVSSZrm , 1 );
293
345
case X86::VMOVAPDZ256rm:
294
346
case X86::VMOVAPSZ256rm:
295
347
case X86::VMOVUPDZ256rm:
296
348
case X86::VMOVUPSZ256rm:
297
- return ConvertToBroadcast (0 , X86::VBROADCASTF32X4Z256rm,
298
- X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm,
299
- 0 , 0 , 1 );
349
+ return FixupConstant (0 , X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
350
+ X86::VBROADCASTSSZ256rm, 0 , 0 , 0 , 0 , 1 );
300
351
case X86::VMOVAPDZrm:
301
352
case X86::VMOVAPSZrm:
302
353
case X86::VMOVUPDZrm:
303
354
case X86::VMOVUPSZrm:
304
- return ConvertToBroadcast (X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
305
- X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0 , 0 ,
306
- 1 );
355
+ return FixupConstant (X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
356
+ X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0 , 0 , 0 , 0 ,
357
+ 1 );
307
358
/* Integer Loads */
359
+ case X86::MOVDQArm:
360
+ case X86::MOVDQUrm:
361
+ return FixupConstant (0 , 0 , 0 , 0 , 0 , 0 , X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
362
+ 1 );
308
363
case X86::VMOVDQArm:
309
364
case X86::VMOVDQUrm:
310
- return ConvertToBroadcast (
311
- 0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm ,
312
- HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm ,
313
- HasAVX2 ? X86::VPBROADCASTWrm : 0 , HasAVX2 ? X86::VPBROADCASTBrm : 0 ,
314
- 1 );
365
+ return FixupConstant ( 0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
366
+ HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm ,
367
+ HasAVX2 ? X86::VPBROADCASTWrm : 0 ,
368
+ HasAVX2 ? X86::VPBROADCASTBrm : 0 , X86::VMOVQI2PQIrm ,
369
+ X86::VMOVDI2PDIrm, 1 );
315
370
case X86::VMOVDQAYrm:
316
371
case X86::VMOVDQUYrm:
317
- return ConvertToBroadcast (
372
+ return FixupConstant (
318
373
0 , HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
319
374
HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
320
375
HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
321
376
HasAVX2 ? X86::VPBROADCASTWYrm : 0 , HasAVX2 ? X86::VPBROADCASTBYrm : 0 ,
322
- 1 );
377
+ 0 , 0 , 1 );
323
378
case X86::VMOVDQA32Z128rm:
324
379
case X86::VMOVDQA64Z128rm:
325
380
case X86::VMOVDQU32Z128rm:
326
381
case X86::VMOVDQU64Z128rm:
327
- return ConvertToBroadcast (0 , 0 , X86::VPBROADCASTQZ128rm,
328
- X86::VPBROADCASTDZ128rm ,
329
- HasBWI ? X86::VPBROADCASTWZ128rm : 0 ,
330
- HasBWI ? X86::VPBROADCASTBZ128rm : 0 , 1 );
382
+ return FixupConstant (0 , 0 , X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm ,
383
+ HasBWI ? X86::VPBROADCASTWZ128rm : 0 ,
384
+ HasBWI ? X86::VPBROADCASTBZ128rm : 0 ,
385
+ X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm , 1 );
331
386
case X86::VMOVDQA32Z256rm:
332
387
case X86::VMOVDQA64Z256rm:
333
388
case X86::VMOVDQU32Z256rm:
334
389
case X86::VMOVDQU64Z256rm:
335
- return ConvertToBroadcast (0 , X86::VBROADCASTI32X4Z256rm,
336
- X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm,
337
- HasBWI ? X86::VPBROADCASTWZ256rm : 0 ,
338
- HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 1 );
390
+ return FixupConstant (0 , X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm ,
391
+ X86::VPBROADCASTDZ256rm,
392
+ HasBWI ? X86::VPBROADCASTWZ256rm : 0 ,
393
+ HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 0 , 0 , 1 );
339
394
case X86::VMOVDQA32Zrm:
340
395
case X86::VMOVDQA64Zrm:
341
396
case X86::VMOVDQU32Zrm:
342
397
case X86::VMOVDQU64Zrm:
343
- return ConvertToBroadcast (X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
344
- X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
345
- HasBWI ? X86::VPBROADCASTWZrm : 0 ,
346
- HasBWI ? X86::VPBROADCASTBZrm : 0 , 1 );
398
+ return FixupConstant (X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
399
+ X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
400
+ HasBWI ? X86::VPBROADCASTWZrm : 0 ,
401
+ HasBWI ? X86::VPBROADCASTBZrm : 0 , 0 , 0 , 1 );
347
402
}
348
403
349
404
auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
@@ -368,7 +423,7 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
368
423
369
424
if (OpBcst32 || OpBcst64) {
370
425
unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
371
- return ConvertToBroadcast (0 , 0 , OpBcst64, OpBcst32, 0 , 0 , OpNo);
426
+ return FixupConstant (0 , 0 , OpBcst64, OpBcst32, 0 , 0 , 0 , 0 , OpNo);
372
427
}
373
428
return false ;
374
429
};
0 commit comments