Skip to content

Commit d90cac9

Browse files
[DAGCombine] Simplify partial_reduce_*mla with constant. (#138289)
partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) -> partial_reduce_*mla(acc, x, C)
1 parent ee7f6a5 commit d90cac9

File tree

2 files changed

+176
-21
lines changed

2 files changed

+176
-21
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12601,47 +12601,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1260112601
return SDValue();
1260212602
}
1260312603

12604-
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
12605-
// Splat(1)) into
12606-
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
12607-
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
12608-
// Splat(1)) into
12609-
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
12604+
// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
12605+
// -> partial_reduce_*mla(acc, a, b)
12606+
//
12607+
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12608+
// -> partial_reduce_*mla(acc, x, C)
1261012609
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1261112610
SDLoc DL(N);
12612-
12611+
auto *Context = DAG.getContext();
1261312612
SDValue Acc = N->getOperand(0);
1261412613
SDValue Op1 = N->getOperand(1);
1261512614
SDValue Op2 = N->getOperand(2);
1261612615

12617-
APInt ConstantOne;
12616+
APInt C;
1261812617
if (Op1->getOpcode() != ISD::MUL ||
12619-
!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12620-
!ConstantOne.isOne())
12618+
!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
1262112619
return SDValue();
1262212620

1262312621
SDValue LHS = Op1->getOperand(0);
1262412622
SDValue RHS = Op1->getOperand(1);
1262512623
unsigned LHSOpcode = LHS->getOpcode();
12626-
unsigned RHSOpcode = RHS->getOpcode();
12627-
if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
12624+
if (!ISD::isExtOpcode(LHSOpcode))
1262812625
return SDValue();
1262912626

1263012627
SDValue LHSExtOp = LHS->getOperand(0);
12631-
SDValue RHSExtOp = RHS->getOperand(0);
1263212628
EVT LHSExtOpVT = LHSExtOp.getValueType();
12633-
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
12634-
return SDValue();
1263512629

12636-
// Only perform the DAG combine if there is custom lowering provided by the
12637-
// target
12638-
auto *Context = DAG.getContext();
12630+
// Only perform these combines if the target supports folding
12631+
// the extends into the operation.
1263912632
if (!TLI.isPartialReduceMLALegalOrCustom(
1264012633
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
1264112634
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1264212635
return SDValue();
1264312636

1264412637
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12638+
unsigned NewOpcode =
12639+
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12640+
12641+
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
12642+
// -> partial_reduce_*mla(acc, x, C)
12643+
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
12644+
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
12645+
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
12646+
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
12647+
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
12648+
return SDValue();
12649+
12650+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12651+
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
12652+
}
12653+
12654+
unsigned RHSOpcode = RHS->getOpcode();
12655+
if (!ISD::isExtOpcode(RHSOpcode))
12656+
return SDValue();
12657+
12658+
SDValue RHSExtOp = RHS->getOperand(0);
12659+
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
12660+
return SDValue();
1264512661

1264612662
// For a 2-stage extend the signedness of both of the extends must be the
1264712663
// same. This is so the node can be folded into only a signed or unsigned
@@ -12652,8 +12668,6 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1265212668
Op1.getValueType().getVectorElementType() != AccElemVT)
1265312669
return SDValue();
1265412670

12655-
unsigned NewOpcode =
12656-
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
1265712671
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
1265812672
RHSExtOp);
1265912673
}

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,6 @@ entry:
11391139
ret <vscale x 2 x i16> %partial.reduce
11401140
}
11411141

