@@ -216,8 +216,8 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
216
216
217
217
// Attempt to rebuild a normalized splat vector constant of the requested splat
218
218
// width, built up of potentially smaller scalar values.
219
- static Constant *rebuildSplatableConstant (const Constant *C,
220
- unsigned SplatBitWidth) {
219
+ static Constant *rebuildSplatCst (const Constant *C, unsigned /* NumElts */ ,
220
+ unsigned SplatBitWidth) {
221
221
std::optional<APInt> Splat = getSplatableConstant (C, SplatBitWidth);
222
222
if (!Splat)
223
223
return nullptr ;
@@ -238,8 +238,8 @@ static Constant *rebuildSplatableConstant(const Constant *C,
238
238
return rebuildConstant (OriginalType->getContext (), SclTy, *Splat, NumSclBits);
239
239
}
240
240
241
- static Constant *rebuildZeroUpperConstant (const Constant *C,
242
- unsigned ScalarBitWidth) {
241
+ static Constant *rebuildZeroUpperCst (const Constant *C, unsigned /* NumElts */ ,
242
+ unsigned ScalarBitWidth) {
243
243
Type *Ty = C->getType ();
244
244
Type *SclTy = Ty->getScalarType ();
245
245
unsigned NumBits = Ty->getPrimitiveSizeInBits ();
@@ -265,8 +265,6 @@ static Constant *rebuildZeroUpperConstant(const Constant *C,
265
265
return nullptr ;
266
266
}
267
267
268
- typedef std::function<Constant *(const Constant *, unsigned )> RebuildFn;
269
-
270
268
bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
271
269
MachineBasicBlock &MBB,
272
270
MachineInstr &MI) {
@@ -277,43 +275,42 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
277
275
bool HasBWI = ST->hasBWI ();
278
276
bool HasVLX = ST->hasVLX ();
279
277
280
- auto FixupConstant =
281
- [&](unsigned OpBcst256, unsigned OpBcst128, unsigned OpBcst64,
282
- unsigned OpBcst32, unsigned OpBcst16, unsigned OpBcst8,
283
- unsigned OpUpper64, unsigned OpUpper32, unsigned OperandNo) {
284
- assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
285
- " Unexpected number of operands!" );
286
-
287
- if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
288
- // Attempt to detect a suitable splat/vzload from increasing constant
289
- // bitwidths.
290
- // Prefer vzload vs broadcast for same bitwidth to avoid domain flips.
291
- std::tuple<unsigned , unsigned , RebuildFn> FixupLoad[] = {
292
- {8 , OpBcst8, rebuildSplatableConstant},
293
- {16 , OpBcst16, rebuildSplatableConstant},
294
- {32 , OpUpper32, rebuildZeroUpperConstant},
295
- {32 , OpBcst32, rebuildSplatableConstant},
296
- {64 , OpUpper64, rebuildZeroUpperConstant},
297
- {64 , OpBcst64, rebuildSplatableConstant},
298
- {128 , OpBcst128, rebuildSplatableConstant},
299
- {256 , OpBcst256, rebuildSplatableConstant},
300
- };
301
- for (auto [BitWidth, Op, RebuildConstant] : FixupLoad) {
302
- if (Op) {
303
- // Construct a suitable constant and adjust the MI to use the new
304
- // constant pool entry.
305
- if (Constant *NewCst = RebuildConstant (C, BitWidth)) {
306
- unsigned NewCPI =
307
- CP->getConstantPoolIndex (NewCst, Align (BitWidth / 8 ));
308
- MI.setDesc (TII->get (Op));
309
- MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
310
- return true ;
311
- }
312
- }
278
+ struct FixupEntry {
279
+ int Op;
280
+ int NumCstElts;
281
+ int BitWidth;
282
+ std::function<Constant *(const Constant *, unsigned , unsigned )>
283
+ RebuildConstant;
284
+ };
285
+ auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
286
+ #ifdef EXPENSIVE_CHECKS
287
+ assert (llvm::is_sorted (Fixups,
288
+ [](const FixupEntry &A, const FixupEntry &B) {
289
+ return (A.NumCstElts * A.BitWidth ) <
290
+ (B.NumCstElts * B.BitWidth );
291
+ }) &&
292
+ " Constant fixup table not sorted in ascending constant size" );
293
+ #endif
294
+ assert (MI.getNumOperands () >= (OperandNo + X86::AddrNumOperands) &&
295
+ " Unexpected number of operands!" );
296
+ if (auto *C = X86::getConstantFromPool (MI, OperandNo)) {
297
+ for (const FixupEntry &Fixup : Fixups) {
298
+ if (Fixup.Op ) {
299
+ // Construct a suitable constant and adjust the MI to use the new
300
+ // constant pool entry.
301
+ if (Constant *NewCst =
302
+ Fixup.RebuildConstant (C, Fixup.NumCstElts , Fixup.BitWidth )) {
303
+ unsigned NewCPI =
304
+ CP->getConstantPoolIndex (NewCst, Align (Fixup.BitWidth / 8 ));
305
+ MI.setDesc (TII->get (Fixup.Op ));
306
+ MI.getOperand (OperandNo + X86::AddrDisp).setIndex (NewCPI);
307
+ return true ;
313
308
}
314
309
}
315
- return false ;
316
- };
310
+ }
311
+ }
312
+ return false ;
313
+ };
317
314
318
315
// Attempt to convert full width vector loads into broadcast/vzload loads.
319
316
switch (Opc) {
@@ -323,82 +320,125 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
323
320
case X86::MOVUPDrm:
324
321
case X86::MOVUPSrm:
325
322
// TODO: SSE3 MOVDDUP Handling
326
- return FixupConstant (0 , 0 , 0 , 0 , 0 , 0 , X86::MOVSDrm, X86::MOVSSrm, 1 );
323
+ return FixupConstant ({{X86::MOVSSrm, 1 , 32 , rebuildZeroUpperCst},
324
+ {X86::MOVSDrm, 1 , 64 , rebuildZeroUpperCst}},
325
+ 1 );
327
326
case X86::VMOVAPDrm:
328
327
case X86::VMOVAPSrm:
329
328
case X86::VMOVUPDrm:
330
329
case X86::VMOVUPSrm:
331
- return FixupConstant (0 , 0 , X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0 , 0 ,
332
- X86::VMOVSDrm, X86::VMOVSSrm, 1 );
330
+ return FixupConstant ({{X86::VMOVSSrm, 1 , 32 , rebuildZeroUpperCst},
331
+ {X86::VBROADCASTSSrm, 1 , 32 , rebuildSplatCst},
332
+ {X86::VMOVSDrm, 1 , 64 , rebuildZeroUpperCst},
333
+ {X86::VMOVDDUPrm, 1 , 64 , rebuildSplatCst}},
334
+ 1 );
333
335
case X86::VMOVAPDYrm:
334
336
case X86::VMOVAPSYrm:
335
337
case X86::VMOVUPDYrm:
336
338
case X86::VMOVUPSYrm:
337
- return FixupConstant (0 , X86::VBROADCASTF128rm, X86::VBROADCASTSDYrm,
338
- X86::VBROADCASTSSYrm, 0 , 0 , 0 , 0 , 1 );
339
+ return FixupConstant ({{X86::VBROADCASTSSYrm, 1 , 32 , rebuildSplatCst},
340
+ {X86::VBROADCASTSDYrm, 1 , 64 , rebuildSplatCst},
341
+ {X86::VBROADCASTF128rm, 1 , 128 , rebuildSplatCst}},
342
+ 1 );
339
343
case X86::VMOVAPDZ128rm:
340
344
case X86::VMOVAPSZ128rm:
341
345
case X86::VMOVUPDZ128rm:
342
346
case X86::VMOVUPSZ128rm:
343
- return FixupConstant (0 , 0 , X86::VMOVDDUPZ128rm, X86::VBROADCASTSSZ128rm, 0 ,
344
- 0 , X86::VMOVSDZrm, X86::VMOVSSZrm, 1 );
347
+ return FixupConstant ({{X86::VMOVSSZrm, 1 , 32 , rebuildZeroUpperCst},
348
+ {X86::VBROADCASTSSZ128rm, 1 , 32 , rebuildSplatCst},
349
+ {X86::VMOVSDZrm, 1 , 64 , rebuildZeroUpperCst},
350
+ {X86::VMOVDDUPZ128rm, 1 , 64 , rebuildSplatCst}},
351
+ 1 );
345
352
case X86::VMOVAPDZ256rm:
346
353
case X86::VMOVAPSZ256rm:
347
354
case X86::VMOVUPDZ256rm:
348
355
case X86::VMOVUPSZ256rm:
349
- return FixupConstant (0 , X86::VBROADCASTF32X4Z256rm, X86::VBROADCASTSDZ256rm,
350
- X86::VBROADCASTSSZ256rm, 0 , 0 , 0 , 0 , 1 );
356
+ return FixupConstant (
357
+ {{X86::VBROADCASTSSZ256rm, 1 , 32 , rebuildSplatCst},
358
+ {X86::VBROADCASTSDZ256rm, 1 , 64 , rebuildSplatCst},
359
+ {X86::VBROADCASTF32X4Z256rm, 1 , 128 , rebuildSplatCst}},
360
+ 1 );
351
361
case X86::VMOVAPDZrm:
352
362
case X86::VMOVAPSZrm:
353
363
case X86::VMOVUPDZrm:
354
364
case X86::VMOVUPSZrm:
355
- return FixupConstant (X86::VBROADCASTF64X4rm, X86::VBROADCASTF32X4rm,
356
- X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0 , 0 , 0 , 0 ,
365
+ return FixupConstant ({{X86::VBROADCASTSSZrm, 1 , 32 , rebuildSplatCst},
366
+ {X86::VBROADCASTSDZrm, 1 , 64 , rebuildSplatCst},
367
+ {X86::VBROADCASTF32X4rm, 1 , 128 , rebuildSplatCst},
368
+ {X86::VBROADCASTF64X4rm, 1 , 256 , rebuildSplatCst}},
357
369
1 );
358
370
/* Integer Loads */
359
371
case X86::MOVDQArm:
360
- case X86::MOVDQUrm:
361
- return FixupConstant (0 , 0 , 0 , 0 , 0 , 0 , X86::MOVQI2PQIrm, X86::MOVDI2PDIrm,
372
+ case X86::MOVDQUrm: {
373
+ return FixupConstant ({{X86::MOVDI2PDIrm, 1 , 32 , rebuildZeroUpperCst},
374
+ {X86::MOVQI2PQIrm, 1 , 64 , rebuildZeroUpperCst}},
362
375
1 );
376
+ }
363
377
case X86::VMOVDQArm:
364
- case X86::VMOVDQUrm:
365
- return FixupConstant (0 , 0 , HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm,
366
- HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm,
367
- HasAVX2 ? X86::VPBROADCASTWrm : 0 ,
368
- HasAVX2 ? X86::VPBROADCASTBrm : 0 , X86::VMOVQI2PQIrm,
369
- X86::VMOVDI2PDIrm, 1 );
378
+ case X86::VMOVDQUrm: {
379
+ FixupEntry Fixups[] = {
380
+ {HasAVX2 ? X86::VPBROADCASTBrm : 0 , 1 , 8 , rebuildSplatCst},
381
+ {HasAVX2 ? X86::VPBROADCASTWrm : 0 , 1 , 16 , rebuildSplatCst},
382
+ {X86::VMOVDI2PDIrm, 1 , 32 , rebuildZeroUpperCst},
383
+ {HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1 , 32 ,
384
+ rebuildSplatCst},
385
+ {X86::VMOVQI2PQIrm, 1 , 64 , rebuildZeroUpperCst},
386
+ {HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1 , 64 ,
387
+ rebuildSplatCst},
388
+ };
389
+ return FixupConstant (Fixups, 1 );
390
+ }
370
391
case X86::VMOVDQAYrm:
371
- case X86::VMOVDQUYrm:
372
- return FixupConstant (
373
- 0 , HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm,
374
- HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm,
375
- HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm,
376
- HasAVX2 ? X86::VPBROADCASTWYrm : 0 , HasAVX2 ? X86::VPBROADCASTBYrm : 0 ,
377
- 0 , 0 , 1 );
392
+ case X86::VMOVDQUYrm: {
393
+ FixupEntry Fixups[] = {
394
+ {HasAVX2 ? X86::VPBROADCASTBYrm : 0 , 1 , 8 , rebuildSplatCst},
395
+ {HasAVX2 ? X86::VPBROADCASTWYrm : 0 , 1 , 16 , rebuildSplatCst},
396
+ {HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1 , 32 ,
397
+ rebuildSplatCst},
398
+ {HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1 , 64 ,
399
+ rebuildSplatCst},
400
+ {HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1 , 128 ,
401
+ rebuildSplatCst}};
402
+ return FixupConstant (Fixups, 1 );
403
+ }
378
404
case X86::VMOVDQA32Z128rm:
379
405
case X86::VMOVDQA64Z128rm:
380
406
case X86::VMOVDQU32Z128rm:
381
- case X86::VMOVDQU64Z128rm:
382
- return FixupConstant (0 , 0 , X86::VPBROADCASTQZ128rm, X86::VPBROADCASTDZ128rm,
383
- HasBWI ? X86::VPBROADCASTWZ128rm : 0 ,
384
- HasBWI ? X86::VPBROADCASTBZ128rm : 0 ,
385
- X86::VMOVQI2PQIZrm, X86::VMOVDI2PDIZrm, 1 );
407
+ case X86::VMOVDQU64Z128rm: {
408
+ FixupEntry Fixups[] = {
409
+ {HasBWI ? X86::VPBROADCASTBZ128rm : 0 , 1 , 8 , rebuildSplatCst},
410
+ {HasBWI ? X86::VPBROADCASTWZ128rm : 0 , 1 , 16 , rebuildSplatCst},
411
+ {X86::VMOVDI2PDIZrm, 1 , 32 , rebuildZeroUpperCst},
412
+ {X86::VPBROADCASTDZ128rm, 1 , 32 , rebuildSplatCst},
413
+ {X86::VMOVQI2PQIZrm, 1 , 64 , rebuildZeroUpperCst},
414
+ {X86::VPBROADCASTQZ128rm, 1 , 64 , rebuildSplatCst}};
415
+ return FixupConstant (Fixups, 1 );
416
+ }
386
417
case X86::VMOVDQA32Z256rm:
387
418
case X86::VMOVDQA64Z256rm:
388
419
case X86::VMOVDQU32Z256rm:
389
- case X86::VMOVDQU64Z256rm:
390
- return FixupConstant (0 , X86::VBROADCASTI32X4Z256rm, X86::VPBROADCASTQZ256rm,
391
- X86::VPBROADCASTDZ256rm,
392
- HasBWI ? X86::VPBROADCASTWZ256rm : 0 ,
393
- HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 0 , 0 , 1 );
420
+ case X86::VMOVDQU64Z256rm: {
421
+ FixupEntry Fixups[] = {
422
+ {HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 1 , 8 , rebuildSplatCst},
423
+ {HasBWI ? X86::VPBROADCASTWZ256rm : 0 , 1 , 16 , rebuildSplatCst},
424
+ {X86::VPBROADCASTDZ256rm, 1 , 32 , rebuildSplatCst},
425
+ {X86::VPBROADCASTQZ256rm, 1 , 64 , rebuildSplatCst},
426
+ {X86::VBROADCASTI32X4Z256rm, 1 , 128 , rebuildSplatCst}};
427
+ return FixupConstant (Fixups, 1 );
428
+ }
394
429
case X86::VMOVDQA32Zrm:
395
430
case X86::VMOVDQA64Zrm:
396
431
case X86::VMOVDQU32Zrm:
397
- case X86::VMOVDQU64Zrm:
398
- return FixupConstant (X86::VBROADCASTI64X4rm, X86::VBROADCASTI32X4rm,
399
- X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm,
400
- HasBWI ? X86::VPBROADCASTWZrm : 0 ,
401
- HasBWI ? X86::VPBROADCASTBZrm : 0 , 0 , 0 , 1 );
432
+ case X86::VMOVDQU64Zrm: {
433
+ FixupEntry Fixups[] = {
434
+ {HasBWI ? X86::VPBROADCASTBZrm : 0 , 1 , 8 , rebuildSplatCst},
435
+ {HasBWI ? X86::VPBROADCASTWZrm : 0 , 1 , 16 , rebuildSplatCst},
436
+ {X86::VPBROADCASTDZrm, 1 , 32 , rebuildSplatCst},
437
+ {X86::VPBROADCASTQZrm, 1 , 64 , rebuildSplatCst},
438
+ {X86::VBROADCASTI32X4rm, 1 , 128 , rebuildSplatCst},
439
+ {X86::VBROADCASTI64X4rm, 1 , 256 , rebuildSplatCst}};
440
+ return FixupConstant (Fixups, 1 );
441
+ }
402
442
}
403
443
404
444
auto ConvertToBroadcastAVX512 = [&](unsigned OpSrc32, unsigned OpSrc64) {
@@ -423,7 +463,9 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
423
463
424
464
if (OpBcst32 || OpBcst64) {
425
465
unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32;
426
- return FixupConstant (0 , 0 , OpBcst64, OpBcst32, 0 , 0 , 0 , 0 , OpNo);
466
+ FixupEntry Fixups[] = {{(int )OpBcst32, 32 , 32 , rebuildSplatCst},
467
+ {(int )OpBcst64, 64 , 64 , rebuildSplatCst}};
468
+ return FixupConstant (Fixups, OpNo);
427
469
}
428
470
return false ;
429
471
};
0 commit comments