Skip to content

Commit d6920f4

Browse files
committed
[MLIR][OpenMP] Improve omp.section block arguments handling
The `omp.section` operation is an outlier in that the block arguments it has are defined by clauses on the required parent `omp.sections` operation. This patch updates the definition of this operation introducing the `BlockArgOpenMPOpInterface` to simplify the handling and verification of these block arguments, implemented based on the parent `omp.sections`.
1 parent a821f44 commit d6920f4

File tree

4 files changed

+53
-2
lines changed

4 files changed

+53
-2
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

+10-2
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
207207
// 2.8.1 Sections Construct
208208
//===----------------------------------------------------------------------===//
209209

210-
def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">],
211-
singleRegion = true> {
210+
def SectionOp : OpenMP_Op<"section", traits = [
211+
BlockArgOpenMPOpInterface, HasParent<"SectionsOp">
212+
], singleRegion = true> {
212213
let summary = "section directive";
213214
let description = [{
214215
A section operation encloses a region which represents one section in a
@@ -218,6 +219,13 @@ def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">],
218219
operation. This is done to reflect situations where these block arguments
219220
represent variables private to each section.
220221
}];
222+
let extraClassDeclaration = [{
223+
// Override BlockArgOpenMPOpInterface methods based on the parent
224+
// omp.sections operation. Only forward-declare here because SectionsOp is
225+
// not completely defined at this point.
226+
unsigned numPrivateBlockArgs();
227+
unsigned numReductionBlockArgs();
228+
}] # clausesExtraClassDeclaration;
221229
let assemblyFormat = "$region attr-dict";
222230
}
223231

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1844,6 +1844,18 @@ LogicalResult TeamsOp::verify() {
18441844
getReductionByref());
18451845
}
18461846

1847+
//===----------------------------------------------------------------------===//
1848+
// SectionOp
1849+
//===----------------------------------------------------------------------===//
1850+
1851+
unsigned SectionOp::numPrivateBlockArgs() {
1852+
return getParentOp().numPrivateBlockArgs();
1853+
}
1854+
1855+
unsigned SectionOp::numReductionBlockArgs() {
1856+
return getParentOp().numReductionBlockArgs();
1857+
}
1858+
18471859
//===----------------------------------------------------------------------===//
18481860
// SectionsOp
18491861
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

+25
Original file line numberDiff line numberDiff line change
@@ -1572,6 +1572,31 @@ func.func @omp_sections() {
15721572

15731573
// -----
15741574

1575+
omp.declare_reduction @add_f32 : f32
1576+
init {
1577+
^bb0(%arg: f32):
1578+
%0 = arith.constant 0.0 : f32
1579+
omp.yield (%0 : f32)
1580+
}
1581+
combiner {
1582+
^bb1(%arg0: f32, %arg1: f32):
1583+
%1 = arith.addf %arg0, %arg1 : f32
1584+
omp.yield (%1 : f32)
1585+
}
1586+
1587+
func.func @omp_sections(%x : !llvm.ptr) {
1588+
omp.sections reduction(@add_f32 %x -> %arg0 : !llvm.ptr) {
1589+
// expected-error @below {{op expected at least 1 entry block argument(s)}}
1590+
omp.section {
1591+
omp.terminator
1592+
}
1593+
omp.terminator
1594+
}
1595+
return
1596+
}
1597+
1598+
// -----
1599+
15751600
func.func @omp_single(%data_var : memref<i32>) -> () {
15761601
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
15771602
"omp.single" (%data_var) ({

mlir/test/Dialect/OpenMP/ops.mlir

+6
Original file line numberDiff line numberDiff line change
@@ -1127,11 +1127,13 @@ func.func @sections_reduction() {
11271127
omp.sections reduction(@add_f32 %0 -> %arg0 : !llvm.ptr) {
11281128
// CHECK: omp.section
11291129
omp.section {
1130+
^bb0(%arg1 : !llvm.ptr):
11301131
%1 = arith.constant 2.0 : f32
11311132
omp.terminator
11321133
}
11331134
// CHECK: omp.section
11341135
omp.section {
1136+
^bb0(%arg1 : !llvm.ptr):
11351137
%1 = arith.constant 3.0 : f32
11361138
omp.terminator
11371139
}
@@ -1148,11 +1150,13 @@ func.func @sections_reduction_byref() {
11481150
omp.sections reduction(byref @add_f32 %0 -> %arg0 : !llvm.ptr) {
11491151
// CHECK: omp.section
11501152
omp.section {
1153+
^bb0(%arg1 : !llvm.ptr):
11511154
%1 = arith.constant 2.0 : f32
11521155
omp.terminator
11531156
}
11541157
// CHECK: omp.section
11551158
omp.section {
1159+
^bb0(%arg1 : !llvm.ptr):
11561160
%1 = arith.constant 3.0 : f32
11571161
omp.terminator
11581162
}
@@ -1246,10 +1250,12 @@ func.func @sections_reduction2() {
12461250
// CHECK: omp.sections reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>)
12471251
omp.sections reduction(@add2_f32 %0 -> %arg0 : memref<1xf32>) {
12481252
omp.section {
1253+
^bb0(%arg1 : !llvm.ptr):
12491254
%1 = arith.constant 2.0 : f32
12501255
omp.terminator
12511256
}
12521257
omp.section {
1258+
^bb0(%arg1 : !llvm.ptr):
12531259
%1 = arith.constant 2.0 : f32
12541260
omp.terminator
12551261
}

0 commit comments

Comments
 (0)