Skip to content

Commit 8f8016f

Browse files
author
Hugh Delaney
authored
[NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} (#114977)
Add patterns to lower `fmaxnum(fma(a, b, c), 0)` to `fma.rn{.ftz}.relu` for `f16`, `f16x2`, `bf16`, `bf16x2` types, when `nnan` is used. `fma_relu` honours `NaN`, so the substitution is only made if the `fma` is `nnan`, since `fmaxnum` returns the non NaN argument when passed a NaN value. This patch also removes some `bf16` ftz instructions since `FTZ` is not supported with the `bf16` type, according to the PTX ISA docs.
1 parent ed8019d commit 8f8016f

File tree

4 files changed

+4152
-14
lines changed

4 files changed

+4152
-14
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -418,25 +418,13 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
418418
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
419419
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
420420
Requires<[useFP16Math, allowFMA]>;
421-
def bf16rr_ftz :
422-
NVPTXInst<(outs Int16Regs:$dst),
423-
(ins Int16Regs:$a, Int16Regs:$b),
424-
!strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
425-
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
426-
Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
427421
def bf16rr :
428422
NVPTXInst<(outs Int16Regs:$dst),
429423
(ins Int16Regs:$a, Int16Regs:$b),
430424
!strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
431425
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
432426
Requires<[hasBF16Math, allowFMA]>;
433427

434-
def bf16x2rr_ftz :
435-
NVPTXInst<(outs Int32Regs:$dst),
436-
(ins Int32Regs:$a, Int32Regs:$b),
437-
!strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
438-
[(set (v2bf16 Int32Regs:$dst), (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
439-
Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
440428
def bf16x2rr :
441429
NVPTXInst<(outs Int32Regs:$dst),
442430
(ins Int32Regs:$a, Int32Regs:$b),
@@ -1423,9 +1411,7 @@ defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
14231411
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
14241412
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
14251413
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
1426-
defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
14271414
defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
1428-
defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
14291415
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
14301416
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
14311417
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
@@ -3959,3 +3945,54 @@ def atomic_thread_fence_seq_cst_cta :
39593945
def atomic_thread_fence_acq_rel_cta :
39603946
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
39613947
Requires<[hasPTX<60>, hasSM<70>]>;
3948+
3949+
def fpimm_any_zero : FPImmLeaf<fAny, [{
3950+
return Imm.isZero();
3951+
}]>;
3952+
3953+
def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
3954+
def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
3955+
3956+
// Perform substitution if fma only has one use, and also if instruction has
3957+
// nnan instruction flag or if the TM has NoNaNsFPMath
3958+
def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
3959+
(fma node:$a, node:$b, node:$c), [{
3960+
return N->hasOneUse() &&
3961+
(N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
3962+
}]>;
3963+
// fmaxnum will differentiate between signed and unsigned zeros soon, so this
3964+
// PatFrag is for a fmaxnum node with nsz
3965+
def NVPTX_fmaxnum_nsz : PatFrag<(ops node:$a, node:$b),
3966+
(fmaxnum node:$a, node:$b), [{
3967+
return N->getFlags().hasNoSignedZeros() || TM.Options.NoSignedZerosFPMath;
3968+
}]>;
3969+
3970+
class NVPTXInst_rrr<RegisterClass RC, string Instruction, list<Predicate> Preds>
3971+
: NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
3972+
!strconcat(Instruction, "\t$dst, $a, $b, $c;"), []>,
3973+
Requires<Preds>;
3974+
3975+
def FMARELU_F16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3976+
def FMARELU_F16_FTZ : NVPTXInst_rrr<Int16Regs, "fma.rn.ftz.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3977+
def FMARELU_BF16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.bf16", [hasBF16Math, hasPTX<70>, hasSM<80>]>;
3978+
def FMARELU_F16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3979+
def FMARELU_F16X2_FTZ : NVPTXInst_rrr<Int32Regs, "fma.rn.ftz.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3980+
def FMARELU_BF16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.bf16x2", [hasBF16Math, hasPTX<70>, hasSM<80>]>;
3981+
3982+
// FTZ
3983+
def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3984+
(FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3985+
Requires<[doF32FTZ]>;
3986+
def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
3987+
(FMARELU_F16X2_FTZ Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>,
3988+
Requires<[doF32FTZ]>;
3989+
3990+
// NO FTZ
3991+
def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3992+
(FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
3993+
def : Pat<(bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3994+
(FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
3995+
def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
3996+
(FMARELU_F16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
3997+
def : Pat<(v2bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2bf16)),
3998+
(FMARELU_BF16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;

0 commit comments

Comments
 (0)