Skip to content

Commit f61e3a6

Browse files
committed
[MLIR][OpenMP] Use map format to represent use_device_{addr,ptr}
This patch updates the `omp.target_data` operation to use the same formatting as `map` clauses on `omp.target` for `use_device_addr` and `use_device_ptr`. This is done so the mapping that is being enforced between op arguments and associated entry block arguments is explicit. The way it is achieved is by marking these clauses as entry block argument-defining and adjusting printer/parsers accordingly. As a result of this change, block arguments for `use_device_addr` come before those for `use_device_ptr`, which is the opposite of the previous undocumented situation. Some unit tests are updated based on this change, in addition to those updated because of the format change.
1 parent a757969 commit f61e3a6

File tree

11 files changed

+179
-63
lines changed

11 files changed

+179
-63
lines changed

flang/test/Fir/convert-to-llvm-openmp-and-fir.fir

+3-2
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,14 @@ func.func @_QPopenmp_target_data_region() {
429429

430430
func.func @_QPomp_target_data_empty() {
431431
%0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
432-
omp.target_data use_device_addr(%0 : !fir.ref<!fir.array<1024xi32>>) {
432+
omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
433+
omp.terminator
433434
}
434435
return
435436
}
436437

437438
// CHECK-LABEL: llvm.func @_QPomp_target_data_empty
438-
// CHECK: omp.target_data use_device_addr(%1 : !llvm.ptr) {
439+
// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
439440
// CHECK: }
440441

441442
// -----

flang/test/Lower/OpenMP/target.f90

+2-4
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,8 @@ subroutine omp_target_device_ptr
506506
type(c_ptr) :: a
507507
integer, target :: b
508508
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
509-
!CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}})
509+
!CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}} -> %[[VAL_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
510510
!$omp target data map(tofrom: a) use_device_ptr(a)
511-
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
512511
!CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
513512
a = c_loc(b)
514513
!CHECK: omp.terminator
@@ -529,9 +528,8 @@ subroutine omp_target_device_addr
529528
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
530529
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
531530
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
532-
!CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]], %[[DEV_ADDR]] : {{.*}}) {
531+
!CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
533532
!$omp target data map(tofrom: a) use_device_addr(a)
534-
!CHECK: ^bb0(%[[ARG_0:.*]]: !fir.llvm_ptr<!fir.ref<i32>>, %[[ARG_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
535533
!CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
536534
!CHECK: %[[C10:.*]] = arith.constant 10 : i32
537535
!CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>

flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90

+4-8
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
! use_device_ptr to use_device_addr works, without breaking any functionality.
77

88
!CHECK: func.func @{{.*}}only_use_device_ptr()
9-
!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
10-
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
9+
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
1110
subroutine only_use_device_ptr
1211
use iso_c_binding
1312
integer, pointer, dimension(:) :: array
@@ -19,8 +18,7 @@ subroutine only_use_device_ptr
1918
end subroutine
2019

2120
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
22-
!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
23-
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
21+
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
2422
subroutine mix_use_device_ptr_and_addr
2523
use iso_c_binding
2624
integer, pointer, dimension(:) :: array
@@ -32,8 +30,7 @@ subroutine mix_use_device_ptr_and_addr
3230
end subroutine
3331

3432
!CHECK: func.func @{{.*}}only_use_device_addr()
35-
!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
36-
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
33+
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
3734
subroutine only_use_device_addr
3835
use iso_c_binding
3936
integer, pointer, dimension(:) :: array
@@ -45,8 +42,7 @@ subroutine only_use_device_addr
4542
end subroutine
4643

4744
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
48-
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
49-
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
45+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
5046
subroutine mix_use_device_ptr_and_addr_and_map
5147
use iso_c_binding
5248
integer :: i, j

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

+24-4
Original file line numberDiff line numberDiff line change
@@ -1209,18 +1209,28 @@ class OpenMP_UseDeviceAddrClauseSkip<
12091209
bit description = false, bit extraClassDeclaration = false
12101210
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
12111211
extraClassDeclaration> {
1212+
let traits = [
1213+
BlockArgOpenMPOpInterface
1214+
];
1215+
12121216
let arguments = (ins
12131217
Variadic<OpenMP_PointerLikeType>:$use_device_addr_vars
12141218
);
12151219

1216-
let optAssemblyFormat = [{
1217-
`use_device_addr` `(` $use_device_addr_vars `:` type($use_device_addr_vars) `)`
1220+
let extraClassDeclaration = [{
1221+
unsigned numUseDeviceAddrBlockArgs() {
1222+
return getUseDeviceAddrVars().size();
1223+
}
12181224
}];
12191225

12201226
let description = [{
12211227
The optional `use_device_addr_vars` specifies the address of the objects in
12221228
the device data environment.
12231229
}];
1230+
1231+
// Assembly format not defined because this clause must be processed together
1232+
// with the first region of the operation, as it defines entry block
1233+
// arguments.
12241234
}
12251235

12261236
def OpenMP_UseDeviceAddrClause : OpenMP_UseDeviceAddrClauseSkip<>;
@@ -1234,18 +1244,28 @@ class OpenMP_UseDevicePtrClauseSkip<
12341244
bit description = false, bit extraClassDeclaration = false
12351245
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
12361246
extraClassDeclaration> {
1247+
let traits = [
1248+
BlockArgOpenMPOpInterface
1249+
];
1250+
12371251
let arguments = (ins
12381252
Variadic<OpenMP_PointerLikeType>:$use_device_ptr_vars
12391253
);
12401254

1241-
let optAssemblyFormat = [{
1242-
`use_device_ptr` `(` $use_device_ptr_vars `:` type($use_device_ptr_vars) `)`
1255+
let extraClassDeclaration = [{
1256+
unsigned numUseDevicePtrBlockArgs() {
1257+
return getUseDevicePtrVars().size();
1258+
}
12431259
}];
12441260

12451261
let description = [{
12461262
The optional `use_device_ptr_vars` specifies the device pointers to the
12471263
corresponding list items in the device data environment.
12481264
}];
1265+
1266+
// Assembly format not defined because this clause must be processed together
1267+
// with the first region of the operation, as it defines entry block
1268+
// arguments.
12491269
}
12501270

12511271
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

+6
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,12 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [
987987
OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)>
988988
];
989989

990+
let assemblyFormat = clausesAssemblyFormat # [{
991+
custom<UseDeviceAddrUseDevicePtrRegion>(
992+
$region, $use_device_addr_vars, type($use_device_addr_vars),
993+
$use_device_ptr_vars, type($use_device_ptr_vars)) attr-dict
994+
}];
995+
990996
let hasVerifier = 1;
991997
}
992998

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

+36-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
4545
"unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
4646
return 0;
4747
}]>,
48+
InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.",
49+
"unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{
50+
return 0;
51+
}]>,
52+
InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.",
53+
"unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{
54+
return 0;
55+
}]>,
4856

