Skip to content

Commit 273a94b

Browse files
authored
[NVPTX] Add some more immediate instruction variants (#122746)
While this likely won't impact the final SASS, it makes for more compact PTX.
1 parent c4fb718 commit 273a94b

File tree

6 files changed

+175
-151
lines changed

6 files changed

+175
-151
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -207,33 +207,39 @@ class ValueToRegClass<ValueType T> {
207207
// Some Common Instruction Class Templates
208208
//===----------------------------------------------------------------------===//
209209

210+
// Utility class to wrap up information about a register and DAG type for more
211+
// convenient iteration and parameterization
212+
class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm> {
213+
ValueType Ty = ty;
214+
NVPTXRegClass RC = rc;
215+
Operand Imm = imm;
216+
int Size = ty.Size;
217+
}
218+
219+
def I16RT : RegTyInfo<i16, Int16Regs, i16imm>;
220+
def I32RT : RegTyInfo<i32, Int32Regs, i32imm>;
221+
def I64RT : RegTyInfo<i64, Int64Regs, i64imm>;
222+
210223
// Template for instructions which take three int64, int32, or int16 args.
211224
// The instructions are named "<OpcStr><Width>" (e.g. "add.s64").
212-
multiclass I3<string OpcStr, SDNode OpNode> {
213-
def i64rr :
214-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b),
215-
!strconcat(OpcStr, "64 \t$dst, $a, $b;"),
216-
[(set i64:$dst, (OpNode i64:$a, i64:$b))]>;
217-
def i64ri :
218-
NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b),
219-
!strconcat(OpcStr, "64 \t$dst, $a, $b;"),
220-
[(set i64:$dst, (OpNode i64:$a, imm:$b))]>;
221-
def i32rr :
222-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
223-
!strconcat(OpcStr, "32 \t$dst, $a, $b;"),
224-
[(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
225-
def i32ri :
226-
NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
227-
!strconcat(OpcStr, "32 \t$dst, $a, $b;"),
228-
[(set i32:$dst, (OpNode i32:$a, imm:$b))]>;
229-
def i16rr :
230-
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
231-
!strconcat(OpcStr, "16 \t$dst, $a, $b;"),
232-
[(set i16:$dst, (OpNode i16:$a, i16:$b))]>;
233-
def i16ri :
234-
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, i16imm:$b),
235-
!strconcat(OpcStr, "16 \t$dst, $a, $b;"),
236-
[(set i16:$dst, (OpNode i16:$a, (imm):$b))]>;
225+
multiclass I3<string OpcStr, SDNode OpNode, bit commutative> {
226+
foreach t = [I16RT, I32RT, I64RT] in {
227+
defvar asmstr = OpcStr # t.Size # " \t$dst, $a, $b;";
228+
229+
def t.Ty # rr :
230+
NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b),
231+
asmstr,
232+
[(set t.Ty:$dst, (OpNode t.Ty:$a, t.Ty:$b))]>;
233+
def t.Ty # ri :
234+
NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.Imm:$b),
235+
asmstr,
236+
[(set t.Ty:$dst, (OpNode t.RC:$a, imm:$b))]>;
237+
if !not(commutative) then
238+
def t.Ty # ir :
239+
NVPTXInst<(outs t.RC:$dst), (ins t.Imm:$a, t.RC:$b),
240+
asmstr,
241+
[(set t.Ty:$dst, (OpNode imm:$a, t.RC:$b))]>;
242+
}
237243
}
238244

239245
class I16x2<string OpcStr, SDNode OpNode> :
@@ -870,8 +876,8 @@ defm SUB_i1 : ADD_SUB_i1<sub>;
870876

871877
// int16, int32, and int64 signed addition. Since nvptx is 2's complement, we
872878
// also use these for unsigned arithmetic.
873-
defm ADD : I3<"add.s", add>;
874-
defm SUB : I3<"sub.s", sub>;
879+
defm ADD : I3<"add.s", add, /*commutative=*/ true>;
880+
defm SUB : I3<"sub.s", sub, /*commutative=*/ false>;
875881

