Skip to content

Commit c282d55

Browse files
committed
[mlir] add support for reductions in OpenMP WsLoopOp
Use a modeling similar to SCF ParallelOp to support arbitrary parallel reductions. The two main differences are: (1) reductions are named and declared beforehand similarly to functions using a special op that provides the neutral element, the reduction code and optionally the atomic reduction code; (2) reductions go through memory instead because this is closer to the OpenMP semantics. See https://llvm.discourse.group/t/rfc-openmp-reduction-support/3367. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D105358
1 parent d4df382 commit c282d55

File tree

9 files changed

+623
-27
lines changed

9 files changed

+623
-27
lines changed

mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ mlir_tablegen(OpenMPOps.h.inc -gen-op-decls)
99
mlir_tablegen(OpenMPOps.cpp.inc -gen-op-defs)
1010
mlir_tablegen(OpenMPOpsEnums.h.inc -gen-enum-decls)
1111
mlir_tablegen(OpenMPOpsEnums.cpp.inc -gen-enum-defs)
12+
mlir_tablegen(OpenMPTypeInterfaces.h.inc -gen-type-interface-decls)
13+
mlir_tablegen(OpenMPTypeInterfaces.cpp.inc -gen-type-interface-defs)
1214
add_mlir_doc(OpenMPOps OpenMPDialect Dialects/ -gen-dialect-doc)
1315
add_public_tablegen_target(MLIROpenMPOpsIncGen)
1416
add_dependencies(OpenMPDialectDocGen omp_common_td)

mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
#ifndef MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_
1414
#define MLIR_DIALECT_OPENMP_OPENMPDIALECT_H_
1515

16-
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1717
#include "mlir/IR/Dialect.h"
1818
#include "mlir/IR/OpDefinition.h"
19+
#include "mlir/IR/SymbolTable.h"
1920
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122

2223
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
2324
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.h.inc"
25+
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.h.inc"
2426

2527
#define GET_OP_CLASSES
2628
#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 125 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
include "mlir/IR/OpBase.td"
1818
include "mlir/Interfaces/SideEffectInterfaces.td"
1919
include "mlir/Interfaces/ControlFlowInterfaces.td"
20-
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
20+
include "mlir/IR/SymbolInterfaces.td"
2121
include "mlir/Dialect/OpenMP/OmpCommon.td"
22+
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
2223

2324
def OpenMP_Dialect : Dialect {
2425
let name = "omp";
2526
let cppNamespace = "::mlir::omp";
27+
let dependentDialects = ["::mlir::LLVM::LLVMDialect"];
2628
}
2729

2830
class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
@@ -31,6 +33,27 @@ class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
3133
// Type which can be constraint accepting standard integers and indices.
3234
def IntLikeType : AnyTypeOf<[AnyInteger, Index]>;
3335

36+
def OpenMP_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
37+
let cppNamespace = "::mlir::omp";
38+
39+
let description = [{
40+
An interface for pointer-like types suitable to contain a value that OpenMP
41+
specification refers to as variable.
42+
}];
43+
44+
let methods = [
45+
InterfaceMethod<
46+
/*description=*/"Returns the pointee type.",
47+
/*retTy=*/"::mlir::Type",
48+
/*methodName=*/"getElementType"
49+
>,
50+
];
51+
}
52+
53+
def OpenMP_PointerLikeType : Type<
54+
CPred<"$_self.isa<::mlir::omp::PointerLikeType>()">,
55+
"OpenMP-compatible variable type", "::mlir::omp::PointerLikeType">;
56+
3457
//===----------------------------------------------------------------------===//
3558
// 2.6 parallel Construct
3659
//===----------------------------------------------------------------------===//
@@ -146,6 +169,18 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
146169
that the `linear_vars` and `linear_step_vars` variadic lists should contain
147170
the same number of elements.
148171

172+
Reductions can be performed in a workshare loop by specifying reduction
173+
accumulator variables in `reduction_vars` and symbols referring to reduction
174+
declarations in the `reductions` attribute. Each reduction is identified
175+
by the accumulator it uses and accumulators must not be repeated in the same
176+
reduction. The `omp.reduction` operation accepts the accumulator and a
177+
partial value which is considered to be produced by the current loop
178+
iteration for the given reduction. If multiple values are produced for the
179+
same accumulator, i.e. there are multiple `omp.reduction`s, the last value
180+
is taken. The reduction declaration specifies how to combine the values from
181+
each iteration into the final value, which is available in the accumulator
182+
after the loop completes.
183+
149184
The optional `schedule_val` attribute specifies the loop schedule for this
150185
loop, determining how the loop is distributed across the parallel threads.
151186
The optional `schedule_chunk_var` associated with this determines further
@@ -173,6 +208,9 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
173208
Variadic<AnyType>:$lastprivate_vars,
174209
Variadic<AnyType>:$linear_vars,
175210
Variadic<AnyType>:$linear_step_vars,
211+
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
212+
OptionalAttr<TypedArrayAttrBase<SymbolRefAttr,
213+
"array of symbol references">>:$reductions,
176214
OptionalAttr<ScheduleKind>:$schedule_val,
177215
Optional<AnyType>:$schedule_chunk_var,
178216
Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$collapse_val,
@@ -191,11 +229,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
191229
"ValueRange":$upperBound, "ValueRange":$step,
192230
"ValueRange":$privateVars, "ValueRange":$firstprivateVars,
193231
"ValueRange":$lastprivate_vars, "ValueRange":$linear_vars,
194-
"ValueRange":$linear_step_vars, "StringAttr":$schedule_val,
195-
"Value":$schedule_chunk_var, "IntegerAttr":$collapse_val,
196-
"UnitAttr":$nowait, "IntegerAttr":$ordered_val,
197-
"StringAttr":$order_val, "UnitAttr":$inclusive, CArg<"bool",
198-
"true">:$buildBody)>,
232+
"ValueRange":$linear_step_vars, "ValueRange":$reduction_vars,
233+
"StringAttr":$schedule_val, "Value":$schedule_chunk_var,
234+
"IntegerAttr":$collapse_val, "UnitAttr":$nowait,
235+
"IntegerAttr":$ordered_val, "StringAttr":$order_val,
236+
"UnitAttr":$inclusive, CArg<"bool", "true">:$buildBody)>,
199237
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands,
200238
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
201239
];
@@ -205,13 +243,18 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
205243
let extraClassDeclaration = [{
206244
/// Returns the number of loops in the workshape loop nest.
207245
unsigned getNumLoops() { return lowerBound().size(); }
246+
247+
/// Returns the number of reduction variables.
248+
unsigned getNumReductionVars() { return reduction_vars().size(); }
208249
}];
209250
let parser = [{ return parseWsLoopOp(parser, result); }];
210251
let printer = [{ return printWsLoopOp(p, *this); }];
252+
let verifier = [{ return ::verifyWsLoopOp(*this); }];
211253
}
212254

