Skip to content

Commit dd09221

Browse files
committed
Revert "[mlir][gpu] Align reduction operations with vector combining kinds (#73423)"
This reverts commit e0aac8c. I'm seeing some nvidia integration test failures: https://lab.llvm.org/buildbot/#/builders/61/builds/52334.
1 parent dd3184c commit dd09221

File tree

11 files changed

+171
-352
lines changed

11 files changed

+171
-352
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -931,53 +931,38 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, Terminator]>,
931931
}];
932932
}
933933

934-
// These mirror the reduction combining kinds from the vector dialect.
934+
// add, mul mirror the XLA ComparisonDirection enum.
935935
def GPU_AllReduceOpAdd : I32EnumAttrCase<"ADD", 0, "add">;
936-
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 1, "mul">;
937-
def GPU_AllReduceOpMinUI : I32EnumAttrCase<"MINUI", 2, "minui">;
938-
def GPU_AllReduceOpMinSI : I32EnumAttrCase<"MINSI", 3, "minsi">;
939-
// Follows the `arith.minnumf` semantics.
940-
def GPU_AllReduceOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
941-
def GPU_AllReduceOpMaxUI : I32EnumAttrCase<"MAXUI", 5, "maxui">;
942-
def GPU_AllReduceOpMaxSI : I32EnumAttrCase<"MAXSI", 6, "maxsi">;
943-
// Follows the `arith.maxnumf` semantics.
944-
def GPU_AllReduceOpMaxF : I32EnumAttrCase<"MAXF", 7, "maxf">;
945-
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 8, "and">;
946-
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 9, "or">;
947-
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 10, "xor">;
948-
// Follows the `arith.minimumf` semantics.
949-
def GPU_AllReduceOpMinimumF : I32EnumAttrCase<"MINIMUMF", 11, "minimumf">;
950-
// Follows the `arith.maximumf` semantics.
951-
def GPU_AllReduceOpMaximumF : I32EnumAttrCase<"MAXIMUMF", 12, "maximumf">;
936+
def GPU_AllReduceOpAnd : I32EnumAttrCase<"AND", 1, "and">;
937+
def GPU_AllReduceOpMax : I32EnumAttrCase<"MAX", 2, "max">;
938+
def GPU_AllReduceOpMin : I32EnumAttrCase<"MIN", 3, "min">;
939+
def GPU_AllReduceOpMul : I32EnumAttrCase<"MUL", 4, "mul">;
940+
def GPU_AllReduceOpOr : I32EnumAttrCase<"OR", 5, "or">;
941+
def GPU_AllReduceOpXor : I32EnumAttrCase<"XOR", 6, "xor">;
952942

953943
def GPU_AllReduceOperation : I32EnumAttr<"AllReduceOperation",
954944
"built-in reduction operations supported by gpu.allreduce.",
955945
[
956946
GPU_AllReduceOpAdd,
957-
GPU_AllReduceOpMul,
958-
GPU_AllReduceOpMinUI,
959-
GPU_AllReduceOpMinSI,
960-
GPU_AllReduceOpMinF,
961-
GPU_AllReduceOpMaxUI,
962-
GPU_AllReduceOpMaxSI,
963-
GPU_AllReduceOpMaxF,
964947
GPU_AllReduceOpAnd,
948+
GPU_AllReduceOpMax,
949+
GPU_AllReduceOpMin,
950+
GPU_AllReduceOpMul,
965951
GPU_AllReduceOpOr,
966-
GPU_AllReduceOpXor,
967-
GPU_AllReduceOpMinimumF,
968-
GPU_AllReduceOpMaximumF
952+
GPU_AllReduceOpXor
969953
]>{
970954
let genSpecializedAttr = 0;
971955
let cppNamespace = "::mlir::gpu";
972956
}
973-
974-
def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
975-
976957
def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
977958
"all_reduce_op">;
978959

