Skip to content

Commit abe66f7

Browse files
committed
AMDGPU: Don't bitcast float typed atomic store in IR
Implement the promotion in the DAG.
1 parent 9a094b7 commit abe66f7

File tree

5 files changed

+73
-9
lines changed

5 files changed

+73
-9
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4946,7 +4946,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
49464946
Node->getOpcode() == ISD::INSERT_VECTOR_ELT) {
49474947
OVT = Node->getOperand(0).getSimpleValueType();
49484948
}
4949-
if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
4949+
if (Node->getOpcode() == ISD::ATOMIC_STORE ||
4950+
Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
49504951
Node->getOpcode() == ISD::STRICT_SINT_TO_FP ||
49514952
Node->getOpcode() == ISD::STRICT_FSETCC ||
49524953
Node->getOpcode() == ISD::STRICT_FSETCCS)
@@ -5557,7 +5558,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
55575558
Results.push_back(CvtVec);
55585559
break;
55595560
}
5560-
case ISD::ATOMIC_SWAP: {
5561+
case ISD::ATOMIC_SWAP:
5562+
case ISD::ATOMIC_STORE: {
55615563
AtomicSDNode *AM = cast<AtomicSDNode>(Node);
55625564
SDLoc SL(Node);
55635565
SDValue CastVal = DAG.getNode(ISD::BITCAST, SL, NVT, AM->getVal());
@@ -5566,13 +5568,22 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
55665568
assert(AM->getMemoryVT().getSizeInBits() == NVT.getSizeInBits() &&
55675569
"unexpected atomic_swap with illegal type");
55685570

5569-
SDValue NewAtomic
5570-
= DAG.getAtomic(ISD::ATOMIC_SWAP, SL, NVT,
5571-
DAG.getVTList(NVT, MVT::Other),
5572-
{ AM->getChain(), AM->getBasePtr(), CastVal },
5573-
AM->getMemOperand());
5574-
Results.push_back(DAG.getNode(ISD::BITCAST, SL, OVT, NewAtomic));
5575-
Results.push_back(NewAtomic.getValue(1));
5571+
SDValue Op0 = AM->getBasePtr();
5572+
SDValue Op1 = CastVal;
5573+
5574+
// ATOMIC_STORE uses a swapped operand order from every other AtomicSDNode,
5575+
// but really it should merge with ISD::STORE.
5576+
if (AM->getOpcode() == ISD::ATOMIC_STORE)
5577+
std::swap(Op0, Op1);
5578+
5579+
SDValue NewAtomic = DAG.getAtomic(AM->getOpcode(), SL, NVT, AM->getChain(),
5580+
Op0, Op1, AM->getMemOperand());
5581+
5582+
if (AM->getOpcode() != ISD::ATOMIC_STORE) {
5583+
Results.push_back(DAG.getNode(ISD::BITCAST, SL, OVT, NewAtomic));
5584+
Results.push_back(NewAtomic.getValue(1));
5585+
} else
5586+
Results.push_back(NewAtomic);
55765587
break;
55775588
}
55785589
case ISD::SPLAT_VECTOR: {

llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,6 +2249,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
22492249
case ISD::SELECT_CC: R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
22502250
case ISD::SETCC: R = PromoteFloatOp_SETCC(N, OpNo); break;
22512251
case ISD::STORE: R = PromoteFloatOp_STORE(N, OpNo); break;
2252+
case ISD::ATOMIC_STORE: R = PromoteFloatOp_ATOMIC_STORE(N, OpNo); break;
22522253
}
22532254
// clang-format on
22542255

@@ -2371,6 +2372,23 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_STORE(SDNode *N, unsigned OpNo) {
23712372
ST->getMemOperand());
23722373
}
23732374

2375+
SDValue DAGTypeLegalizer::PromoteFloatOp_ATOMIC_STORE(SDNode *N,
2376+
unsigned OpNo) {
2377+
AtomicSDNode *ST = cast<AtomicSDNode>(N);
2378+
SDValue Val = ST->getVal();
2379+
SDLoc DL(N);
2380+
2381+
SDValue Promoted = GetPromotedFloat(Val);
2382+
EVT VT = ST->getOperand(1).getValueType();
2383+
EVT IVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
2384+
2385+
SDValue NewVal = DAG.getNode(GetPromotionOpcode(Promoted.getValueType(), VT),
2386+
DL, IVT, Promoted);
2387+
2388+
return DAG.getAtomic(ISD::ATOMIC_STORE, DL, IVT, ST->getChain(), NewVal,
2389+
ST->getBasePtr(), ST->getMemOperand());
2390+
}
2391+
23742392
//===----------------------------------------------------------------------===//
23752393
// Float Result Promotion
23762394
//===----------------------------------------------------------------------===//
@@ -3154,6 +3172,9 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) {
31543172
case ISD::SELECT_CC: Res = SoftPromoteHalfOp_SELECT_CC(N, OpNo); break;
31553173
case ISD::SETCC: Res = SoftPromoteHalfOp_SETCC(N); break;
31563174
case ISD::STORE: Res = SoftPromoteHalfOp_STORE(N, OpNo); break;
3175+
case ISD::ATOMIC_STORE:
3176+
Res = SoftPromoteHalfOp_ATOMIC_STORE(N, OpNo);
3177+
break;
31573178
case ISD::STACKMAP:
31583179
Res = SoftPromoteHalfOp_STACKMAP(N, OpNo);
31593180
break;
@@ -3307,6 +3328,19 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_STORE(SDNode *N, unsigned OpNo) {
33073328
ST->getMemOperand());
33083329
}
33093330

