Skip to content

Commit f43ad88

Browse files
authored
[RISCV] Handle zvfhmin and zvfbfmin promotion to f32 in half arith costs (#108361)
Arithmetic half or bfloat ops on zvfhmin and zvfbfmin respectively will be promoted and carried out in f32, so this updates getArithmeticInstrCost to check for this.
1 parent 63b534b commit f43ad88

File tree

2 files changed

+172
-87
lines changed

2 files changed

+172
-87
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

+28-7
Original file line numberDiff line numberDiff line change
@@ -1908,6 +1908,29 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19081908
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, Op2Info,
19091909
Args, CxtI);
19101910

1911+
// f16 with zvfhmin and bf16 will be promoted to f32.
1912+
// FIXME: nxv32[b]f16 will be custom lowered and split.
1913+
unsigned ISDOpcode = TLI->InstructionOpcodeToISD(Opcode);
1914+
InstructionCost CastCost = 0;
1915+
if ((LT.second.getVectorElementType() == MVT::f16 ||
1916+
LT.second.getVectorElementType() == MVT::bf16) &&
1917+
TLI->getOperationAction(ISDOpcode, LT.second) ==
1918+
TargetLoweringBase::LegalizeAction::Promote) {
1919+
MVT PromotedVT = TLI->getTypeToPromoteTo(ISDOpcode, LT.second);
1920+
Type *PromotedTy = EVT(PromotedVT).getTypeForEVT(Ty->getContext());
1921+
Type *LegalTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
1922+
// Add cost of extending arguments
1923+
CastCost += LT.first * Args.size() *
1924+
getCastInstrCost(Instruction::FPExt, PromotedTy, LegalTy,
1925+
TTI::CastContextHint::None, CostKind);
1926+
// Add cost of truncating result
1927+
CastCost +=
1928+
LT.first * getCastInstrCost(Instruction::FPTrunc, LegalTy, PromotedTy,
1929+
TTI::CastContextHint::None, CostKind);
1930+
// Compute cost of op in promoted type
1931+
LT.second = PromotedVT;
1932+
}
1933+
19111934
auto getConstantMatCost =
19121935
[&](unsigned Operand, TTI::OperandValueInfo OpInfo) -> InstructionCost {
19131936
if (OpInfo.isUniform() && TLI->canSplatOperand(Opcode, Operand))
@@ -1929,7 +1952,7 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19291952
ConstantMatCost += getConstantMatCost(1, Op2Info);
19301953

19311954
unsigned Op;
1932-
switch (TLI->InstructionOpcodeToISD(Opcode)) {
1955+
switch (ISDOpcode) {
19331956
case ISD::ADD:
19341957
case ISD::SUB:
19351958
Op = RISCV::VADD_VV;
@@ -1959,11 +1982,9 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19591982
break;
19601983
case ISD::FADD:
19611984
case ISD::FSUB:
1962-
// TODO: Address FP16 with VFHMIN
19631985
Op = RISCV::VFADD_VV;
19641986
break;
19651987
case ISD::FMUL:
1966-
// TODO: Address FP16 with VFHMIN
19671988
Op = RISCV::VFMUL_VV;
19681989
break;
19691990
case ISD::FDIV:
@@ -1975,9 +1996,9 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19751996
default:
19761997
// Assuming all other instructions have the same cost until a need arises to
19771998
// differentiate them.
1978-
return ConstantMatCost + BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind,
1979-
Op1Info, Op2Info,
1980-
Args, CxtI);
1999+
return CastCost + ConstantMatCost +
2000+
BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, Op2Info,
2001+
Args, CxtI);
19812002
}
19822003

19832004
InstructionCost InstrCost = getRISCVInstructionCost(Op, LT.second, CostKind);
@@ -1986,7 +2007,7 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
19862007
// scalar floating point ops aren't cheaper than their vector equivalents.
19872008
if (Ty->isFPOrFPVectorTy())
19882009
InstrCost *= 2;
1989-
return ConstantMatCost + LT.first * InstrCost;
2010+
return CastCost + ConstantMatCost + LT.first * InstrCost;
19902011
}
19912012

19922013
// TODO: Deduplicate from TargetTransformInfoImplCRTPBase.

0 commit comments

Comments
 (0)