979960
def GPU_AllReduceOp : GPU_Op<"all_reduce",
980-
[SameOperandsAndResultType, IsolatedFromAbove]> {
961+
[SameOperandsAndResultType, IsolatedFromAbove]>,
962+
Arguments<(ins AnyType:$value,
963+
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
964+
UnitAttr:$uniform)>,
965+
Results<(outs AnyType)> {
981966
let summary = "Reduce values among workgroup.";
982967
let description = [{
983968
The `all_reduce` op reduces the value of every work item across a local
@@ -996,23 +981,12 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
996981

997982
compute the sum of each work item's %0 value. The first version specifies
998983
the accumulation as operation, whereas the second version specifies the
999-
accumulation as code region. The reduction operation must be one of:
1000-
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
1001-
`or`, `xor`
1002-
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
1003-
`maximumf`
984+
accumulation as code region. The accumulation operation must be one of:
985+
`add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
1004986

1005987
If `uniform` flag is set either none or all work items of a workgroup
1006988
need to execute this op in convergence.
1007989
}];
1008-
1009-
let arguments = (ins
1010-
AnyIntegerOrFloat:$value,
1011-
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
1012-
UnitAttr:$uniform
1013-
);
1014-
let results = (outs AnyIntegerOrFloat:$result);
1015-
1016990
let regions = (region AnyRegion:$body);
1017991
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
1018992
(`uniform` $uniform^)? $body attr-dict
@@ -1022,7 +996,12 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
1022996
let hasRegionVerifier = 1;
1023997
}
1024998

1025-
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
999+
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
1000+
[SameOperandsAndResultType]>,
1001+
Arguments<(ins AnyType:$value,
1002+
GPU_AllReduceOperationAttr:$op,
1003+
UnitAttr:$uniform)>,
1004+
Results<(outs AnyType)> {
10261005
let summary = "Reduce values among subgroup.";
10271006
let description = [{
10281007
The `subgroup_reduce` op reduces the value of every work item across a
@@ -1035,21 +1014,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
10351014
```
10361015

10371016
If `uniform` flag is set either none or all work items of a subgroup
1038-
need to execute this op in convergence. The reduction operation must be one
1039-
of:
1040-
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
1041-
`or`, `xor`
1042-
* Floating point types: `add`, `mul`, `minf`, `maxf`, `minimumf`,
1043-
`maximumf`
1017+
need to execute this op in convergence.
10441018
}];
1045-
1046-
let arguments = (ins
1047-
AnyIntegerOrFloat:$value,
1048-
GPU_AllReduceOperationAttr:$op,
1049-
UnitAttr:$uniform
1050-
);
1051-
let results = (outs AnyIntegerOrFloat:$result);
1052-
10531019
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
10541020
(`uniform` $uniform^)? attr-dict
10551021
`:` functional-type(operands, results) }];

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3434
!::llvm::cast<VectorType>($_self).isScalable()}]>;
3535

3636
// Whether a type is a scalable VectorType.
37-
def IsVectorTypeWithAnyDimScalablePred
37+
def IsVectorTypeWithAnyDimScalablePred
3838
: CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
3939
::llvm::cast<VectorType>($_self).isScalable()}]>;
4040

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,17 @@ convertReduxKind(gpu::AllReduceOperation mode) {
6565
switch (mode) {
6666
case gpu::AllReduceOperation::ADD:
6767
return NVVM::ReduxKind::ADD;
68-
case gpu::AllReduceOperation::MUL:
69-
return std::nullopt;
70-
case gpu::AllReduceOperation::MINSI:
71-
return NVVM::ReduxKind::MIN;
72-
case gpu::AllReduceOperation::MINUI:
73-
return std::nullopt;
74-
case gpu::AllReduceOperation::MINF:
75-
return NVVM::ReduxKind::MIN;
76-
case gpu::AllReduceOperation::MAXSI:
77-
return NVVM::ReduxKind::MAX;
78-
case gpu::AllReduceOperation::MAXUI:
79-
return std::nullopt;
80-
case gpu::AllReduceOperation::MAXF:
81-
return NVVM::ReduxKind::MAX;
8268
case gpu::AllReduceOperation::AND:
8369
return NVVM::ReduxKind::AND;
70+
case gpu::AllReduceOperation::MAX:
71+
return NVVM::ReduxKind::MAX;
72+
case gpu::AllReduceOperation::MIN:
73+
return NVVM::ReduxKind::MIN;
8474
case gpu::AllReduceOperation::OR:
8575
return NVVM::ReduxKind::OR;
8676
case gpu::AllReduceOperation::XOR:
8777
return NVVM::ReduxKind::XOR;
88-
case gpu::AllReduceOperation::MINIMUMF:
89-
case gpu::AllReduceOperation::MAXIMUMF:
78+
case gpu::AllReduceOperation::MUL:
9079
return std::nullopt;
9180
}
9281
return std::nullopt;

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -503,53 +503,26 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
503503
return std::nullopt;
504504
}
505505

