Skip to content

Commit ee6199c

Browse files
authored
[mlir][openacc][NFC] Cleanup hasOnly functions for device_type support (#78800)
Just a cleanup for all the `has.*Only()` function to avoid code duplication
1 parent b5df6a9 commit ee6199c

File tree

1 file changed

+49
-101
lines changed

1 file changed

+49
-101
lines changed

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

+49-101
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,41 @@ void OpenACCDialect::initialize() {
6969
*getContext());
7070
}
7171

72+
//===----------------------------------------------------------------------===//
73+
// device_type support helpers
74+
//===----------------------------------------------------------------------===//
75+
76+
static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
77+
if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
78+
return true;
79+
return false;
80+
}
81+
82+
static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
83+
mlir::acc::DeviceType deviceType) {
84+
if (!hasDeviceTypeValues(arrayAttr))
85+
return false;
86+
87+
for (auto attr : *arrayAttr) {
88+
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
89+
if (deviceTypeAttr.getValue() == deviceType)
90+
return true;
91+
}
92+
93+
return false;
94+
}
95+
96+
static void printDeviceTypes(mlir::OpAsmPrinter &p,
97+
std::optional<mlir::ArrayAttr> deviceTypes) {
98+
if (!hasDeviceTypeValues(deviceTypes))
99+
return;
100+
101+
p << "[";
102+
llvm::interleaveComma(*deviceTypes, p,
103+
[&](mlir::Attribute attr) { p << attr; });
104+
p << "]";
105+
}
106+
72107
//===----------------------------------------------------------------------===//
73108
// DataBoundsOp
74109
//===----------------------------------------------------------------------===//
@@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() {
722757
}
723758

724759
bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
725-
if (auto arrayAttr = getAsyncOnly()) {
726-
if (findSegment(*arrayAttr, deviceType))
727-
return true;
728-
}
729-
return false;
760+
return hasDeviceType(getAsyncOnly(), deviceType);
730761
}
731762

732763
mlir::Value acc::ParallelOp::getAsyncValue() {
@@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() {
789820
}
790821

791822
bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
792-
if (auto arrayAttr = getWaitOnly()) {
793-
if (findSegment(*arrayAttr, deviceType))
794-
return true;
795-
}
796-
return false;
823+
return hasDeviceType(getWaitOnly(), deviceType);
797824
}
798825

799826
mlir::Operation::operand_range ParallelOp::getWaitValues() {
@@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
10331060
return success();
10341061
}
10351062

1036-
static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
1037-
if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
1038-
return true;
1039-
return false;
1040-
}
1041-
1042-
static void printDeviceTypes(mlir::OpAsmPrinter &p,
1043-
std::optional<mlir::ArrayAttr> deviceTypes) {
1044-
if (!hasDeviceTypeValues(deviceTypes))
1045-
return;
1046-
1047-
p << "[";
1048-
llvm::interleaveComma(*deviceTypes, p,
1049-
[&](mlir::Attribute attr) { p << attr; });
1050-
p << "]";
1051-
}
1052-
10531063
static void printDeviceTypeOperandsWithKeywordOnly(
10541064
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
10551065
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() {
10931103
}
10941104

10951105
bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1096-
if (auto arrayAttr = getAsyncOnly()) {
1097-
if (findSegment(*arrayAttr, deviceType))
1098-
return true;
1099-
}
1100-
return false;
1106+
return hasDeviceType(getAsyncOnly(), deviceType);
11011107
}
11021108

11031109
mlir::Value acc::SerialOp::getAsyncValue() {
@@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() {
11141120
}
11151121

11161122
bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1117-
if (auto arrayAttr = getWaitOnly()) {
1118-
if (findSegment(*arrayAttr, deviceType))
1119-
return true;
1120-
}
1121-
return false;
1123+
return hasDeviceType(getWaitOnly(), deviceType);
11221124
}
11231125

11241126
mlir::Operation::operand_range SerialOp::getWaitValues() {
@@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() {
11771179
}
11781180

11791181
bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1180-
if (auto arrayAttr = getAsyncOnly()) {
1181-
if (findSegment(*arrayAttr, deviceType))
1182-
return true;
1183-
}
1184-
return false;
1182+
return hasDeviceType(getAsyncOnly(), deviceType);
11851183
}
11861184