876882
def ADD16x2 : I16x2<"add.s", add>;
877883

@@ -883,18 +889,18 @@ defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc>;
883889
defm ADDCCC : ADD_SUB_INT_CARRY<"addc.cc", adde>;
884890
defm SUBCCC : ADD_SUB_INT_CARRY<"subc.cc", sube>;
885891

886-
defm MULT : I3<"mul.lo.s", mul>;
892+
defm MULT : I3<"mul.lo.s", mul, /*commutative=*/ true>;
887893

888-
defm MULTHS : I3<"mul.hi.s", mulhs>;
889-
defm MULTHU : I3<"mul.hi.u", mulhu>;
894+
defm MULTHS : I3<"mul.hi.s", mulhs, /*commutative=*/ true>;
895+
defm MULTHU : I3<"mul.hi.u", mulhu, /*commutative=*/ true>;
890896

891-
defm SDIV : I3<"div.s", sdiv>;
892-
defm UDIV : I3<"div.u", udiv>;
897+
defm SDIV : I3<"div.s", sdiv, /*commutative=*/ false>;
898+
defm UDIV : I3<"div.u", udiv, /*commutative=*/ false>;
893899

894900
// The ri versions of rem.s and rem.u won't be selected; DAGCombiner::visitSREM
895901
// will lower it.
896-
defm SREM : I3<"rem.s", srem>;
897-
defm UREM : I3<"rem.u", urem>;
902+
defm SREM : I3<"rem.s", srem, /*commutative=*/ false>;
903+
defm UREM : I3<"rem.u", urem, /*commutative=*/ false>;
898904

899905
// Integer absolute value. NumBits should be one minus the bit width of RC.
900906
// This idiom implements the algorithm at
@@ -909,10 +915,10 @@ defm ABS_32 : ABS<i32, Int32Regs, ".s32">;
909915
defm ABS_64 : ABS<i64, Int64Regs, ".s64">;
910916

911917
// Integer min/max.
912-
defm SMAX : I3<"max.s", smax>;
913-
defm UMAX : I3<"max.u", umax>;
914-
defm SMIN : I3<"min.s", smin>;
915-
defm UMIN : I3<"min.u", umin>;
918+
defm SMAX : I3<"max.s", smax, /*commutative=*/ true>;
919+
defm UMAX : I3<"max.u", umax, /*commutative=*/ true>;
920+
defm SMIN : I3<"min.s", smin, /*commutative=*/ true>;
921+
defm UMIN : I3<"min.u", umin, /*commutative=*/ true>;
916922

917923
def SMAX16x2 : I16x2<"max.s", smax>;
918924
def UMAX16x2 : I16x2<"max.u", umax>;
@@ -1392,25 +1398,32 @@ def FDIV32ri_prec :
13921398
//
13931399