506-
// TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
507-
// does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
508-
// reduction ops. We should account possible precision requirements in this
509-
// conversion.
510-
511506
using ReduceType = gpu::AllReduceOperation;
507+
namespace spv = spirv;
512508
const OpHandler handlers[] = {
513509
{ReduceType::ADD,
514-
&createGroupReduceOpImpl<spirv::GroupIAddOp,
515-
spirv::GroupNonUniformIAddOp>,
516-
&createGroupReduceOpImpl<spirv::GroupFAddOp,
517-
spirv::GroupNonUniformFAddOp>},
510+
&createGroupReduceOpImpl<spv::GroupIAddOp, spv::GroupNonUniformIAddOp>,
511+
&createGroupReduceOpImpl<spv::GroupFAddOp, spv::GroupNonUniformFAddOp>},
518512
{ReduceType::MUL,
519-
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
520-
spirv::GroupNonUniformIMulOp>,
521-
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
522-
spirv::GroupNonUniformFMulOp>},
523-
{ReduceType::MINUI,
524-
&createGroupReduceOpImpl<spirv::GroupUMinOp,
525-
spirv::GroupNonUniformUMinOp>,
526-
nullptr},
527-
{ReduceType::MINSI,
528-
&createGroupReduceOpImpl<spirv::GroupSMinOp,
529-
spirv::GroupNonUniformSMinOp>,
530-
nullptr},
531-
{ReduceType::MINF, nullptr,
532-
&createGroupReduceOpImpl<spirv::GroupFMinOp,
533-
spirv::GroupNonUniformFMinOp>},
534-
{ReduceType::MAXUI,
535-
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
536-
spirv::GroupNonUniformUMaxOp>,
537-
nullptr},
538-
{ReduceType::MAXSI,
539-
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
540-
spirv::GroupNonUniformSMaxOp>,
541-
nullptr},
542-
{ReduceType::MAXF, nullptr,
543-
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
544-
spirv::GroupNonUniformFMaxOp>},
545-
{ReduceType::MINIMUMF, nullptr,
546-
&createGroupReduceOpImpl<spirv::GroupFMinOp,
547-
spirv::GroupNonUniformFMinOp>},
548-
{ReduceType::MAXIMUMF, nullptr,
549-
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
550-
spirv::GroupNonUniformFMaxOp>}};
551-
552-
for (const OpHandler &handler : handlers)
513+
&createGroupReduceOpImpl<spv::GroupIMulKHROp,
514+
spv::GroupNonUniformIMulOp>,
515+
&createGroupReduceOpImpl<spv::GroupFMulKHROp,
516+
spv::GroupNonUniformFMulOp>},
517+
{ReduceType::MIN,
518+
&createGroupReduceOpImpl<spv::GroupSMinOp, spv::GroupNonUniformSMinOp>,
519+
&createGroupReduceOpImpl<spv::GroupFMinOp, spv::GroupNonUniformFMinOp>},
520+
{ReduceType::MAX,
521+
&createGroupReduceOpImpl<spv::GroupSMaxOp, spv::GroupNonUniformSMaxOp>,
522+
&createGroupReduceOpImpl<spv::GroupFMaxOp, spv::GroupNonUniformFMaxOp>},
523+
};
524+
525+
for (auto &handler : handlers)
553526
if (handler.type == opType)
554527
return (handler.*handlerPtr)(builder, loc, arg, isGroup, isUniform);
555528

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
#include "mlir/IR/TypeUtilities.h"
2828
#include "mlir/Interfaces/FunctionImplementation.h"
2929
#include "mlir/Interfaces/SideEffectInterfaces.h"
30-
#include "mlir/Support/LogicalResult.h"
3130
#include "mlir/Transforms/InliningUtils.h"
32-
#include "llvm/ADT/STLExtras.h"
3331
#include "llvm/ADT/TypeSwitch.h"
3432
#include "llvm/Support/CommandLine.h"
3533
#include "llvm/Support/ErrorHandling.h"
@@ -488,23 +486,12 @@ static LogicalResult verifyAttributions(Operation *op,
488486
// AllReduceOp
489487
//===----------------------------------------------------------------------===//
490488

491-
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
492-
Type resType) {
493-
using Kind = gpu::AllReduceOperation;
494-
if (llvm::is_contained(
495-
{Kind::MINF, Kind::MAXF, Kind::MINIMUMF, Kind::MAXIMUMF}, opName)) {
496-
if (!isa<FloatType>(resType))
497-
return failure();
498-
}
499-
500-
if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
501-
Kind::AND, Kind::OR, Kind::XOR},
502-
opName)) {
503-
if (!isa<IntegerType>(resType))
504-
return failure();
505-
}
506-
507-
return success();
489+
static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
490+
Type resType) {
491+
return (opName != gpu::AllReduceOperation::AND &&
492+
opName != gpu::AllReduceOperation::OR &&
493+
opName != gpu::AllReduceOperation::XOR) ||
494+
llvm::isa<IntegerType>(resType);
508495
}
509496