4957
// Unified access methods for clause-associated entry block arguments.
5058
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
@@ -72,6 +80,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
7280
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
7381
return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
7482
}]>,
83+
InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.",
84+
"unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{
85+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
86+
return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs();
87+
}]>,
88+
InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.",
89+
"unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{
90+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
91+
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
92+
}]>,
7593

7694
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
7795
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
@@ -109,13 +127,30 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
109127
iface.getTaskReductionBlockArgsStart(),
110128
$_op.numTaskReductionBlockArgs());
111129
}]>,
130+
InterfaceMethod<"Get block arguments defined by `use_device_addr`.",
131+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
132+
"getUseDeviceAddrBlockArgs", (ins), [{
133+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
134+
return $_op->getRegion(0).getArguments().slice(
135+
iface.getUseDeviceAddrBlockArgsStart(),
136+
$_op.numUseDeviceAddrBlockArgs());
137+
}]>,
138+
InterfaceMethod<"Get block arguments defined by `use_device_ptr`.",
139+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
140+
"getUseDevicePtrBlockArgs", (ins), [{
141+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
142+
return $_op->getRegion(0).getArguments().slice(
143+
iface.getUseDevicePtrBlockArgsStart(),
144+
$_op.numUseDevicePtrBlockArgs());
145+
}]>,
112146
];
113147