13941400
multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
1395-
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1396-
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1397-
[(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
1398-
Requires<[Pred]>;
1399-
def rri : NVPTXInst<(outs RC:$dst),
1400-
(ins RC:$a, RC:$b, ImmCls:$c),
1401-
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1402-
[(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>,
1403-
Requires<[Pred]>;
1404-
def rir : NVPTXInst<(outs RC:$dst),
1405-
(ins RC:$a, ImmCls:$b, RC:$c),
1406-
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1407-
[(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>,
1408-
Requires<[Pred]>;
1409-
def rii : NVPTXInst<(outs RC:$dst),
1410-
(ins RC:$a, ImmCls:$b, ImmCls:$c),
1411-
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1412-
[(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>,
1413-
Requires<[Pred]>;
1401+
defvar asmstr = OpcStr # " \t$dst, $a, $b, $c;";
1402+
def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1403+
asmstr,
1404+
[(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
1405+
Requires<[Pred]>;
1406+
def rri : NVPTXInst<(outs RC:$dst),
1407+
(ins RC:$a, RC:$b, ImmCls:$c),
1408+
asmstr,
1409+
[(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>,
1410+
Requires<[Pred]>;
1411+
def rir : NVPTXInst<(outs RC:$dst),
1412+
(ins RC:$a, ImmCls:$b, RC:$c),
1413+
asmstr,
1414+
[(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>,
1415+
Requires<[Pred]>;
1416+
def rii : NVPTXInst<(outs RC:$dst),
1417+
(ins RC:$a, ImmCls:$b, ImmCls:$c),
1418+
asmstr,
1419+
[(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>,
1420+
Requires<[Pred]>;
1421+
def iir : NVPTXInst<(outs RC:$dst),
1422+
(ins ImmCls:$a, ImmCls:$b, RC:$c),
1423+
asmstr,
1424+
[(set RC:$dst, (fma fpimm:$a, fpimm:$b, RC:$c))]>,
1425+
Requires<[Pred]>;
1426+
14141427
}
14151428

14161429
multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
// Utility class to wrap up information about a register and DAG type for more
10-
// convenient iteration and parameterization
11-
class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm> {
12-
ValueType Ty = ty;
13-
NVPTXRegClass RC = rc;
14-
Operand Imm = imm;
15-
int Size = ty.Size;
16-
}
17-
18-
def I32RT : RegTyInfo<i32, Int32Regs, i32imm>;
19-
def I64RT : RegTyInfo<i64, Int64Regs, i64imm>;
20-
21-
229
def immFloat0 : PatLeaf<(fpimm), [{
2310
float f = (float)N->getValueAPF().convertToFloat();
2411
return (f==0.0f);

llvm/test/CodeGen/NVPTX/arithmetic-int.ll

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,26 @@ define i16 @lshr_i16(i16 %a, i16 %b) {
317317
%ret = lshr i16 %a, %b
318318
ret i16 %ret
319319
}
320+
321+
;; Immediate cases
322+
323+
define i16 @srem_i16_ir(i16 %a) {
324+
; CHECK: rem.s16 %rs{{[0-9]+}}, 12, %rs{{[0-9]+}}
325+
; CHECK: ret
326+
%ret = srem i16 12, %a
327+
ret i16 %ret
328+
}
329+
330+
define i32 @udiv_i32_ir(i32 %a) {
331+
; CHECK: div.u32 %r{{[0-9]+}}, 34, %r{{[0-9]+}}
332+
; CHECK: ret
333+
%ret = udiv i32 34, %a
334+
ret i32 %ret
335+
}
336+
337+
define i64 @sub_i64_ir(i64 %a) {
338+
; CHECK: sub.s64 %rd{{[0-9]+}}, 56, %rd{{[0-9]+}}
339+
; CHECK: ret
340+
%ret = sub i64 56, %a
341+
ret i64 %ret
342+
}

llvm/test/CodeGen/NVPTX/fma.ll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,17 @@ define ptx_device double @t2_f64(double %x, double %y, double %z, double %w) {
4141
%d = call double @dummy_f64(double %b, double %c)
4242
ret double %d
4343
}
44+
45+
define ptx_device float @f32_iir(float %x) {
46+
; CHECK: fma.rn.f32 %f{{[0-9]+}}, 0f52E8D4A5, 0f4A52FC54, %f{{[0-9]+}};
47+
; CHECK: ret;
48+
%r = call float @llvm.fma.f32(float 499999997952.0, float 3456789.0, float %x)
49+
ret float %r
50+
}
51+
52+
define ptx_device float @f32_iii(float %x) {
53+
; CHECK: mov.f32 %f{{[0-9]+}}, 0f41200000;
54+
; CHECK: ret;
55+
%r = call float @llvm.fma.f32(float 2.0, float 3.0, float 4.0)
56+
ret float %r
57+
}

0 commit comments

Comments
 (0)