Skip to content

Conversation

hdelan
Copy link

@hdelan hdelan commented Nov 5, 2024

Add patterns to lower fmaxnum(fma(a, b, c), 0) to fma.rn{.ftz}.relu for f16, f16x2, bf16, bf16x2 types, when nnan is used.

fma_relu honours NaN, so the substitution is only made if the fma is nnan, since fmaxnum returns the non NaN argument when passed a NaN value.

This patch also removes some bf16 ftz instructions since FTZ is not supported with the bf16 type, according to the PTX ISA docs.

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Hugh Delaney (hdelan)

Changes

Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.


Full diff: https://github.com/llvm/llvm-project/pull/114977.diff

2 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+16)
  • (added) llvm/test/CodeGen/NVPTX/fma-relu.ll (+77)
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5f6cba397c5352..52312fa9afbd7e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3917,3 +3917,19 @@ def atomic_thread_fence_seq_cst_cta :
 def atomic_thread_fence_acq_rel_cta :
   NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
   Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm0 : FPImmLeaf<fAny, [{
+  return Imm.isExactlyValue(+0.0);
+}]>;
+
+def FMARELU :
+  NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+            "fma.rn.relu \t$dst, $a, $b, $c;", []>;
+
+def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
+
+def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+  (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+  Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu.ll b/llvm/test/CodeGen/NVPTX/fma-relu.ll
new file mode 100644
index 00000000000000..6c340ef9d53015
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-relu.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
+
+define half @fma_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call half @llvm.fma.f16(half %a, half %b, half %c)
+  %2 = fcmp ogt half %1, 0.0
+  %3 = select i1 %2, half %1, half 0.0
+  ret half %3
+}
+
+define half @fma_f16_expanded(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_f16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_f16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_f16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul half %a, %b
+  %2 = fadd half %1, %c
+  %3 = fcmp ogt half %2, 0.0
+  %4 = select i1 %3, half %2, half 0.0
+  ret half %4
+}
+
+define bfloat @fma_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
+  %2 = fcmp ogt bfloat %1, 0.0
+  %3 = select i1 %2, bfloat %1, bfloat 0.0
+  ret bfloat %3
+}
+
+define bfloat @fma_bf16_expanded(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16_expanded(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b16 %rs1, [fma_bf16_expanded_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [fma_bf16_expanded_param_1];
+; CHECK-NEXT:    ld.param.b16 %rs3, [fma_bf16_expanded_param_2];
+; CHECK-NEXT:    fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT:    ret;
+  %1 = fmul bfloat %a, %b
+  %2 = fadd bfloat %1, %c
+  %3 = fcmp ogt bfloat %2, 0.0
+  %4 = select i1 %3, bfloat %2, bfloat 0.0
+  ret bfloat %4
+}

@hdelan hdelan changed the title Add patterns for fma.relu.{f16|bf16} [NVPTX] Add patterns for fma.relu.{f16|bf16} Nov 5, 2024
@hdelan
Copy link
Author

hdelan commented Nov 5, 2024

Ping @ldrumm @frasercrmck

Copy link
Contributor

@justinfargnoli justinfargnoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM!

Please wait for @AlexMaclean's review though as he's more familiar with NVPTXInstrInfo.td than I am.

@AlexMaclean
Copy link
Member

Suppose the fma has more uses in addition to the fmaxnum, If this optimization kicks in it may increase the register pressure and won't be a clear win in terms of performance. I'm not sure this will be a problem, but to be conservative it may be better to implement this as a DAG combine and verify the fma has a single use.

@AlexMaclean AlexMaclean requested a review from Artem-B November 5, 2024 17:30
Comment on lines 35 to 304
%1 = fmul half %a, %b
%2 = fadd half %1, %c
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add a couple more test runs:

  • one with w/o mul/add -> fma contraction to make sure we do not use fma.rn.relu unintentionally.
  • one targeting older GPUs to make sure we do not emit fma.rn.relu there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added more tests to cover these cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth also having a test case that uses llvm.maxnum? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select into an llvm.maxnum anyway.

Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math (or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc flags (no --enable-unsafe-fp-math, no -nvptx-fma-level. We should still generate fma.relu in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc flags like this is a less standardised approach.

Comment on lines 35 to 304
%1 = fmul half %a, %b
%2 = fadd half %1, %c
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth also having a test case that uses llvm.maxnum? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select into an llvm.maxnum anyway.

Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math (or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc flags (no --enable-unsafe-fp-math, no -nvptx-fma-level. We should still generate fma.relu in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc flags like this is a less standardised approach.

@hdelan hdelan force-pushed the fma-relu branch 4 times, most recently from 4759e15 to 9456007 Compare November 6, 2024 11:20
Copy link
Contributor

@ldrumm ldrumm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@hdelan hdelan force-pushed the fma-relu branch 2 times, most recently from b2f6135 to f5eea93 Compare November 6, 2024 12:55
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will eventually want to figure out a way to effectively use f16x2 and bf16x2 variants of fma.rn.relu.

@hdelan
Copy link
Author

hdelan commented Nov 7, 2024

Suppose the fma has more uses in addition to the fmaxnum, If this optimization kicks in it may increase the register pressure and won't be a clear win in terms of performance. I'm not sure this will be a problem, but to be conservative it may be better to implement this as a DAG combine and verify the fma has a single use.

I've changed the pattern matching to make sure FMA relu is only emitted if the FMA DAG has a single use.

@hdelan
Copy link
Author

hdelan commented Nov 7, 2024

We will eventually want to figure out a way to effectively use f16x2 and bf16x2 variants of fma.rn.relu.

I've added support for the f16x2 and bf16x2 variants as well.

@hdelan hdelan force-pushed the fma-relu branch 2 times, most recently from fab1893 to 2ca364f Compare November 7, 2024 13:37
@hdelan hdelan changed the title [NVPTX] Add patterns for fma.relu.{f16|bf16} [NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} Nov 7, 2024
@hdelan
Copy link
Author

hdelan commented Nov 7, 2024

FTZ FMA is not supported for bf16, so I have removed some patterns elsewhere that were allowing fma.rn.ftz.bf16 to be emitted.

Comment on lines +1412 to +1259
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another future optimization opportunity -- we could lower those "truncate-to-i16, then recombine into i32" as a single prmt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does FMA handle NaNs? fmaxnum requires If either operand is a NaN, returns the other non-NaN operand.

If we pass a NaN into an fma.relu, and get back a NaN back and that would be incorrect for fmaxnum.
The documentation does not seem to say anything about NaN handling by FMA (it only mentions it for .sat variants)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are only applying this pattern if allowUnsafeFPMath is enabled. So AFAIA we don't need to worry about NaN propagation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment for the allowUnsafeFPMath says that it only affects precision of the operations, not NaN handling:

/// UnsafeFPMath - This flag is enabled when the

So, it would allow using a regular FMA instead of a*b+c, but NaN propagation would not be affected.
This is not the case for fmaxnum -> fma, where fmaxnum has distinctly different requirements for processing NaN inputs.

In any case, we should run that by someone who's familiar with FP handling nuances.
@arsenm @RKSimon -- what would be a sufficient constraint to reflect that the instruction is not expected to deal with NaNs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely do not rely on allowUnsafeFPMath for no nans. We really should delete allowUnsafeFPMath.

FMA propagates nan as normal, and may also produce output nans based on the operands.

I also think this pattern is broken for signed zeroes. I assume the ftz modifier is expected to preserve -0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have changed the constraints to Requires<[noNaNsFPMath]>.

I also think this pattern is broken for signed zeroes. I assume the ftz modifier is expected to preserve -0

I have extended the fpimm pattern to also accommodate negative zero.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think this pattern is broken for signed zeroes. I assume the ftz modifier is expected to preserve -0

I think I misunderstood what you were saying first time round. You are concerned with a comparison of 0.0 and -0.0. Since I am pattern matching with fmaxnum, which says either -0.0 or 0.0 may be returned in a comparison of both values, the guarantees of fma.relu (which doesn't specify what is returned when signed and unsigned zeros are in the final comparison) are the same as the fmaxnum comparison.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are fixing the definition of fmaxnum/fminnum to require correctly ordered signed zeroes. The fuzzy behavior can be achieved with nsz, so it doesn't make sense to define it this way.

Best to respect this now to reduce the amount of code that needs changing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a PatFrag to make sure the fmaxnum has nsz.

@hdelan hdelan requested a review from frasercrmck November 11, 2024 17:33
Copy link
Contributor

@frasercrmck frasercrmck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description needs updating to reflect some of the new changes - the conditions under which the optimization is legal/performed, and that some instructions are being removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally wouldn't bother supporting the function attribute case - I don't think it warrants the extra code and patterns. Supporting nnan should be enough. I think that the NVPTX backend should generally be moving towards using instruction flags rather than function attributes to do its various fast-math optimizations. But I'd defer to other people on that point.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Artem-B do you have an opinion on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure a multiclass could whittle this bulk down quite a bit. The only things changing are the instruction suffix, the reg class, and the instruction suffix in the asm.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done so in latest commit.

@hdelan hdelan force-pushed the fma-relu branch 3 times, most recently from 5c4d1ec to c261a02 Compare November 14, 2024 14:12
Hugh Delaney added 18 commits November 18, 2024 14:50
Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and
bf16 types.
FMA relu should only be emitted if the FMA node has a single use. This
should limit register pressure in some cases do avoid computing FMA as
well as FMA.relu.

Also split tests into two files, one using FMA contraction and the other
using the FMA intrinsic.
The PTX ISA docs note that FTZ is not permitted for FTZ for bf16 types.
fma.ftz isn't supported for bf16.
The pattern substitution for fmax(fma() 0) is only valid if NaNs cannot
be emitted. This is because fmax is guaranteed to return the non-NaN
arg, whereas fma.relu is NaN preserving.
Extend the pattern to allow negative zeros for scalar types.
If nnan is used in instruction flags then FMA relu can also be emitted.
Use a multiclass to refactor instruction defs.
Instead of adding a flag for noNaNsFPMath, just add the check in the
PatFrag which also checks for single use FMA, as well as instruction
flags.
fmaxnum currently returns either -0.0 or 0.0 when comparing both values,
however this specification is going to change to return positive zero
in a comparison between both vals (according to arsenm).

fma.relu is not specified to return either positive or negative zero in
a comparison of signed zeros, so it shouldn't be emitted in a DAG that
uses fmaxnum's new specification unless nsz is also used.
Don't use multiclass. Class is more readable.
TODO: refactor other methods to use this.
@hdelan hdelan merged commit 8f8016f into llvm:main Nov 18, 2024
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants