Skip to content

Commit 11bf02e

Browse files
authored
DAG: Fix ABI lowering with FP promote in strictfp functions (#74405)
This was emitting non-strict casts in ABI contexts for illegal types.
1 parent fdcb76f commit 11bf02e

File tree

7 files changed

+1084
-225
lines changed

7 files changed

+1084
-225
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ static const unsigned MaxParallelChains = 64;
153153
static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
154154
const SDValue *Parts, unsigned NumParts,
155155
MVT PartVT, EVT ValueVT, const Value *V,
156+
SDValue InChain,
156157
std::optional<CallingConv::ID> CC);
157158

158159
/// getCopyFromParts - Create a value that contains the specified legal parts
@@ -163,6 +164,7 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
163164
static SDValue
164165
getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
165166
unsigned NumParts, MVT PartVT, EVT ValueVT, const Value *V,
167+
SDValue InChain,
166168
std::optional<CallingConv::ID> CC = std::nullopt,
167169
std::optional<ISD::NodeType> AssertOp = std::nullopt) {
168170
// Let the target assemble the parts if it wants to
@@ -173,7 +175,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
173175

174176
if (ValueVT.isVector())
175177
return getCopyFromPartsVector(DAG, DL, Parts, NumParts, PartVT, ValueVT, V,
176-
CC);
178+
InChain, CC);
177179

178180
assert(NumParts > 0 && "No parts to assemble!");
179181
SDValue Val = Parts[0];
@@ -194,10 +196,10 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
194196
EVT HalfVT = EVT::getIntegerVT(*DAG.getContext(), RoundBits/2);
195197

196198
if (RoundParts > 2) {
197-
Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2,
198-
PartVT, HalfVT, V);
199-
Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2,
200-
RoundParts / 2, PartVT, HalfVT, V);
199+
Lo = getCopyFromParts(DAG, DL, Parts, RoundParts / 2, PartVT, HalfVT, V,
200+
InChain);
201+
Hi = getCopyFromParts(DAG, DL, Parts + RoundParts / 2, RoundParts / 2,
202+
PartVT, HalfVT, V, InChain);
201203
} else {
202204
Lo = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[0]);
203205
Hi = DAG.getNode(ISD::BITCAST, DL, HalfVT, Parts[1]);
@@ -213,7 +215,7 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
213215
unsigned OddParts = NumParts - RoundParts;
214216
EVT OddVT = EVT::getIntegerVT(*DAG.getContext(), OddParts * PartBits);
215217
Hi = getCopyFromParts(DAG, DL, Parts + RoundParts, OddParts, PartVT,
216-
OddVT, V, CC);
218+
OddVT, V, InChain, CC);
217219

218220
// Combine the round and odd parts.
219221
Lo = Val;
@@ -243,7 +245,8 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
243245
assert(ValueVT.isFloatingPoint() && PartVT.isInteger() &&
244246
!PartVT.isVector() && "Unexpected split");
245247
EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), ValueVT.getSizeInBits());
246-
Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V, CC);
248+
Val = getCopyFromParts(DAG, DL, Parts, NumParts, PartVT, IntVT, V,
249+
InChain, CC);
247250
}
248251
}
249252

@@ -283,10 +286,20 @@ getCopyFromParts(SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts,
283286

284287
if (PartEVT.isFloatingPoint() && ValueVT.isFloatingPoint()) {
285288
// FP_ROUND's are always exact here.
286-
if (ValueVT.bitsLT(Val.getValueType()))
287-
return DAG.getNode(
288-
ISD::FP_ROUND, DL, ValueVT, Val,
289-
DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout())));
289+
if (ValueVT.bitsLT(Val.getValueType())) {
290+
291+
SDValue NoChange =
292+
DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout()));
293+
294+
if (DAG.getMachineFunction().getFunction().getAttributes().hasFnAttr(
295+
llvm::Attribute::StrictFP)) {
296+
return DAG.getNode(ISD::STRICT_FP_ROUND, DL,
297+
DAG.getVTList(ValueVT, MVT::Other), InChain, Val,
298+
NoChange);
299+
}
300+
301+
return DAG.getNode(ISD::FP_ROUND, DL, ValueVT, Val, NoChange);
302+
}
290303