114148
let verify = [{
115149
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
116150
unsigned expectedArgs = iface.numInReductionBlockArgs() +
117151
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
118-
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
152+
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
153+
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
119154
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
120155
return $_op->emitOpError() << "expected at least " << expectedArgs
121156
<< " entry block argument(s)";

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ struct AllRegionParseArgs {
504504
std::optional<PrivateParseArgs> privateArgs;
505505
std::optional<ReductionParseArgs> reductionArgs;
506506
std::optional<ReductionParseArgs> taskReductionArgs;
507+
std::optional<MapParseArgs> useDeviceAddrArgs;
508+
std::optional<MapParseArgs> useDevicePtrArgs;
507509
};
508510
} // namespace
509511

@@ -648,6 +650,16 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
648650
return parser.emitError(parser.getCurrentLocation())
649651
<< "invalid `task_reduction` format";
650652

653+
if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
654+
args.useDeviceAddrArgs)))
655+
return parser.emitError(parser.getCurrentLocation())
656+
<< "invalid `use_device_addr` format";
657+
658+
if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
659+
args.useDevicePtrArgs)))
660+
return parser.emitError(parser.getCurrentLocation())
661+
<< "invalid `use_device_addr` format";
662+
651663
return parser.parseRegion(region, entryBlockArgs);
652664
}
653665

@@ -735,6 +747,18 @@ static ParseResult parseTaskReductionRegion(
735747
return parseBlockArgRegion(parser, region, args);
736748
}
737749

750+
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
751+
OpAsmParser &parser, Region &region,
752+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
753+
SmallVectorImpl<Type> &useDeviceAddrTypes,
754+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
755+
SmallVectorImpl<Type> &useDevicePtrTypes) {
756+
AllRegionParseArgs args;
757+
args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
758+
args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
759+
return parseBlockArgRegion(parser, region, args);
760+
}
761+
738762
//===----------------------------------------------------------------------===//
739763
// Printers for operations including clauses that define entry block arguments.
740764
//===----------------------------------------------------------------------===//
@@ -767,6 +791,8 @@ struct AllRegionPrintArgs {
767791
std::optional<PrivatePrintArgs> privateArgs;
768792
std::optional<ReductionPrintArgs> reductionArgs;
769793
std::optional<ReductionPrintArgs> taskReductionArgs;
794+
std::optional<MapPrintArgs> useDeviceAddrArgs;
795+
std::optional<MapPrintArgs> useDevicePtrArgs;
770796
};
771797
} // namespace
772798

@@ -849,6 +875,11 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
849875
printBlockArgClause(p, ctx, "task_reduction",
850876
iface.getTaskReductionBlockArgs(),
851877
args.taskReductionArgs);
878+
printBlockArgClause(p, ctx, "use_device_addr",
879+
iface.getUseDeviceAddrBlockArgs(),
880+
args.useDeviceAddrArgs);
881+
printBlockArgClause(p, ctx, "use_device_ptr",
882+
iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
852883

853884
p.printRegion(region, /*printEntryBlockArgs=*/false);
854885
}
@@ -925,6 +956,18 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
925956
printBlockArgRegion(p, op, region, args);
926957
}
927958

959+
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
960+
Region &region,
961+
ValueRange useDeviceAddrVars,
962+
TypeRange useDeviceAddrTypes,
963+
ValueRange useDevicePtrVars,
964+
TypeRange useDevicePtrTypes) {
965+
AllRegionPrintArgs args;
966+
args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
967+
args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
968+
printBlockArgRegion(p, op, region, args);
969+
}
970+
928971
/// Verifies Reduction Clause
929972
static LogicalResult
930973
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,

0 commit comments

Comments
 (0)