diff --git a/llvm/lib/Target/X86/X86LowerTileCopy.cpp b/llvm/lib/Target/X86/X86LowerTileCopy.cpp index e7afc49240e54..2ca21e69e5919 100644 --- a/llvm/lib/Target/X86/X86LowerTileCopy.cpp +++ b/llvm/lib/Target/X86/X86LowerTileCopy.cpp @@ -20,6 +20,7 @@ #include "X86InstrBuilder.h" #include "X86InstrInfo.h" #include "X86Subtarget.h" +#include "llvm/CodeGen/LiveRegUnits.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -72,10 +73,16 @@ FunctionPass *llvm::createX86LowerTileCopyPass() { bool X86LowerTileCopy::runOnMachineFunction(MachineFunction &MF) { const X86Subtarget &ST = MF.getSubtarget(); const X86InstrInfo *TII = ST.getInstrInfo(); + const TargetRegisterInfo *TRI = ST.getRegisterInfo(); + BitVector GR64Regs = + TRI->getAllocatableSet(MF, TRI->getRegClass(X86::GR64RegClassID)); bool Changed = false; for (MachineBasicBlock &MBB : MF) { - for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) { + LiveRegUnits UsedRegs(*TRI); + UsedRegs.addLiveOuts(MBB); + for (MachineInstr &MI : llvm::make_early_inc_range(reverse(MBB))) { + UsedRegs.stepBackward(MI); if (!MI.isCopy()) continue; MachineOperand &DstMO = MI.getOperand(0); @@ -85,27 +92,41 @@ bool X86LowerTileCopy::runOnMachineFunction(MachineFunction &MF) { if (!X86::TILERegClass.contains(DstReg, SrcReg)) continue; - const TargetRegisterInfo *TRI = ST.getRegisterInfo(); // Allocate stack slot for tile register unsigned Size = TRI->getSpillSize(X86::TILERegClass); Align Alignment = TRI->getSpillAlign(X86::TILERegClass); int TileSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment); - // Allocate stack slot for stride register - Size = TRI->getSpillSize(X86::GR64RegClass); - Alignment = TRI->getSpillAlign(X86::GR64RegClass); - int StrideSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment); - // TODO: Pick a killed regiter to avoid save/reload. There is problem - // to get live interval in this stage. - Register GR64Cand = X86::RAX; + int StrideSS = 0; + + // Pick a killed register to avoid a save/reload. + Register GR64Cand = X86::NoRegister; + for (auto RegT : GR64Regs.set_bits()) { + if (UsedRegs.available(RegT)) { + GR64Cand = RegT; + break; + } + } const DebugLoc &DL = MI.getDebugLoc(); - // mov %rax (%sp) - BuildMI(MBB, MI, DL, TII->get(X86::IMPLICIT_DEF), GR64Cand); - addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64mr)), StrideSS) - .addReg(GR64Cand); - // mov 64 %rax - BuildMI(MBB, MI, DL, TII->get(X86::MOV64ri), GR64Cand).addImm(64); + if (GR64Cand) { + // mov 64 %reg + BuildMI(MBB, MI, DL, TII->get(X86::MOV64ri), GR64Cand).addImm(64); + } else { + // No available register? Save RAX and reload it after use. + + // Allocate stack slot for stride register + Size = TRI->getSpillSize(X86::GR64RegClass); + Alignment = TRI->getSpillAlign(X86::GR64RegClass); + StrideSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment); + + // mov %reg (%sp) + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64mr)), + StrideSS) + .addReg(X86::RAX); + // mov 64 %reg + BuildMI(MBB, MI, DL, TII->get(X86::MOV64ri), X86::RAX).addImm(64); + } // tilestored %tmm, (%sp, %idx) #define GET_EGPR_IF_ENABLED(OPC) (ST.hasEGPR() ? OPC##_EVEX : OPC) unsigned Opc = GET_EGPR_IF_ENABLED(X86::TILESTORED); @@ -120,10 +141,12 @@ bool X86LowerTileCopy::runOnMachineFunction(MachineFunction &MF) { #undef GET_EGPR_IF_ENABLED NewMI = addFrameReference(BuildMI(MBB, MI, DL, TII->get(Opc), DstReg), TileSS); - // restore %rax - // mov (%sp) %rax - addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64rm), GR64Cand), - StrideSS); + if (!GR64Cand) { + // restore %rax + // mov (%sp) %rax + addFrameReference( + BuildMI(MBB, MI, DL, TII->get(X86::MOV64rm), GR64Cand), StrideSS); + } MI.eraseFromParent(); Changed = true; } diff --git a/llvm/test/CodeGen/X86/AMX/amx-lower-tile-copy.ll b/llvm/test/CodeGen/X86/AMX/amx-lower-tile-copy.ll index 4686361ad2fcf..a0085afbaf025 100644 --- a/llvm/test/CodeGen/X86/AMX/amx-lower-tile-copy.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-lower-tile-copy.ll @@ -44,12 +44,8 @@ define dso_local void @test1(ptr%buf) nounwind { ; CHECK-NEXT: tileloadd 3024(%rsp,%rax), %tmm3 # 1024-byte Folded Reload ; CHECK-NEXT: tileloadd (%rbx,%r15), %tmm0 ; CHECK-NEXT: tileloadd (%rbx,%r15), %tmm1 -; CHECK-NEXT: # implicit-def: $rax -; CHECK-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill -; CHECK-NEXT: movabsq $64, %rax ; CHECK-NEXT: tilestored %tmm3, 1024(%rsp,%rax) # 1024-byte Folded Spill ; CHECK-NEXT: tileloadd {{[-0-9]+}}(%r{{[sb]}}p), %tmm2 # 1024-byte Folded Reload -; CHECK-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload ; CHECK-NEXT: tdpbssd %tmm1, %tmm0, %tmm2 ; CHECK-NEXT: tilestored %tmm2, (%rbx,%r15) ; CHECK-NEXT: incl %r14d @@ -111,16 +107,10 @@ define dso_local void @test1(ptr%buf) nounwind { ; EGPR-NEXT: # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7b,0x4b,0x9c,0x04,0xd0,0x0b,0x00,0x00] ; EGPR-NEXT: tileloadd (%rbx,%r15), %tmm0 # EVEX TO VEX Compression encoding: [0xc4,0xa2,0x7b,0x4b,0x04,0x3b] ; EGPR-NEXT: tileloadd (%rbx,%r15), %tmm1 # EVEX TO VEX Compression encoding: [0xc4,0xa2,0x7b,0x4b,0x0c,0x3b] -; EGPR-NEXT: # implicit-def: $rax -; EGPR-NEXT: movq %rax, {{[-0-9]+}}(%r{{[sb]}}p) # 8-byte Spill -; EGPR-NEXT: # encoding: [0x48,0x89,0x84,0x24,0xb8,0x03,0x00,0x00] -; EGPR-NEXT: movabsq $64, %rax # encoding: [0x48,0xb8,0x40,0x00,0x00,0x00,0x00,0x00,0x00,0x00] ; EGPR-NEXT: tilestored %tmm3, 1024(%rsp,%rax) # 1024-byte Folded Spill ; EGPR-NEXT: # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7a,0x4b,0x9c,0x04,0x00,0x04,0x00,0x00] ; EGPR-NEXT: tileloadd {{[-0-9]+}}(%r{{[sb]}}p), %tmm2 # 1024-byte Folded Reload ; EGPR-NEXT: # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7b,0x4b,0x94,0x24,0x00,0x04,0x00,0x00] -; EGPR-NEXT: movq {{[-0-9]+}}(%r{{[sb]}}p), %rax # 8-byte Reload -; EGPR-NEXT: # encoding: [0x48,0x8b,0x84,0x24,0xb8,0x03,0x00,0x00] ; EGPR-NEXT: tdpbssd %tmm1, %tmm0, %tmm2 # encoding: [0xc4,0xe2,0x73,0x5e,0xd0] ; EGPR-NEXT: tilestored %tmm2, (%rbx,%r15) # EVEX TO VEX Compression encoding: [0xc4,0xa2,0x7a,0x4b,0x14,0x3b] ; EGPR-NEXT: incl %r14d # encoding: [0x41,0xff,0xc6]