Skip to content

Commit b2f1d06

Browse files
authored
[mlir][amdgpu] Improve Chipset version utility (#106169)
* Fix an OOB access * Add comparison operators * Add documentation * Add unit tests
1 parent b8c0e8a commit b2f1d06

File tree

6 files changed

+116
-12
lines changed

6 files changed

+116
-12
lines changed

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,41 @@
99
#define MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_
1010

1111
#include "mlir/Support/LLVM.h"
12+
#include <utility>
1213

13-
namespace mlir {
14-
namespace amdgpu {
14+
namespace mlir::amdgpu {
15+
16+
/// Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
17+
/// Note that the leading digits form a decimal number, while the last two
18+
/// digits for a hexadecimal number. For example:
19+
/// gfx942 --> major = 9, minor = 0x42
20+
/// gfx90a --> major = 9, minor = 0xa
21+
/// gfx1103 --> major = 10, minor = 0x3
1522
struct Chipset {
1623
Chipset() = default;
1724
Chipset(unsigned majorVersion, unsigned minorVersion)
1825
: majorVersion(majorVersion), minorVersion(minorVersion){};
26+
27+
/// Parses the chipset version string and returns the chipset on success, and
28+
/// failure otherwise.
1929
static FailureOr<Chipset> parse(StringRef name);
2030

21-
unsigned majorVersion = 0;
22-
unsigned minorVersion = 0;
31+
friend bool operator==(const Chipset &lhs, const Chipset &rhs) {
32+
return lhs.majorVersion == rhs.majorVersion &&
33+
lhs.minorVersion == rhs.minorVersion;
34+
}
35+
friend bool operator!=(const Chipset &lhs, const Chipset &rhs) {
36+
return !(lhs == rhs);
37+
}
38+
friend bool operator<(const Chipset &lhs, const Chipset &rhs) {
39+
return std::make_pair(lhs.majorVersion, lhs.minorVersion) <
40+
std::make_pair(rhs.majorVersion, rhs.minorVersion);
41+
}
42+
43+
unsigned majorVersion = 0; // The major version (decimal).
44+
unsigned minorVersion = 0; // The minor version (hexadecimal).
2345
};
24-
} // end namespace amdgpu
25-
} // end namespace mlir
46+
47+
} // namespace mlir::amdgpu
2648

2749
#endif
Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- Chipset.cpp - AMDGPU Chipset version struct parsing -----------===//
1+
//===- Chipset.cpp - AMDGPU Chipset version struct parsing ----------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,22 +7,26 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
10-
#include "mlir/Support/LLVM.h"
1110
#include "llvm/ADT/StringRef.h"
1211

13-
using namespace mlir;
14-
using namespace mlir::amdgpu;
12+
namespace mlir::amdgpu {
1513

1614
FailureOr<Chipset> Chipset::parse(StringRef name) {
17-
if (!name.starts_with("gfx"))
15+
if (!name.consume_front("gfx"))
1816
return failure();
17+
if (name.size() < 3)
18+
return failure();
19+
1920
unsigned major = 0;
2021
unsigned minor = 0;
21-
StringRef majorRef = name.drop_front(3).drop_back(2);
22+
23+
StringRef majorRef = name.drop_back(2);
2224
StringRef minorRef = name.take_back(2);
2325
if (majorRef.getAsInteger(10, major))
2426
return failure();
2527
if (minorRef.getAsInteger(16, minor))
2628
return failure();
2729
return Chipset(major, minor);
2830
}
31+
32+
} // namespace mlir::amdgpu
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- AMDGPUUtilsTest.cpp - Unit tests for AMDGPU dialect utils ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
10+
#include "gtest/gtest.h"
11+
12+
namespace mlir::amdgpu {
13+
namespace {
14+
15+
TEST(ChipsetTest, Parsing) {
16+
FailureOr<Chipset> chipset = Chipset::parse("gfx90a");
17+
ASSERT_TRUE(succeeded(chipset));
18+
EXPECT_EQ(chipset->majorVersion, 9u);
19+
EXPECT_EQ(chipset->minorVersion, 0x0au);
20+
21+
chipset = Chipset::parse("gfx940");
22+
ASSERT_TRUE(succeeded(chipset));
23+
EXPECT_EQ(chipset->majorVersion, 9u);
24+
EXPECT_EQ(chipset->minorVersion, 0x40u);
25+
26+
chipset = Chipset::parse("gfx1103");
27+
ASSERT_TRUE(succeeded(chipset));
28+
EXPECT_EQ(chipset->majorVersion, 11u);
29+
EXPECT_EQ(chipset->minorVersion, 0x03u);
30+
}
31+
32+
TEST(ChipsetTest, ParsingInvalid) {
33+
EXPECT_TRUE(failed(Chipset::parse("navi33")));
34+
EXPECT_TRUE(failed(Chipset::parse("rdna2")));
35+
EXPECT_TRUE(failed(Chipset::parse("sm_80")));
36+
EXPECT_TRUE(failed(Chipset::parse("GFX940")));
37+
EXPECT_TRUE(failed(Chipset::parse("Gfx940")));
38+
EXPECT_TRUE(failed(Chipset::parse("gfx9")));
39+
EXPECT_TRUE(failed(Chipset::parse("gfx_940")));
40+
EXPECT_TRUE(failed(Chipset::parse("gfx940_")));
41+
EXPECT_TRUE(failed(Chipset::parse("gfxmeow")));
42+
EXPECT_TRUE(failed(Chipset::parse("gfx1fff")));
43+
}
44+
45+
TEST(ChipsetTest, Comparison) {
46+
EXPECT_EQ(Chipset(9, 0x40), Chipset(9, 0x40));
47+
EXPECT_NE(Chipset(9, 0x40), Chipset(9, 0x42));
48+
EXPECT_NE(Chipset(9, 0x00), Chipset(10, 0x00));
49+
50+
EXPECT_LT(Chipset(9, 0x00), Chipset(10, 0x00));
51+
EXPECT_LT(Chipset(9, 0x0a), Chipset(9, 0x42));
52+
EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x42));
53+
EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x40));
54+
}
55+
56+
} // namespace
57+
} // namespace mlir::amdgpu
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_mlir_unittest(MLIRAMDGPUTests
2+
AMDGPUUtilsTest.cpp
3+
)
4+
target_link_libraries(MLIRAMDGPUTests
5+
PRIVATE
6+
MLIRAMDGPUUtils
7+
)

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
66
MLIRIR
77
MLIRDialect)
88

9+
add_subdirectory(AMDGPU)
910
add_subdirectory(ArmSME)
1011
add_subdirectory(Index)
1112
add_subdirectory(LLVMIR)

utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,19 @@ cc_test(
137137
],
138138
)
139139

140+
cc_test(
141+
name = "amdgpu_tests",
142+
size = "small",
143+
srcs = glob([
144+
"Dialect/AMDGPU/*.cpp",
145+
]),
146+
deps = [
147+
"//mlir:AMDGPUUtils",
148+
"//third-party/unittest:gtest",
149+
"//third-party/unittest:gtest_main",
150+
],
151+
)
152+
140153
cc_test(
141154
name = "memref_tests",
142155
size = "small",

0 commit comments

Comments
 (0)