1142-
11431142
define <vscale x 4 x i64> @partial_reduce_only_split_acc(<vscale x 4 x i64> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
11441143
; CHECK-LABEL: partial_reduce_only_split_acc:
11451144
; CHECK: // %bb.0: // %entry
@@ -1178,3 +1177,145 @@ entry:
11781177
<vscale x 4 x i64> %acc, <vscale x 8 x i64> %mult)
11791178
ret <vscale x 4 x i64> %partial.reduce
11801179
}
1180+
1181+
define <vscale x 4 x i32> @sdot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
1182+
; CHECK-LABEL: sdot_imm:
1183+
; CHECK: // %bb.0: // %entry
1184+
; CHECK-NEXT: sunpklo z2.h, z1.b
1185+
; CHECK-NEXT: sunpkhi z1.h, z1.b
1186+
; CHECK-NEXT: sunpklo z3.s, z2.h
1187+
; CHECK-NEXT: sunpkhi z2.s, z2.h
1188+
; CHECK-NEXT: sub z0.s, z0.s, z3.s
1189+
; CHECK-NEXT: sunpklo z3.s, z1.h
1190+
; CHECK-NEXT: sunpkhi z1.s, z1.h
1191+
; CHECK-NEXT: sub z0.s, z0.s, z2.s
1192+
; CHECK-NEXT: sub z0.s, z0.s, z3.s
1193+
; CHECK-NEXT: sub z0.s, z0.s, z1.s
1194+
; CHECK-NEXT: ret
1195+
;
1196+
; CHECK-NEWLOWERING-LABEL: sdot_imm:
1197+
; CHECK-NEWLOWERING: // %bb.0: // %entry
1198+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
1199+
; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
1200+
; CHECK-NEWLOWERING-NEXT: ret
1201+
entry:
1202+
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
1203+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 -1)
1204+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
1205+
ret <vscale x 4 x i32> %partial.reduce
1206+
}
1207+
1208+
define <vscale x 4 x i32> @sdot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
1209+
; CHECK-LABEL: sdot_imm_does_not_fit:
1210+
; CHECK: // %bb.0: // %entry
1211+
; CHECK-NEXT: sunpklo z2.h, z1.b
1212+
; CHECK-NEXT: sunpkhi z1.h, z1.b
1213+
; CHECK-NEXT: sunpklo z3.s, z2.h
1214+
; CHECK-NEXT: sunpkhi z2.s, z2.h
1215+
; CHECK-NEXT: sunpklo z4.s, z1.h
1216+
; CHECK-NEXT: sunpkhi z1.s, z1.h
1217+
; CHECK-NEXT: lsl z4.s, z4.s, #8
1218+
; CHECK-NEXT: lsl z2.s, z2.s, #8
1219+
; CHECK-NEXT: lsl z3.s, z3.s, #8
1220+
; CHECK-NEXT: lsl z1.s, z1.s, #8
1221+
; CHECK-NEXT: add z0.s, z0.s, z3.s
1222+
; CHECK-NEXT: add z2.s, z2.s, z4.s
1223+
; CHECK-NEXT: add z0.s, z0.s, z2.s
1224+
; CHECK-NEXT: add z0.s, z0.s, z1.s
1225+
; CHECK-NEXT: ret
1226+
;
1227+
; CHECK-NEWLOWERING-LABEL: sdot_imm_does_not_fit:
1228+
; CHECK-NEWLOWERING: // %bb.0: // %entry
1229+
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
1230+
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
1231+
; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
1232+
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
1233+
; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z1.h
1234+
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
1235+
; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
1236+
; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
1237+
; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
1238+
; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
1239+
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
1240+
; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
1241+
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
1242+
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
1243+
; CHECK-NEWLOWERING-NEXT: ret
1244+
entry:
1245+
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
1246+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
1247+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
1248+
ret <vscale x 4 x i32> %partial.reduce
1249+
}
1250+
1251+
define <vscale x 4 x i32> @udot_imm(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
1252+
; CHECK-LABEL: udot_imm:
1253+
; CHECK: // %bb.0: // %entry
1254+
; CHECK-NEXT: uunpklo z3.h, z1.b
1255+
; CHECK-NEXT: mov z2.s, #255 // =0xff
1256+
; CHECK-NEXT: ptrue p0.s
1257+
; CHECK-NEXT: uunpkhi z1.h, z1.b
1258+
; CHECK-NEXT: uunpklo z4.s, z3.h
1259+
; CHECK-NEXT: uunpkhi z3.s, z3.h
1260+
; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
1261+
; CHECK-NEXT: uunpklo z4.s, z1.h
1262+
; CHECK-NEXT: uunpkhi z1.s, z1.h
1263+
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z2.s
1264+
; CHECK-NEXT: mla z0.s, p0/m, z4.s, z2.s
1265+
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
1266+
; CHECK-NEXT: ret
1267+
;
1268+
; CHECK-NEWLOWERING-LABEL: udot_imm:
1269+
; CHECK-NEWLOWERING: // %bb.0: // %entry
1270+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #-1 // =0xffffffffffffffff
1271+
; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
1272+
; CHECK-NEWLOWERING-NEXT: ret
1273+
entry:
1274+
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
1275+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 255)
1276+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
1277+
ret <vscale x 4 x i32> %partial.reduce
1278+
}
1279+
1280+
define <vscale x 4 x i32> @udot_imm_does_not_fit(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a) {
1281+
; CHECK-LABEL: udot_imm_does_not_fit:
1282+
; CHECK: // %bb.0: // %entry
1283+
; CHECK-NEXT: uunpklo z2.h, z1.b
1284+
; CHECK-NEXT: uunpkhi z1.h, z1.b
1285+
; CHECK-NEXT: uunpklo z3.s, z2.h
1286+
; CHECK-NEXT: uunpkhi z2.s, z2.h
1287+
; CHECK-NEXT: uunpklo z4.s, z1.h
1288+
; CHECK-NEXT: uunpkhi z1.s, z1.h
1289+
; CHECK-NEXT: lsl z4.s, z4.s, #8
1290+
; CHECK-NEXT: lsl z2.s, z2.s, #8
1291+
; CHECK-NEXT: lsl z3.s, z3.s, #8
1292+
; CHECK-NEXT: lsl z1.s, z1.s, #8
1293+
; CHECK-NEXT: add z0.s, z0.s, z3.s
1294+
; CHECK-NEXT: add z2.s, z2.s, z4.s
1295+
; CHECK-NEXT: add z0.s, z0.s, z2.s
1296+
; CHECK-NEXT: add z0.s, z0.s, z1.s
1297+
; CHECK-NEXT: ret
1298+
;
1299+
; CHECK-NEWLOWERING-LABEL: udot_imm_does_not_fit:
1300+
; CHECK-NEWLOWERING: // %bb.0: // %entry
1301+
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
1302+
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
1303+
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
1304+
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
1305+
; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h
1306+
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
1307+
; CHECK-NEWLOWERING-NEXT: lsl z4.s, z4.s, #8
1308+
; CHECK-NEWLOWERING-NEXT: lsl z2.s, z2.s, #8
1309+
; CHECK-NEWLOWERING-NEXT: lsl z3.s, z3.s, #8
1310+
; CHECK-NEWLOWERING-NEXT: lsl z1.s, z1.s, #8
1311+
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
1312+
; CHECK-NEWLOWERING-NEXT: add z2.s, z2.s, z4.s
1313+
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
1314+
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
1315+
; CHECK-NEWLOWERING-NEXT: ret
1316+
entry:
1317+
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
1318+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, splat(i32 256)
1319+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
1320+
ret <vscale x 4 x i32> %partial.reduce
1321+
}

0 commit comments

Comments
 (0)