11871185
mlir::Value acc::KernelsOp::getAsyncValue() {
@@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() {
12281226
}
12291227

12301228
bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1231-
if (auto arrayAttr = getWaitOnly()) {
1232-
if (findSegment(*arrayAttr, deviceType))
1233-
return true;
1234-
}
1235-
return false;
1229+
return hasDeviceType(getWaitOnly(), deviceType);
12361230
}
12371231

12381232
mlir::Operation::operand_range KernelsOp::getWaitValues() {
@@ -1646,33 +1640,21 @@ Value LoopOp::getDataOperand(unsigned i) {
16461640
bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
16471641

16481642
bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1649-
if (auto arrayAttr = getAuto_()) {
1650-
if (findSegment(*arrayAttr, deviceType))
1651-
return true;
1652-
}
1653-
return false;
1643+
return hasDeviceType(getAuto_(), deviceType);
16541644
}
16551645

16561646
bool LoopOp::hasIndependent() {
16571647
return hasIndependent(mlir::acc::DeviceType::None);
16581648
}
16591649

16601650
bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1661-
if (auto arrayAttr = getIndependent()) {
1662-
if (findSegment(*arrayAttr, deviceType))
1663-
return true;
1664-
}
1665-
return false;
1651+
return hasDeviceType(getIndependent(), deviceType);
16661652
}
16671653

16681654
bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
16691655

16701656
bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1671-
if (auto arrayAttr = getSeq()) {
1672-
if (findSegment(*arrayAttr, deviceType))
1673-
return true;
1674-
}
1675-
return false;
1657+
return hasDeviceType(getSeq(), deviceType);
16761658
}
16771659

16781660
mlir::Value LoopOp::getVectorValue() {
@@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
16871669
bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
16881670

16891671
bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1690-
if (auto arrayAttr = getVector()) {
1691-
if (findSegment(*arrayAttr, deviceType))
1692-
return true;
1693-
}
1694-
return false;
1672+
return hasDeviceType(getVector(), deviceType);
16951673
}
16961674

16971675
mlir::Value LoopOp::getWorkerValue() {
@@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
17061684
bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
17071685

17081686
bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
1709-
if (auto arrayAttr = getWorker()) {
1710-
if (findSegment(*arrayAttr, deviceType))
1711-
return true;
1712-
}
1713-
return false;
1687+
return hasDeviceType(getWorker(), deviceType);
17141688
}
17151689

17161690
mlir::Operation::operand_range LoopOp::getTileValues() {
@@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
17711745
bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
17721746

17731747
bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
1774-
if (auto arrayAttr = getGang()) {
1775-
if (findSegment(*arrayAttr, deviceType))
1776-
return true;
1777-
}
1778-
return false;
1748+
return hasDeviceType(getGang(), deviceType);
17791749
}
17801750

17811751
//===----------------------------------------------------------------------===//
@@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() {
18151785
}
18161786

18171787
bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1818-
if (auto arrayAttr = getAsyncOnly()) {
1819-
if (findSegment(*arrayAttr, deviceType))
1820-
return true;
1821-
}
1822-
return false;
1788+
return hasDeviceType(getAsyncOnly(), deviceType);
18231789
}
18241790

18251791
mlir::Value DataOp::getAsyncValue() {
@@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
18341800
bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
18351801

18361802
bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1837-
if (auto arrayAttr = getWaitOnly()) {
1838-
if (findSegment(*arrayAttr, deviceType))
1839-
return true;
1840-
}
1841-
return false;
1803+
return hasDeviceType(getWaitOnly(), deviceType);
18421804
}
18431805

18441806
mlir::Operation::operand_range DataOp::getWaitValues() {
@@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() {
20912053
// RoutineOp
20922054
//===----------------------------------------------------------------------===//
20932055

2094-
static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
2095-
mlir::acc::DeviceType deviceType) {
2096-
if (!hasDeviceTypeValues(arrayAttr))
2097-
return false;
2098-
2099-
for (auto attr : *arrayAttr) {
2100-
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2101-
if (deviceTypeAttr.getValue() == deviceType)
2102-
return true;
2103-
}
2104-
2105-
return false;
2106-
}
2107-
21082056
static unsigned getParallelismForDeviceType(acc::RoutineOp op,
21092057
acc::DeviceType dtype) {
21102058
unsigned parallelism = 0;

0 commit comments

Comments
 (0)