10
10
// replace them with smaller constant pool entries, including:
11
11
// * Converting AVX512 memory-fold instructions to their broadcast-fold form
12
12
// * Broadcasting of full width loads.
13
- // * TODO: Sign/ Zero extension of full width loads.
13
+ // * TODO: Zero extension of full width loads.
14
14
//
15
15
// ===----------------------------------------------------------------------===//
16
16
@@ -265,11 +265,47 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
265
265
return nullptr ;
266
266
}
267
267
268
+ static Constant *rebuildExtCst (const Constant *C, bool IsSExt, unsigned NumElts,
269
+ unsigned SrcEltBitWidth) {
270
+ Type *Ty = C->getType ();
271
+ unsigned NumBits = Ty->getPrimitiveSizeInBits ();
272
+ unsigned DstEltBitWidth = NumBits / NumElts;
273
+ assert ((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
274
+ (DstEltBitWidth % SrcEltBitWidth) == 0 &&
275
+ (DstEltBitWidth > SrcEltBitWidth) && " Illegal extension width" );
276
+
277
+ if (std::optional<APInt> Bits = extractConstantBits (C)) {
278
+ assert ((Bits->getBitWidth () / DstEltBitWidth) == NumElts &&
279
+ (Bits->getBitWidth () % DstEltBitWidth) == 0 &&
280
+ " Unexpected constant extension" );
281
+
282
+ // Ensure every vector element can be represented by the src bitwidth.
283
+ APInt TruncBits = APInt::getZero (NumElts * SrcEltBitWidth);
284
+ for (unsigned I = 0 ; I != NumElts; ++I) {
285
+ APInt Elt = Bits->extractBits (DstEltBitWidth, I * DstEltBitWidth);
286
+ if ((IsSExt && Elt.getSignificantBits () > SrcEltBitWidth) ||
287
+ (!IsSExt && Elt.getActiveBits () > SrcEltBitWidth))
288
+ return nullptr ;
289
+ TruncBits.insertBits (Elt.trunc (SrcEltBitWidth), I * SrcEltBitWidth);
290
+ }
291
+
292
+ return rebuildConstant (Ty->getContext (), Ty->getScalarType (), TruncBits,
293
+ SrcEltBitWidth);
294
+ }
295
+
296
+ return nullptr ;
297
+ }
298
+ static Constant *rebuildSExtCst (const Constant *C, unsigned NumElts,
299
+ unsigned SrcEltBitWidth) {
300
+ return rebuildExtCst (C, true , NumElts, SrcEltBitWidth);
301
+ }
302
+
268
303
bool X86FixupVectorConstantsPass::processInstruction (MachineFunction &MF,
269
304
MachineBasicBlock &MBB,
270
305
MachineInstr &MI) {
271
306
unsigned Opc = MI.getOpcode ();
272
307
MachineConstantPool *CP = MI.getParent ()->getParent ()->getConstantPool ();
308
+ bool HasSSE41 = ST->hasSSE41 ();
273
309
bool HasAVX2 = ST->hasAVX2 ();
274
310
bool HasDQI = ST->hasDQI ();
275
311
bool HasBWI = ST->hasBWI ();
@@ -312,7 +348,15 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
312
348
return false ;
313
349
};
314
350
315
- // Attempt to convert full width vector loads into broadcast/vzload loads.
351
+ // Attempt to detect a suitable vzload/broadcast/vextload from increasing
352
+ // constant bitwidths. Prefer vzload/broadcast/vextload for same bitwidth:
353
+ // - vzload shouldn't ever need a shuffle port to zero the upper elements and
354
+ // the fp/int domain versions are equally available so we don't introduce a
355
+ // domain crossing penalty.
356
+ // - broadcast sometimes need a shuffle port (especially for 8/16-bit
357
+ // variants), AVX1 only has fp domain broadcasts but AVX2+ have good fp/int
358
+ // domain equivalents.
359
+ // - vextload always needs a shuffle port and is only ever int domain.
316
360
switch (Opc) {
317
361
/* FP Loads */
318
362
case X86::MOVAPDrm:
@@ -370,22 +414,34 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
370
414
/* Integer Loads */
371
415
case X86::MOVDQArm:
372
416
case X86::MOVDQUrm: {
373
- return FixupConstant ({{X86::MOVDI2PDIrm, 1 , 32 , rebuildZeroUpperCst},
374
- {X86::MOVQI2PQIrm, 1 , 64 , rebuildZeroUpperCst}},
375
- 1 );
417
+ FixupEntry Fixups[] = {
418
+ {HasSSE41 ? X86::PMOVSXBQrm : 0 , 2 , 8 , rebuildSExtCst},
419
+ {X86::MOVDI2PDIrm, 1 , 32 , rebuildZeroUpperCst},
420
+ {HasSSE41 ? X86::PMOVSXBDrm : 0 , 4 , 8 , rebuildSExtCst},
421
+ {HasSSE41 ? X86::PMOVSXWQrm : 0 , 2 , 16 , rebuildSExtCst},
422
+ {X86::MOVQI2PQIrm, 1 , 64 , rebuildZeroUpperCst},
423
+ {HasSSE41 ? X86::PMOVSXBWrm : 0 , 8 , 8 , rebuildSExtCst},
424
+ {HasSSE41 ? X86::PMOVSXWDrm : 0 , 4 , 16 , rebuildSExtCst},
425
+ {HasSSE41 ? X86::PMOVSXDQrm : 0 , 2 , 32 , rebuildSExtCst}};
426
+ return FixupConstant (Fixups, 1 );
376
427
}
377
428
case X86::VMOVDQArm:
378
429
case X86::VMOVDQUrm: {
379
430
FixupEntry Fixups[] = {
380
431
{HasAVX2 ? X86::VPBROADCASTBrm : 0 , 1 , 8 , rebuildSplatCst},
381
432
{HasAVX2 ? X86::VPBROADCASTWrm : 0 , 1 , 16 , rebuildSplatCst},
433
+ {X86::VPMOVSXBQrm, 2 , 8 , rebuildSExtCst},
382
434
{X86::VMOVDI2PDIrm, 1 , 32 , rebuildZeroUpperCst},
383
435
{HasAVX2 ? X86::VPBROADCASTDrm : X86::VBROADCASTSSrm, 1 , 32 ,
384
436
rebuildSplatCst},
437
+ {X86::VPMOVSXBDrm, 4 , 8 , rebuildSExtCst},
438
+ {X86::VPMOVSXWQrm, 2 , 16 , rebuildSExtCst},
385
439
{X86::VMOVQI2PQIrm, 1 , 64 , rebuildZeroUpperCst},
386
440
{HasAVX2 ? X86::VPBROADCASTQrm : X86::VMOVDDUPrm, 1 , 64 ,
387
441
rebuildSplatCst},
388
- };
442
+ {X86::VPMOVSXBWrm, 8 , 8 , rebuildSExtCst},
443
+ {X86::VPMOVSXWDrm, 4 , 16 , rebuildSExtCst},
444
+ {X86::VPMOVSXDQrm, 2 , 32 , rebuildSExtCst}};
389
445
return FixupConstant (Fixups, 1 );
390
446
}
391
447
case X86::VMOVDQAYrm:
@@ -395,10 +451,16 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
395
451
{HasAVX2 ? X86::VPBROADCASTWYrm : 0 , 1 , 16 , rebuildSplatCst},
396
452
{HasAVX2 ? X86::VPBROADCASTDYrm : X86::VBROADCASTSSYrm, 1 , 32 ,
397
453
rebuildSplatCst},
454
+ {HasAVX2 ? X86::VPMOVSXBQYrm : 0 , 4 , 8 , rebuildSExtCst},
398
455
{HasAVX2 ? X86::VPBROADCASTQYrm : X86::VBROADCASTSDYrm, 1 , 64 ,
399
456
rebuildSplatCst},
457
+ {HasAVX2 ? X86::VPMOVSXBDYrm : 0 , 8 , 8 , rebuildSExtCst},
458
+ {HasAVX2 ? X86::VPMOVSXWQYrm : 0 , 4 , 16 , rebuildSExtCst},
400
459
{HasAVX2 ? X86::VBROADCASTI128rm : X86::VBROADCASTF128rm, 1 , 128 ,
401
- rebuildSplatCst}};
460
+ rebuildSplatCst},
461
+ {HasAVX2 ? X86::VPMOVSXBWYrm : 0 , 16 , 8 , rebuildSExtCst},
462
+ {HasAVX2 ? X86::VPMOVSXWDYrm : 0 , 8 , 16 , rebuildSExtCst},
463
+ {HasAVX2 ? X86::VPMOVSXDQYrm : 0 , 4 , 32 , rebuildSExtCst}};
402
464
return FixupConstant (Fixups, 1 );
403
465
}
404
466
case X86::VMOVDQA32Z128rm:
@@ -408,10 +470,16 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
408
470
FixupEntry Fixups[] = {
409
471
{HasBWI ? X86::VPBROADCASTBZ128rm : 0 , 1 , 8 , rebuildSplatCst},
410
472
{HasBWI ? X86::VPBROADCASTWZ128rm : 0 , 1 , 16 , rebuildSplatCst},
473
+ {X86::VPMOVSXBQZ128rm, 2 , 8 , rebuildSExtCst},
411
474
{X86::VMOVDI2PDIZrm, 1 , 32 , rebuildZeroUpperCst},
412
475
{X86::VPBROADCASTDZ128rm, 1 , 32 , rebuildSplatCst},
476
+ {X86::VPMOVSXBDZ128rm, 4 , 8 , rebuildSExtCst},
477
+ {X86::VPMOVSXWQZ128rm, 2 , 16 , rebuildSExtCst},
413
478
{X86::VMOVQI2PQIZrm, 1 , 64 , rebuildZeroUpperCst},
414
- {X86::VPBROADCASTQZ128rm, 1 , 64 , rebuildSplatCst}};
479
+ {X86::VPBROADCASTQZ128rm, 1 , 64 , rebuildSplatCst},
480
+ {HasBWI ? X86::VPMOVSXBWZ128rm : 0 , 8 , 8 , rebuildSExtCst},
481
+ {X86::VPMOVSXWDZ128rm, 4 , 16 , rebuildSExtCst},
482
+ {X86::VPMOVSXDQZ128rm, 2 , 32 , rebuildSExtCst}};
415
483
return FixupConstant (Fixups, 1 );
416
484
}
417
485
case X86::VMOVDQA32Z256rm:
@@ -422,8 +490,14 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
422
490
{HasBWI ? X86::VPBROADCASTBZ256rm : 0 , 1 , 8 , rebuildSplatCst},
423
491
{HasBWI ? X86::VPBROADCASTWZ256rm : 0 , 1 , 16 , rebuildSplatCst},
424
492
{X86::VPBROADCASTDZ256rm, 1 , 32 , rebuildSplatCst},
493
+ {X86::VPMOVSXBQZ256rm, 4 , 8 , rebuildSExtCst},
425
494
{X86::VPBROADCASTQZ256rm, 1 , 64 , rebuildSplatCst},
426
- {X86::VBROADCASTI32X4Z256rm, 1 , 128 , rebuildSplatCst}};
495
+ {X86::VPMOVSXBDZ256rm, 8 , 8 , rebuildSExtCst},
496
+ {X86::VPMOVSXWQZ256rm, 4 , 16 , rebuildSExtCst},
497
+ {X86::VBROADCASTI32X4Z256rm, 1 , 128 , rebuildSplatCst},
498
+ {HasBWI ? X86::VPMOVSXBWZ256rm : 0 , 16 , 8 , rebuildSExtCst},
499
+ {X86::VPMOVSXWDZ256rm, 8 , 16 , rebuildSExtCst},
500
+ {X86::VPMOVSXDQZ256rm, 4 , 32 , rebuildSExtCst}};
427
501
return FixupConstant (Fixups, 1 );
428
502
}
429
503
case X86::VMOVDQA32Zrm:
@@ -435,8 +509,14 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
435
509
{HasBWI ? X86::VPBROADCASTWZrm : 0 , 1 , 16 , rebuildSplatCst},
436
510
{X86::VPBROADCASTDZrm, 1 , 32 , rebuildSplatCst},
437
511
{X86::VPBROADCASTQZrm, 1 , 64 , rebuildSplatCst},
512
+ {X86::VPMOVSXBQZrm, 8 , 8 , rebuildSExtCst},
438
513
{X86::VBROADCASTI32X4rm, 1 , 128 , rebuildSplatCst},
439
- {X86::VBROADCASTI64X4rm, 1 , 256 , rebuildSplatCst}};
514
+ {X86::VPMOVSXBDZrm, 16 , 8 , rebuildSExtCst},
515
+ {X86::VPMOVSXWQZrm, 8 , 16 , rebuildSExtCst},
516
+ {X86::VBROADCASTI64X4rm, 1 , 256 , rebuildSplatCst},
517
+ {HasBWI ? X86::VPMOVSXBWZrm : 0 , 32 , 8 , rebuildSExtCst},
518
+ {X86::VPMOVSXWDZrm, 16 , 16 , rebuildSExtCst},
519
+ {X86::VPMOVSXDQZrm, 8 , 32 , rebuildSExtCst}};
440
520
return FixupConstant (Fixups, 1 );
441
521
}
442
522
}
0 commit comments