510497
LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -531,13 +518,12 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
531518
return emitError("expected gpu.yield op in region");
532519
} else {
533520
gpu::AllReduceOperation opName = *getOp();
534-
if (failed(verifyReduceOpAndType(opName, getType()))) {
535-
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
536-
<< "` reduction operation is not compatible with type "
537-
<< getType();
521+
if (!verifyReduceOpAndType(opName, getType())) {
522+
return emitError()
523+
<< '`' << gpu::stringifyAllReduceOperation(opName)
524+
<< "` accumulator is only compatible with Integer type";
538525
}
539526
}
540-
541527
return success();
542528
}
543529

@@ -588,10 +574,9 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
588574

589575
LogicalResult gpu::SubgroupReduceOp::verify() {
590576
gpu::AllReduceOperation opName = getOp();
591-
if (failed(verifyReduceOpAndType(opName, getType()))) {
577+
if (!verifyReduceOpAndType(opName, getType())) {
592578
return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
593-
<< "` reduction operation is not compatible with type "
594-
<< getType();
579+
<< "` accumulator is only compatible with Integer type";
595580
}
596581
return success();
597582
}

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -214,49 +214,54 @@ struct GpuAllReduceRewriter {
214214

215215
/// Returns an accumulator factory that creates an op specified by opName.
216216
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217-
using Kind = gpu::AllReduceOperation;
218217
bool isFloatingPoint = isa<FloatType>(valueType);
219218
switch (opName) {
220-
case Kind::ADD:
219+
case gpu::AllReduceOperation::ADD:
221220
return isFloatingPoint ? getFactory<arith::AddFOp>()
222221
: getFactory<arith::AddIOp>();
223-
case Kind::MUL:
222+
case gpu::AllReduceOperation::MUL:
224223
return isFloatingPoint ? getFactory<arith::MulFOp>()
225224
: getFactory<arith::MulIOp>();
226-
case Kind::MINSI:
227-
return getFactory<arith::MinSIOp>();
228-
case Kind::MINUI:
229-
return getFactory<arith::MinUIOp>();
230-
case Kind::MINF:
231-
return getFactory<arith::MinNumFOp>();
232-
case Kind::MAXSI:
233-
return getFactory<arith::MaxSIOp>();
234-
case Kind::MAXUI:
235-
return getFactory<arith::MaxUIOp>();
236-
case Kind::MAXF:
237-
return getFactory<arith::MaxNumFOp>();
238-
case Kind::AND:
225+
case gpu::AllReduceOperation::AND:
239226
return getFactory<arith::AndIOp>();
240-
case Kind::OR:
227+
case gpu::AllReduceOperation::OR:
241228
return getFactory<arith::OrIOp>();
242-
case Kind::XOR:
229+
case gpu::AllReduceOperation::XOR:
243230
return getFactory<arith::XOrIOp>();
244-
case Kind::MINIMUMF:
245-
return getFactory<arith::MinimumFOp>();
246-
case Kind::MAXIMUMF:
247-
return getFactory<arith::MaximumFOp>();
231+
case gpu::AllReduceOperation::MAX:
232+
return isFloatingPoint
233+
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
234+
arith::CmpFPredicate::UGT>()
235+
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
236+
arith::CmpIPredicate::ugt>();
237+
case gpu::AllReduceOperation::MIN:
238+
return isFloatingPoint
239+
? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
240+
arith::CmpFPredicate::ULT>()
241+
: getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
242+
arith::CmpIPredicate::ult>();
248243
}
249244
llvm_unreachable("unknown GPU AllReduceOperation");
250245
}
251246

252247
/// Returns an accumulator factory that creates an op of type T.
253248
template <typename T>
254249
AccumulatorFactory getFactory() {
255-
return [this](Value lhs, Value rhs) {
250+
return [&](Value lhs, Value rhs) {
256251
return create<T>(lhs.getType(), lhs, rhs);
257252
};
258253
}
259254

255+
/// Returns an accumulator for comparison such as min, max. T is the type
256+
/// of the compare op.
257+
template <typename T, typename PredicateEnum, PredicateEnum predicate>
258+
AccumulatorFactory getCmpFactory() const {
259+
return [&](Value lhs, Value rhs) {
260+
Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
261+
return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
262+
};
263+
}
264+
260265
/// Creates an if-block skeleton and calls the two factories to generate the
261266
/// ops in the `then` and `else` block..
262267
///

0 commit comments

Comments
 (0)