@@ -69,6 +69,41 @@ void OpenACCDialect::initialize() {
69
69
*getContext ());
70
70
}
71
71
72
+ // ===----------------------------------------------------------------------===//
73
+ // device_type support helpers
74
+ // ===----------------------------------------------------------------------===//
75
+
76
+ static bool hasDeviceTypeValues (std::optional<mlir::ArrayAttr> arrayAttr) {
77
+ if (arrayAttr && *arrayAttr && arrayAttr->size () > 0 )
78
+ return true ;
79
+ return false ;
80
+ }
81
+
82
+ static bool hasDeviceType (std::optional<mlir::ArrayAttr> arrayAttr,
83
+ mlir::acc::DeviceType deviceType) {
84
+ if (!hasDeviceTypeValues (arrayAttr))
85
+ return false ;
86
+
87
+ for (auto attr : *arrayAttr) {
88
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
89
+ if (deviceTypeAttr.getValue () == deviceType)
90
+ return true ;
91
+ }
92
+
93
+ return false ;
94
+ }
95
+
96
+ static void printDeviceTypes (mlir::OpAsmPrinter &p,
97
+ std::optional<mlir::ArrayAttr> deviceTypes) {
98
+ if (!hasDeviceTypeValues (deviceTypes))
99
+ return ;
100
+
101
+ p << " [" ;
102
+ llvm::interleaveComma (*deviceTypes, p,
103
+ [&](mlir::Attribute attr) { p << attr; });
104
+ p << " ]" ;
105
+ }
106
+
72
107
// ===----------------------------------------------------------------------===//
73
108
// DataBoundsOp
74
109
// ===----------------------------------------------------------------------===//
@@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() {
722
757
}
723
758
724
759
bool acc::ParallelOp::hasAsyncOnly (mlir::acc::DeviceType deviceType) {
725
- if (auto arrayAttr = getAsyncOnly ()) {
726
- if (findSegment (*arrayAttr, deviceType))
727
- return true ;
728
- }
729
- return false ;
760
+ return hasDeviceType (getAsyncOnly (), deviceType);
730
761
}
731
762
732
763
mlir::Value acc::ParallelOp::getAsyncValue () {
@@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() {
789
820
}
790
821
791
822
bool acc::ParallelOp::hasWaitOnly (mlir::acc::DeviceType deviceType) {
792
- if (auto arrayAttr = getWaitOnly ()) {
793
- if (findSegment (*arrayAttr, deviceType))
794
- return true ;
795
- }
796
- return false ;
823
+ return hasDeviceType (getWaitOnly (), deviceType);
797
824
}
798
825
799
826
mlir::Operation::operand_range ParallelOp::getWaitValues () {
@@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
1033
1060
return success ();
1034
1061
}
1035
1062
1036
- static bool hasDeviceTypeValues (std::optional<mlir::ArrayAttr> arrayAttr) {
1037
- if (arrayAttr && *arrayAttr && arrayAttr->size () > 0 )
1038
- return true ;
1039
- return false ;
1040
- }
1041
-
1042
- static void printDeviceTypes (mlir::OpAsmPrinter &p,
1043
- std::optional<mlir::ArrayAttr> deviceTypes) {
1044
- if (!hasDeviceTypeValues (deviceTypes))
1045
- return ;
1046
-
1047
- p << " [" ;
1048
- llvm::interleaveComma (*deviceTypes, p,
1049
- [&](mlir::Attribute attr) { p << attr; });
1050
- p << " ]" ;
1051
- }
1052
-
1053
1063
static void printDeviceTypeOperandsWithKeywordOnly (
1054
1064
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1055
1065
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() {
1093
1103
}
1094
1104
1095
1105
bool acc::SerialOp::hasAsyncOnly (mlir::acc::DeviceType deviceType) {
1096
- if (auto arrayAttr = getAsyncOnly ()) {
1097
- if (findSegment (*arrayAttr, deviceType))
1098
- return true ;
1099
- }
1100
- return false ;
1106
+ return hasDeviceType (getAsyncOnly (), deviceType);
1101
1107
}
1102
1108
1103
1109
mlir::Value acc::SerialOp::getAsyncValue () {
@@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() {
1114
1120
}
1115
1121
1116
1122
bool acc::SerialOp::hasWaitOnly (mlir::acc::DeviceType deviceType) {
1117
- if (auto arrayAttr = getWaitOnly ()) {
1118
- if (findSegment (*arrayAttr, deviceType))
1119
- return true ;
1120
- }
1121
- return false ;
1123
+ return hasDeviceType (getWaitOnly (), deviceType);
1122
1124
}
1123
1125
1124
1126
mlir::Operation::operand_range SerialOp::getWaitValues () {
@@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() {
1177
1179
}
1178
1180
1179
1181
bool acc::KernelsOp::hasAsyncOnly (mlir::acc::DeviceType deviceType) {
1180
- if (auto arrayAttr = getAsyncOnly ()) {
1181
- if (findSegment (*arrayAttr, deviceType))
1182
- return true ;
1183
- }
1184
- return false ;
1182
+ return hasDeviceType (getAsyncOnly (), deviceType);
1185
1183
}
1186
1184
1187
1185
mlir::Value acc::KernelsOp::getAsyncValue () {
@@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() {
1228
1226
}
1229
1227
1230
1228
bool acc::KernelsOp::hasWaitOnly (mlir::acc::DeviceType deviceType) {
1231
- if (auto arrayAttr = getWaitOnly ()) {
1232
- if (findSegment (*arrayAttr, deviceType))
1233
- return true ;
1234
- }
1235
- return false ;
1229
+ return hasDeviceType (getWaitOnly (), deviceType);
1236
1230
}
1237
1231
1238
1232
mlir::Operation::operand_range KernelsOp::getWaitValues () {
@@ -1646,33 +1640,21 @@ Value LoopOp::getDataOperand(unsigned i) {
1646
1640
bool LoopOp::hasAuto () { return hasAuto (mlir::acc::DeviceType::None); }
1647
1641
1648
1642
bool LoopOp::hasAuto (mlir::acc::DeviceType deviceType) {
1649
- if (auto arrayAttr = getAuto_ ()) {
1650
- if (findSegment (*arrayAttr, deviceType))
1651
- return true ;
1652
- }
1653
- return false ;
1643
+ return hasDeviceType (getAuto_ (), deviceType);
1654
1644
}
1655
1645
1656
1646
bool LoopOp::hasIndependent () {
1657
1647
return hasIndependent (mlir::acc::DeviceType::None);
1658
1648
}
1659
1649
1660
1650
bool LoopOp::hasIndependent (mlir::acc::DeviceType deviceType) {
1661
- if (auto arrayAttr = getIndependent ()) {
1662
- if (findSegment (*arrayAttr, deviceType))
1663
- return true ;
1664
- }
1665
- return false ;
1651
+ return hasDeviceType (getIndependent (), deviceType);
1666
1652
}
1667
1653
1668
1654
bool LoopOp::hasSeq () { return hasSeq (mlir::acc::DeviceType::None); }
1669
1655
1670
1656
bool LoopOp::hasSeq (mlir::acc::DeviceType deviceType) {
1671
- if (auto arrayAttr = getSeq ()) {
1672
- if (findSegment (*arrayAttr, deviceType))
1673
- return true ;
1674
- }
1675
- return false ;
1657
+ return hasDeviceType (getSeq (), deviceType);
1676
1658
}
1677
1659
1678
1660
mlir::Value LoopOp::getVectorValue () {
@@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1687
1669
bool LoopOp::hasVector () { return hasVector (mlir::acc::DeviceType::None); }
1688
1670
1689
1671
bool LoopOp::hasVector (mlir::acc::DeviceType deviceType) {
1690
- if (auto arrayAttr = getVector ()) {
1691
- if (findSegment (*arrayAttr, deviceType))
1692
- return true ;
1693
- }
1694
- return false ;
1672
+ return hasDeviceType (getVector (), deviceType);
1695
1673
}
1696
1674
1697
1675
mlir::Value LoopOp::getWorkerValue () {
@@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1706
1684
bool LoopOp::hasWorker () { return hasWorker (mlir::acc::DeviceType::None); }
1707
1685
1708
1686
bool LoopOp::hasWorker (mlir::acc::DeviceType deviceType) {
1709
- if (auto arrayAttr = getWorker ()) {
1710
- if (findSegment (*arrayAttr, deviceType))
1711
- return true ;
1712
- }
1713
- return false ;
1687
+ return hasDeviceType (getWorker (), deviceType);
1714
1688
}
1715
1689
1716
1690
mlir::Operation::operand_range LoopOp::getTileValues () {
@@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
1771
1745
bool LoopOp::hasGang () { return hasGang (mlir::acc::DeviceType::None); }
1772
1746
1773
1747
bool LoopOp::hasGang (mlir::acc::DeviceType deviceType) {
1774
- if (auto arrayAttr = getGang ()) {
1775
- if (findSegment (*arrayAttr, deviceType))
1776
- return true ;
1777
- }
1778
- return false ;
1748
+ return hasDeviceType (getGang (), deviceType);
1779
1749
}
1780
1750
1781
1751
// ===----------------------------------------------------------------------===//
@@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() {
1815
1785
}
1816
1786
1817
1787
bool acc::DataOp::hasAsyncOnly (mlir::acc::DeviceType deviceType) {
1818
- if (auto arrayAttr = getAsyncOnly ()) {
1819
- if (findSegment (*arrayAttr, deviceType))
1820
- return true ;
1821
- }
1822
- return false ;
1788
+ return hasDeviceType (getAsyncOnly (), deviceType);
1823
1789
}
1824
1790
1825
1791
mlir::Value DataOp::getAsyncValue () {
@@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1834
1800
bool DataOp::hasWaitOnly () { return hasWaitOnly (mlir::acc::DeviceType::None); }
1835
1801
1836
1802
bool DataOp::hasWaitOnly (mlir::acc::DeviceType deviceType) {
1837
- if (auto arrayAttr = getWaitOnly ()) {
1838
- if (findSegment (*arrayAttr, deviceType))
1839
- return true ;
1840
- }
1841
- return false ;
1803
+ return hasDeviceType (getWaitOnly (), deviceType);
1842
1804
}
1843
1805
1844
1806
mlir::Operation::operand_range DataOp::getWaitValues () {
@@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() {
2091
2053
// RoutineOp
2092
2054
// ===----------------------------------------------------------------------===//
2093
2055
2094
- static bool hasDeviceType (std::optional<mlir::ArrayAttr> arrayAttr,
2095
- mlir::acc::DeviceType deviceType) {
2096
- if (!hasDeviceTypeValues (arrayAttr))
2097
- return false ;
2098
-
2099
- for (auto attr : *arrayAttr) {
2100
- auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2101
- if (deviceTypeAttr.getValue () == deviceType)
2102
- return true ;
2103
- }
2104
-
2105
- return false ;
2106
- }
2107
-
2108
2056
static unsigned getParallelismForDeviceType (acc::RoutineOp op,
2109
2057
acc::DeviceType dtype) {
2110
2058
unsigned parallelism = 0 ;
0 commit comments