3331+
SDValue DAGTypeLegalizer::SoftPromoteHalfOp_ATOMIC_STORE(SDNode *N,
3332+
unsigned OpNo) {
3333+
assert(OpNo == 1 && "Can only soften the stored value!");
3334+
AtomicSDNode *ST = cast<AtomicSDNode>(N);
3335+
SDValue Val = ST->getVal();
3336+
SDLoc dl(N);
3337+
3338+
SDValue Promoted = GetSoftPromotedHalf(Val);
3339+
return DAG.getAtomic(ISD::ATOMIC_STORE, dl, Promoted.getValueType(),
3340+
ST->getChain(), Promoted, ST->getBasePtr(),
3341+
ST->getMemOperand());
3342+
}
3343+
33103344
SDValue DAGTypeLegalizer::SoftPromoteHalfOp_STACKMAP(SDNode *N, unsigned OpNo) {
33113345
assert(OpNo > 1); // Because the first two arguments are guaranteed legal.
33123346
SmallVector<SDValue> NewOps(N->ops().begin(), N->ops().end());

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
708708
SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
709709
SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
710710
SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
711+
SDValue PromoteFloatOp_ATOMIC_STORE(SDNode *N, unsigned OpNo);
711712
SDValue PromoteFloatOp_SELECT_CC(SDNode *N, unsigned OpNo);
712713
SDValue PromoteFloatOp_SETCC(SDNode *N, unsigned OpNo);
713714

@@ -751,6 +752,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
751752
SDValue SoftPromoteHalfOp_SETCC(SDNode *N);
752753
SDValue SoftPromoteHalfOp_SELECT_CC(SDNode *N, unsigned OpNo);
753754
SDValue SoftPromoteHalfOp_STORE(SDNode *N, unsigned OpNo);
755+
SDValue SoftPromoteHalfOp_ATOMIC_STORE(SDNode *N, unsigned OpNo);
754756
SDValue SoftPromoteHalfOp_STACKMAP(SDNode *N, unsigned OpNo);
755757
SDValue SoftPromoteHalfOp_PATCHPOINT(SDNode *N, unsigned OpNo);
756758

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,18 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
148148
setOperationAction(ISD::LOAD, MVT::i128, Promote);
149149
AddPromotedToType(ISD::LOAD, MVT::i128, MVT::v4i32);
150150

151+
setOperationAction(ISD::ATOMIC_STORE, MVT::f32, Promote);
152+
AddPromotedToType(ISD::ATOMIC_STORE, MVT::f32, MVT::i32);
153+
154+
setOperationAction(ISD::ATOMIC_STORE, MVT::f64, Promote);
155+
AddPromotedToType(ISD::ATOMIC_STORE, MVT::f64, MVT::i64);
156+
157+
setOperationAction(ISD::ATOMIC_STORE, MVT::f16, Promote);
158+
AddPromotedToType(ISD::ATOMIC_STORE, MVT::f16, MVT::i16);
159+
160+
setOperationAction(ISD::ATOMIC_STORE, MVT::bf16, Promote);
161+
AddPromotedToType(ISD::ATOMIC_STORE, MVT::bf16, MVT::i16);
162+
151163
// There are no 64-bit extloads. These should be done as a 32-bit extload and
152164
// an extension to 64-bit.
153165
for (MVT VT : MVT::integer_valuetypes())

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ class AMDGPUTargetLowering : public TargetLowering {
230230
bool isCheapToSpeculateCtlz(Type *Ty) const override;
231231

232232
bool isSDNodeAlwaysUniform(const SDNode *N) const override;
233+
234+
AtomicExpansionKind shouldCastAtomicStoreInIR(StoreInst *SI) const override {
235+
return AtomicExpansionKind::None;
236+
}
237+
233238
static CCAssignFn *CCAssignFnForCall(CallingConv::ID CC, bool IsVarArg);
234239
static CCAssignFn *CCAssignFnForReturn(CallingConv::ID CC, bool IsVarArg);
235240

0 commit comments

Comments
 (0)