291304
return DAG.getNode(ISD::FP_EXTEND, DL, ValueVT, Val);
292305
}
@@ -324,6 +337,7 @@ static void diagnosePossiblyInvalidConstraint(LLVMContext &Ctx, const Value *V,
324337
static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
325338
const SDValue *Parts, unsigned NumParts,
326339
MVT PartVT, EVT ValueVT, const Value *V,
340+
SDValue InChain,
327341
std::optional<CallingConv::ID> CallConv) {
328342
assert(ValueVT.isVector() && "Not a vector value");
329343
assert(NumParts > 0 && "No parts to assemble!");
@@ -362,17 +376,17 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
362376
// If the register was not expanded, truncate or copy the value,
363377
// as appropriate.
364378
for (unsigned i = 0; i != NumParts; ++i)
365-
Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1,
366-
PartVT, IntermediateVT, V, CallConv);
379+
Ops[i] = getCopyFromParts(DAG, DL, &Parts[i], 1, PartVT, IntermediateVT,
380+
V, InChain, CallConv);
367381
} else if (NumParts > 0) {
368382
// If the intermediate type was expanded, build the intermediate
369383
// operands from the parts.
370384
assert(NumParts % NumIntermediates == 0 &&
371385
"Must expand into a divisible number of parts!");
372386
unsigned Factor = NumParts / NumIntermediates;
373387
for (unsigned i = 0; i != NumIntermediates; ++i)
374-
Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor,
375-
PartVT, IntermediateVT, V, CallConv);
388+
Ops[i] = getCopyFromParts(DAG, DL, &Parts[i * Factor], Factor, PartVT,
389+
IntermediateVT, V, InChain, CallConv);
376390
}
377391

378392
// Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the
@@ -926,7 +940,7 @@ SDValue RegsForValue::getCopyFromRegs(SelectionDAG &DAG,
926940
}
927941

928942
Values[Value] = getCopyFromParts(DAG, dl, Parts.begin(), NumRegs,
929-
RegisterVT, ValueVT, V, CallConv);
943+
RegisterVT, ValueVT, V, Chain, CallConv);
930944
Part += NumRegs;
931945
Parts.clear();
932946
}
@@ -10628,9 +10642,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
1062810642
unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(),
1062910643
CLI.CallConv, VT);
1063010644

10631-
ReturnValues.push_back(getCopyFromParts(CLI.DAG, CLI.DL, &InVals[CurReg],
10632-
NumRegs, RegisterVT, VT, nullptr,
10633-
CLI.CallConv, AssertOp));
10645+
ReturnValues.push_back(getCopyFromParts(
10646+
CLI.DAG, CLI.DL, &InVals[CurReg], NumRegs, RegisterVT, VT, nullptr,
10647+
CLI.Chain, CLI.CallConv, AssertOp));
1063410648
CurReg += NumRegs;
1063510649
}
1063610650

@@ -11109,8 +11123,9 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
1110911123
MVT VT = ValueVTs[0].getSimpleVT();
1111011124
MVT RegVT = TLI->getRegisterType(*CurDAG->getContext(), VT);
1111111125
std::optional<ISD::NodeType> AssertOp;
11112-
SDValue ArgValue = getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT,
11113-
nullptr, F.getCallingConv(), AssertOp);
11126+
SDValue ArgValue =
11127+
getCopyFromParts(DAG, dl, &InVals[0], 1, RegVT, VT, nullptr, NewRoot,
11128+
F.getCallingConv(), AssertOp);
1111411129

1111511130
MachineFunction& MF = SDB->DAG.getMachineFunction();
1111611131
MachineRegisterInfo& RegInfo = MF.getRegInfo();
@@ -11182,7 +11197,7 @@ void SelectionDAGISel::LowerArguments(const Function &F) {
1118211197
AssertOp = ISD::AssertZext;
1118311198

1118411199
ArgValues.push_back(getCopyFromParts(DAG, dl, &InVals[i], NumParts,
11185-
PartVT, VT, nullptr,
11200+
PartVT, VT, nullptr, NewRoot,
1118611201
F.getCallingConv(), AssertOp));
1118711202
}
1118811203

llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,52 +1094,11 @@ define <4 x i1> @isnan_v4bf16(<4 x bfloat> %x) nounwind {
10941094
ret <4 x i1> %1
10951095
}
10961096

