@@ -1908,6 +1908,29 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
1908
1908
return BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info, Op2Info,
1909
1909
Args, CxtI);
1910
1910
1911
+ // f16 with zvfhmin and bf16 will be promoted to f32.
1912
+ // FIXME: nxv32[b]f16 will be custom lowered and split.
1913
+ unsigned ISDOpcode = TLI->InstructionOpcodeToISD (Opcode);
1914
+ InstructionCost CastCost = 0 ;
1915
+ if ((LT.second .getVectorElementType () == MVT::f16 ||
1916
+ LT.second .getVectorElementType () == MVT::bf16) &&
1917
+ TLI->getOperationAction (ISDOpcode, LT.second ) ==
1918
+ TargetLoweringBase::LegalizeAction::Promote) {
1919
+ MVT PromotedVT = TLI->getTypeToPromoteTo (ISDOpcode, LT.second );
1920
+ Type *PromotedTy = EVT (PromotedVT).getTypeForEVT (Ty->getContext ());
1921
+ Type *LegalTy = EVT (LT.second ).getTypeForEVT (Ty->getContext ());
1922
+ // Add cost of extending arguments
1923
+ CastCost += LT.first * Args.size () *
1924
+ getCastInstrCost (Instruction::FPExt, PromotedTy, LegalTy,
1925
+ TTI::CastContextHint::None, CostKind);
1926
+ // Add cost of truncating result
1927
+ CastCost +=
1928
+ LT.first * getCastInstrCost (Instruction::FPTrunc, LegalTy, PromotedTy,
1929
+ TTI::CastContextHint::None, CostKind);
1930
+ // Compute cost of op in promoted type
1931
+ LT.second = PromotedVT;
1932
+ }
1933
+
1911
1934
auto getConstantMatCost =
1912
1935
[&](unsigned Operand, TTI::OperandValueInfo OpInfo) -> InstructionCost {
1913
1936
if (OpInfo.isUniform () && TLI->canSplatOperand (Opcode, Operand))
@@ -1929,7 +1952,7 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
1929
1952
ConstantMatCost += getConstantMatCost (1 , Op2Info);
1930
1953
1931
1954
unsigned Op;
1932
- switch (TLI-> InstructionOpcodeToISD (Opcode) ) {
1955
+ switch (ISDOpcode ) {
1933
1956
case ISD::ADD:
1934
1957
case ISD::SUB:
1935
1958
Op = RISCV::VADD_VV;
@@ -1959,11 +1982,9 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
1959
1982
break ;
1960
1983
case ISD::FADD:
1961
1984
case ISD::FSUB:
1962
- // TODO: Address FP16 with VFHMIN
1963
1985
Op = RISCV::VFADD_VV;
1964
1986
break ;
1965
1987
case ISD::FMUL:
1966
- // TODO: Address FP16 with VFHMIN
1967
1988
Op = RISCV::VFMUL_VV;
1968
1989
break ;
1969
1990
case ISD::FDIV:
@@ -1975,9 +1996,9 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
1975
1996
default :
1976
1997
// Assuming all other instructions have the same cost until a need arises to
1977
1998
// differentiate them.
1978
- return ConstantMatCost + BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind,
1979
- Op1Info, Op2Info,
1980
- Args, CxtI);
1999
+ return CastCost + ConstantMatCost +
2000
+ BaseT::getArithmeticInstrCost (Opcode, Ty, CostKind, Op1Info, Op2Info,
2001
+ Args, CxtI);
1981
2002
}
1982
2003
1983
2004
InstructionCost InstrCost = getRISCVInstructionCost (Op, LT.second , CostKind);
@@ -1986,7 +2007,7 @@ InstructionCost RISCVTTIImpl::getArithmeticInstrCost(
1986
2007
// scalar floating point ops aren't cheaper than their vector equivalents.
1987
2008
if (Ty->isFPOrFPVectorTy ())
1988
2009
InstrCost *= 2 ;
1989
- return ConstantMatCost + LT.first * InstrCost;
2010
+ return CastCost + ConstantMatCost + LT.first * InstrCost;
1990
2011
}
1991
2012
1992
2013
// TODO: Deduplicate from TargetTransformInfoImplCRTPBase.
0 commit comments