@@ -418,25 +418,13 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
418
418
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
419
419
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
420
420
Requires<[useFP16Math, allowFMA]>;
421
- def bf16rr_ftz :
422
- NVPTXInst<(outs Int16Regs:$dst),
423
- (ins Int16Regs:$a, Int16Regs:$b),
424
- !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
425
- [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
426
- Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
427
421
def bf16rr :
428
422
NVPTXInst<(outs Int16Regs:$dst),
429
423
(ins Int16Regs:$a, Int16Regs:$b),
430
424
!strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
431
425
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
432
426
Requires<[hasBF16Math, allowFMA]>;
433
427
434
- def bf16x2rr_ftz :
435
- NVPTXInst<(outs Int32Regs:$dst),
436
- (ins Int32Regs:$a, Int32Regs:$b),
437
- !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
438
- [(set (v2bf16 Int32Regs:$dst), (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
439
- Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
440
428
def bf16x2rr :
441
429
NVPTXInst<(outs Int32Regs:$dst),
442
430
(ins Int32Regs:$a, Int32Regs:$b),
@@ -1423,9 +1411,7 @@ defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
1423
1411
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
1424
1412
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
1425
1413
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
1426
- defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
1427
1414
defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
1428
- defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
1429
1415
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
1430
1416
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
1431
1417
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
@@ -3959,3 +3945,54 @@ def atomic_thread_fence_seq_cst_cta :
3959
3945
def atomic_thread_fence_acq_rel_cta :
3960
3946
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
3961
3947
Requires<[hasPTX<60>, hasSM<70>]>;
3948
+
3949
+ def fpimm_any_zero : FPImmLeaf<fAny, [{
3950
+ return Imm.isZero();
3951
+ }]>;
3952
+
3953
+ def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
3954
+ def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
3955
+
3956
+ // Perform substitution if fma only has one use, and also if instruction has
3957
+ // nnan instruction flag or if the TM has NoNaNsFPMath
3958
+ def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
3959
+ (fma node:$a, node:$b, node:$c), [{
3960
+ return N->hasOneUse() &&
3961
+ (N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
3962
+ }]>;
3963
+ // fmaxnum will differentiate between signed and unsigned zeros soon, so this
3964
+ // PatFrag is for a fmaxnum node with nsz
3965
+ def NVPTX_fmaxnum_nsz : PatFrag<(ops node:$a, node:$b),
3966
+ (fmaxnum node:$a, node:$b), [{
3967
+ return N->getFlags().hasNoSignedZeros() || TM.Options.NoSignedZerosFPMath;
3968
+ }]>;
3969
+
3970
+ class NVPTXInst_rrr<RegisterClass RC, string Instruction, list<Predicate> Preds>
3971
+ : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
3972
+ !strconcat(Instruction, "\t$dst, $a, $b, $c;"), []>,
3973
+ Requires<Preds>;
3974
+
3975
+ def FMARELU_F16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3976
+ def FMARELU_F16_FTZ : NVPTXInst_rrr<Int16Regs, "fma.rn.ftz.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3977
+ def FMARELU_BF16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.bf16", [hasBF16Math, hasPTX<70>, hasSM<80>]>;
3978
+ def FMARELU_F16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3979
+ def FMARELU_F16X2_FTZ : NVPTXInst_rrr<Int32Regs, "fma.rn.ftz.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
3980
+ def FMARELU_BF16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.bf16x2", [hasBF16Math, hasPTX<70>, hasSM<80>]>;
3981
+
3982
+ // FTZ
3983
+ def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3984
+ (FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3985
+ Requires<[doF32FTZ]>;
3986
+ def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
3987
+ (FMARELU_F16X2_FTZ Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>,
3988
+ Requires<[doF32FTZ]>;
3989
+
3990
+ // NO FTZ
3991
+ def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3992
+ (FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
3993
+ def : Pat<(bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
3994
+ (FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
3995
+ def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
3996
+ (FMARELU_F16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
3997
+ def : Pat<(v2bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2bf16)),
3998
+ (FMARELU_BF16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
0 commit comments