-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[NVPTX] Add patterns for fma.relu.{f16|f16x2|bf16|bf16x2} #114977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Hugh Delaney (hdelan) ChangesAdd patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types. Full diff: https://github.com/llvm/llvm-project/pull/114977.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5f6cba397c5352..52312fa9afbd7e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3917,3 +3917,19 @@ def atomic_thread_fence_seq_cst_cta :
def atomic_thread_fence_acq_rel_cta :
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
Requires<[hasPTX<60>, hasSM<70>]>;
+
+def fpimm0 : FPImmLeaf<fAny, [{
+ return Imm.isExactlyValue(+0.0);
+}]>;
+
+def FMARELU :
+ NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
+ "fma.rn.relu \t$dst, $a, $b, $c;", []>;
+
+def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+ (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+ Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
+
+def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
+ (FMARELU Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
+ Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
diff --git a/llvm/test/CodeGen/NVPTX/fma-relu.ll b/llvm/test/CodeGen/NVPTX/fma-relu.ll
new file mode 100644
index 00000000000000..6c340ef9d53015
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fma-relu.ll
@@ -0,0 +1,77 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
+
+define half @fma_f16(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_f16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_f16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_f16_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call half @llvm.fma.f16(half %a, half %b, half %c)
+ %2 = fcmp ogt half %1, 0.0
+ %3 = select i1 %2, half %1, half 0.0
+ ret half %3
+}
+
+define half @fma_f16_expanded(half %a, half %b, half %c) {
+; CHECK-LABEL: fma_f16_expanded(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_f16_expanded_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_f16_expanded_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_f16_expanded_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = fmul half %a, %b
+ %2 = fadd half %1, %c
+ %3 = fcmp ogt half %2, 0.0
+ %4 = select i1 %3, half %2, half 0.0
+ ret half %4
+}
+
+define bfloat @fma_bf16(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
+ %2 = fcmp ogt bfloat %1, 0.0
+ %3 = select i1 %2, bfloat %1, bfloat 0.0
+ ret bfloat %3
+}
+
+define bfloat @fma_bf16_expanded(bfloat %a, bfloat %b, bfloat %c) {
+; CHECK-LABEL: fma_bf16_expanded(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_expanded_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_expanded_param_1];
+; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_expanded_param_2];
+; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
+; CHECK-NEXT: ret;
+ %1 = fmul bfloat %a, %b
+ %2 = fadd bfloat %1, %c
+ %3 = fcmp ogt bfloat %2, 0.0
+ %4 = select i1 %3, bfloat %2, bfloat 0.0
+ ret bfloat %4
+}
|
Ping @ldrumm @frasercrmck |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM!
Please wait for @AlexMaclean's review though as he's more familiar with NVPTXInstrInfo.td
than I am.
Suppose the |
llvm/test/CodeGen/NVPTX/fma-relu.ll
Outdated
%1 = fmul half %a, %b | ||
%2 = fadd half %1, %c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add a couple more test runs:
- one with w/o mul/add -> fma contraction to make sure we do not use
fma.rn.relu
unintentionally. - one targeting older GPUs to make sure we do not emit
fma.rn.relu
there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added more tests to cover these cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth also having a test case that uses llvm.maxnum
? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select
into an llvm.maxnum
anyway.
Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math
(or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math
function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc
flags (no --enable-unsafe-fp-math
, no -nvptx-fma-level
. We should still generate fma.relu
in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc
flags like this is a less standardised approach.
llvm/test/CodeGen/NVPTX/fma-relu.ll
Outdated
%1 = fmul half %a, %b | ||
%2 = fadd half %1, %c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth also having a test case that uses llvm.maxnum
? I believe that if the IR was given the right fast-math flags, InstCombine would transform this select
into an llvm.maxnum
anyway.
Speaking of, should we also have tests with fast-math flags? My feeling is that we should see fast-math flags in the IR as if this was really coming from a frontend with -ffast-math
(or equivalent). IIRC the NVPTX backend relies on the unsafe-fp-math
function attribute being set, which enables these fast math optimizations. I think we should have a test with fast-math flags, fast-math function attributes, and the default llc
flags (no --enable-unsafe-fp-math
, no -nvptx-fma-level
. We should still generate fma.relu
in that case, right? This, imo, should be "the" canonical test of this optimization - using various llc
flags like this is a less standardised approach.
4759e15
to
9456007
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
b2f6135
to
f5eea93
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will eventually want to figure out a way to effectively use f16x2 and bf16x2 variants of fma.rn.relu
.
I've changed the pattern matching to make sure FMA relu is only emitted if the FMA DAG has a single use. |
I've added support for the f16x2 and bf16x2 variants as well. |
fab1893
to
2ca364f
Compare
FTZ FMA is not supported for bf16, so I have removed some patterns elsewhere that were allowing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another future optimization opportunity -- we could lower those "truncate-to-i16, then recombine into i32" as a single prmt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does FMA handle NaNs? fmaxnum
requires If either operand is a NaN, returns the other non-NaN operand.
If we pass a NaN into an fma.relu
, and get back a NaN
back and that would be incorrect for fmaxnum
.
The documentation does not seem to say anything about NaN handling by FMA (it only mentions it for .sat variants)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are only applying this pattern if allowUnsafeFPMath
is enabled. So AFAIA we don't need to worry about NaN propagation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment for the allowUnsafeFPMath
says that it only affects precision of the operations, not NaN handling:
/// UnsafeFPMath - This flag is enabled when the |
So, it would allow using a regular FMA instead of
a*b+c
, but NaN
propagation would not be affected.This is not the case for
fmaxnum
-> fma
, where fmaxnum
has distinctly different requirements for processing NaN inputs.
In any case, we should run that by someone who's familiar with FP handling nuances.
@arsenm @RKSimon -- what would be a sufficient constraint to reflect that the instruction is not expected to deal with NaNs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely do not rely on allowUnsafeFPMath for no nans. We really should delete allowUnsafeFPMath.
FMA propagates nan as normal, and may also produce output nans based on the operands.
I also think this pattern is broken for signed zeroes. I assume the ftz modifier is expected to preserve -0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have changed the constraints to Requires<[noNaNsFPMath]>
.
I also think this pattern is broken for signed zeroes. I assume the ftz modifier is expected to preserve -0
I have extended the fpimm pattern to also accommodate negative zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think this pattern is broken for signed zeroes. I assume the ftz modifier is expected to preserve -0
I think I misunderstood what you were saying first time round. You are concerned with a comparison of 0.0
and -0.0
. Since I am pattern matching with fmaxnum
, which says either -0.0 or 0.0 may be returned in a comparison of both values, the guarantees of fma.relu
(which doesn't specify what is returned when signed and unsigned zeros are in the final comparison) are the same as the fmaxnum
comparison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are fixing the definition of fmaxnum/fminnum to require correctly ordered signed zeroes. The fuzzy behavior can be achieved with nsz, so it doesn't make sense to define it this way.
Best to respect this now to reduce the amount of code that needs changing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a PatFrag to make sure the fmaxnum
has nsz
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description needs updating to reflect some of the new changes - the conditions under which the optimization is legal/performed, and that some instructions are being removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally wouldn't bother supporting the function attribute case - I don't think it warrants the extra code and patterns. Supporting nnan
should be enough. I think that the NVPTX backend should generally be moving towards using instruction flags rather than function attributes to do its various fast-math optimizations. But I'd defer to other people on that point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Artem-B do you have an opinion on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sure a multiclass
could whittle this bulk down quite a bit. The only things changing are the instruction suffix, the reg class, and the instruction suffix in the asm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done so in latest commit.
5c4d1ec
to
c261a02
Compare
Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.
FMA relu should only be emitted if the FMA node has a single use. This should limit register pressure in some cases do avoid computing FMA as well as FMA.relu. Also split tests into two files, one using FMA contraction and the other using the FMA intrinsic.
The PTX ISA docs note that FTZ is not permitted for FTZ for bf16 types.
fma.ftz isn't supported for bf16.
The pattern substitution for fmax(fma() 0) is only valid if NaNs cannot be emitted. This is because fmax is guaranteed to return the non-NaN arg, whereas fma.relu is NaN preserving.
Extend the pattern to allow negative zeros for scalar types.
If nnan is used in instruction flags then FMA relu can also be emitted.
Use a multiclass to refactor instruction defs.
Instead of adding a flag for noNaNsFPMath, just add the check in the PatFrag which also checks for single use FMA, as well as instruction flags.
fmaxnum currently returns either -0.0 or 0.0 when comparing both values, however this specification is going to change to return positive zero in a comparison between both vals (according to arsenm). fma.relu is not specified to return either positive or negative zero in a comparison of signed zeros, so it shouldn't be emitted in a DAG that uses fmaxnum's new specification unless nsz is also used.
Don't use multiclass. Class is more readable.
TODO: refactor other methods to use this.
Add patterns to lower
fmaxnum(fma(a, b, c), 0)
tofma.rn{.ftz}.relu
forf16
,f16x2
,bf16
,bf16x2
types, whennnan
is used.fma_relu
honoursNaN
, so the substitution is only made if thefma
isnnan
, sincefmaxnum
returns the non NaN argument when passed a NaN value.This patch also removes some
bf16
ftz instructions sinceFTZ
is not supported with thebf16
type, according to the PTX ISA docs.