Skip to content

Commit bfe6bc0

Browse files
committed
AMDGPU: Cleanup check for integral exponents in pow folds
Also improves undef handling https://reviews.llvm.org/D159006
1 parent b0272d8 commit bfe6bc0

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

+25-19
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,29 @@ bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
549549
return true;
550550
}
551551

552+
static bool isKnownIntegral(const Value *V) {
553+
if (isa<UndefValue>(V))
554+
return true;
555+
556+
if (const ConstantFP *CF = dyn_cast<ConstantFP>(V))
557+
return CF->getValueAPF().isInteger();
558+
559+
if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
560+
for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
561+
Constant *ConstElt = CDV->getElementAsConstant(i);
562+
if (isa<UndefValue>(ConstElt))
563+
continue;
564+
const ConstantFP *CFP = dyn_cast<ConstantFP>(ConstElt);
565+
if (!CFP || !CFP->getValue().isInteger())
566+
return false;
567+
}
568+
569+
return true;
570+
}
571+
572+
return false;
573+
}
574+
552575
// This function returns false if no change; return true otherwise.
553576
bool AMDGPULibCalls::fold(CallInst *CI) {
554577
Function *Callee = CI->getCalledFunction();
@@ -972,25 +995,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
972995
if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) {
973996
// We cannot handle corner cases for a general pow() function, give up
974997
// unless y is a constant integral value. Then proceed as if it were pown.
975-
if (getVecSize(FInfo) == 1) {
976-
if (const ConstantFP *CF = dyn_cast<ConstantFP>(opr1)) {
977-
double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
978-
? (double)CF->getValueAPF().convertToFloat()
979-
: CF->getValueAPF().convertToDouble();
980-
if (y != (double)(int64_t)y)
981-
return false;
982-
} else
983-
return false;
984-
} else {
985-
if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1)) {
986-
for (int i=0; i < getVecSize(FInfo); ++i) {
987-
double y = CDV->getElementAsAPFloat(i).convertToDouble();
988-
if (y != (double)(int64_t)y)
989-
return false;
990-
}
991-
} else
992-
return false;
993-
}
998+
if (!isKnownIntegral(opr1))
999+
return false;
9941000
}
9951001

9961002
Value *nval;

0 commit comments

Comments
 (0)