1097-
define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
1098-
; GFX7CHECK-LABEL: isnan_bf16_strictfp:
1099-
; GFX7CHECK: ; %bb.0:
1100-
; GFX7CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1101-
; GFX7CHECK-NEXT: v_bfe_u32 v0, v0, 16, 15
1102-
; GFX7CHECK-NEXT: s_movk_i32 s4, 0x7f80
1103-
; GFX7CHECK-NEXT: v_cmp_lt_i32_e32 vcc, s4, v0
1104-
; GFX7CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc
1105-
; GFX7CHECK-NEXT: s_setpc_b64 s[30:31]
1106-
;
1107-
; GFX8CHECK-LABEL: isnan_bf16_strictfp:
1108-
; GFX8CHECK: ; %bb.0:
1109-
; GFX8CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1110-
; GFX8CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1111-
; GFX8CHECK-NEXT: s_movk_i32 s4, 0x7f80
1112-
; GFX8CHECK-NEXT: v_cmp_lt_i16_e32 vcc, s4, v0
1113-
; GFX8CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc
1114-
; GFX8CHECK-NEXT: s_setpc_b64 s[30:31]
1115-
;
1116-
; GFX9CHECK-LABEL: isnan_bf16_strictfp:
1117-
; GFX9CHECK: ; %bb.0:
1118-
; GFX9CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1119-
; GFX9CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1120-
; GFX9CHECK-NEXT: s_movk_i32 s4, 0x7f80
1121-
; GFX9CHECK-NEXT: v_cmp_lt_i16_e32 vcc, s4, v0
1122-
; GFX9CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc
1123-
; GFX9CHECK-NEXT: s_setpc_b64 s[30:31]
1124-
;
1125-
; GFX10CHECK-LABEL: isnan_bf16_strictfp:
1126-
; GFX10CHECK: ; %bb.0:
1127-
; GFX10CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1128-
; GFX10CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1129-
; GFX10CHECK-NEXT: v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
1130-
; GFX10CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc_lo
1131-
; GFX10CHECK-NEXT: s_setpc_b64 s[30:31]
1132-
;
1133-
; GFX11CHECK-LABEL: isnan_bf16_strictfp:
1134-
; GFX11CHECK: ; %bb.0:
1135-
; GFX11CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1136-
; GFX11CHECK-NEXT: v_and_b32_e32 v0, 0x7fff, v0
1137-
; GFX11CHECK-NEXT: v_cmp_lt_i16_e32 vcc_lo, 0x7f80, v0
1138-
; GFX11CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc_lo
1139-
; GFX11CHECK-NEXT: s_setpc_b64 s[30:31]
1140-
%1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
1141-
ret i1 %1
1142-
}
1097+
; FIXME: Broken for gfx6/7
1098+
; define i1 @isnan_bf16_strictfp(bfloat %x) strictfp nounwind {
1099+
; %1 = call i1 @llvm.is.fpclass.bf16(bfloat %x, i32 3) strictfp ; nan
1100+
; ret i1 %1
1101+
; }
11431102

11441103
define i1 @isinf_bf16(bfloat %x) nounwind {
11451104
; GFX7CHECK-LABEL: isinf_bf16:

llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.f16.ll

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,9 @@ define i1 @isnan_f16_strictfp(half %x) strictfp nounwind {
13161316
; GFX7SELDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
13171317
; GFX7SELDAG-NEXT: v_cvt_f16_f32_e32 v0, v0
13181318
; GFX7SELDAG-NEXT: s_movk_i32 s4, 0x7c00
1319+
; GFX7SELDAG-NEXT: v_and_b32_e32 v0, 0xffff, v0
1320+
; GFX7SELDAG-NEXT: v_cvt_f32_f16_e32 v0, v0
1321+
; GFX7SELDAG-NEXT: v_cvt_f16_f32_e32 v0, v0
13191322
; GFX7SELDAG-NEXT: v_and_b32_e32 v0, 0x7fff, v0
13201323
; GFX7SELDAG-NEXT: v_cmp_lt_i32_e32 vcc, s4, v0
13211324
; GFX7SELDAG-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc

llvm/test/CodeGen/AMDGPU/strict_fp_casts.ll

Lines changed: 0 additions & 110 deletions
This file was deleted.

0 commit comments

Comments
 (0)