213-
def YieldOp : OpenMP_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
214-
HasParent<"WsLoopOp">]> {
255+
def YieldOp : OpenMP_Op<"yield",
256+
[NoSideEffect, ReturnLike, Terminator,
257+
ParentOneOf<["WsLoopOp", "ReductionDeclareOp"]>]> {
215258
let summary = "loop yield and termination operation";
216259
let description = [{
217260
"omp.yield" yields SSA values from the OpenMP dialect op region and
@@ -334,4 +377,78 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
334377
let assemblyFormat = "attr-dict";
335378
}
336379

380+
//===----------------------------------------------------------------------===//
381+
// 2.19.5.7 declare reduction Directive
382+
//===----------------------------------------------------------------------===//
383+
384+
def ReductionDeclareOp : OpenMP_Op<"reduction.declare", [Symbol]> {
385+
let summary = "declares a reduction kind";
386+
387+
let description = [{
388+
Declares an OpenMP reduction kind. This requires two mandatory and one
389+
optional region.
390+
391+
1. The initializer region specifies how to initialize the thread-local
392+
reduction value. This is usually the neutral element of the reduction.
393+
For convenience, the region has an argument that contains the value
394+
of the reduction accumulator at the start of the reduction. It is
395+
expected to `omp.yield` the new value on all control flow paths.
396+
2. The reduction region specifies how to combine two values into one, i.e.
397+
the reduction operator. It accepts the two values as arguments and is
398+
expected to `omp.yield` the combined value on all control flow paths.
399+
3. The atomic reduction region is optional and specifies how two values
400+
can be combined atomically given local accumulator variables. It is
401+
expected to store the combined value in the first accumulator variable.
402+
403+
Note that the MLIR type system does not allow for type-polymorphic
404+
reductions. Separate reduction declarations should be created for different
405+
element and accumulator types.
406+
}];
407+
408+
let arguments = (ins SymbolNameAttr:$sym_name,
409+
TypeAttr:$type);
410+
411+
let regions = (region AnyRegion:$initializerRegion,
412+
AnyRegion:$reductionRegion,
413+
AnyRegion:$atomicReductionRegion);
414+
let verifier = "return ::verifyReductionDeclareOp(*this);";
415+
416+
let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword "
417+
"`init` $initializerRegion "
418+
"`combiner` $reductionRegion "
419+
"custom<AtomicReductionRegion>($atomicReductionRegion)";
420+
421+
let extraClassDeclaration = [{
422+
PointerLikeType getAccumulatorType() {
423+
if (atomicReductionRegion().empty())
424+
return {};
425+
426+
return atomicReductionRegion().front().getArgument(0).getType();
427+
}
428+
}];
429+
}
430+
431+
//===----------------------------------------------------------------------===//
432+
// 2.19.5.4 reduction clause
433+
//===----------------------------------------------------------------------===//
434+
435+
def ReductionOp : OpenMP_Op<"reduction", [
436+
TypesMatchWith<"value types matches accumulator element type",
437+
"accumulator", "operand",
438+
"$_self.cast<::mlir::omp::PointerLikeType>().getElementType()">
439+
]> {
440+
let summary = "reduction construct";
441+
let description = [{
442+
Indicates the value that is produced by the current reduction-participating
443+
entity for a reduction requested in some ancestor. The reduction is
444+
identified by the accumulator, but the value of the accumulator may not be
445+
updated immediately.
446+
}];
447+
448+
let arguments= (ins AnyType:$operand, OpenMP_PointerLikeType:$accumulator);
449+
let assemblyFormat =
450+
"$operand `,` $accumulator attr-dict `:` type($accumulator)";
451+
let verifier = "return ::verifyReductionOp(*this);";
452+
}
453+
337454
#endif // OPENMP_OPS

mlir/lib/Dialect/OpenMP/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ add_mlir_dialect_library(MLIROpenMP
99

1010
LINK_LIBS PUBLIC
1111
MLIRIR
12+
MLIRLLVMIR
1213
)

0 commit comments

Comments
 (0)