Skip to content

Commit 763bc92

Browse files
authored
[mlir][amdgpu] Align Chipset with TargetParser (#107720)
Update the Chipset struct to follow the `IsaVersion` definition from llvm's `TargetParser`. This is a follow up to #106169 (comment). * Add the stepping version. Note: This may break downstream code that compares against the minor version directly. * Use comparisons with full Chipset version where possible. Note that we can't use the code in `TargetParser` directly because the chipset utility is outside of `mlir/Target` that re-exports llvm's target library.
1 parent 6cc3bf7 commit 763bc92

File tree

6 files changed

+75
-57
lines changed

6 files changed

+75
-57
lines changed

mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,44 @@
99
#define MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_
1010

1111
#include "mlir/Support/LLVM.h"
12-
#include <utility>
12+
#include <tuple>
1313

1414
namespace mlir::amdgpu {
1515

1616
/// Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
1717
/// Note that the leading digits form a decimal number, while the last two
1818
/// digits for a hexadecimal number. For example:
19-
/// gfx942 --> major = 9, minor = 0x42
20-
/// gfx90a --> major = 9, minor = 0xa
21-
/// gfx1103 --> major = 10, minor = 0x3
19+
/// gfx942 --> major = 9, minor = 0x4, stepping = 0x2
20+
/// gfx90a --> major = 9, minor = 0x0, stepping = 0xa
21+
/// gfx1103 --> major = 10, minor = 0x0, stepping = 0x3
2222
struct Chipset {
23-
Chipset() = default;
24-
Chipset(unsigned majorVersion, unsigned minorVersion)
25-
: majorVersion(majorVersion), minorVersion(minorVersion){};
23+
unsigned majorVersion = 0; // The major version (decimal).
24+
unsigned minorVersion = 0; // The minor version (hexadecimal).
25+
unsigned steppingVersion = 0; // The stepping version (hexadecimal).
26+
27+
constexpr Chipset() = default;
28+
constexpr Chipset(unsigned major, unsigned minor, unsigned stepping)
29+
: majorVersion(major), minorVersion(minor), steppingVersion(stepping) {};
2630

2731
/// Parses the chipset version string and returns the chipset on success, and
2832
/// failure otherwise.
2933
static FailureOr<Chipset> parse(StringRef name);
3034

31-
friend bool operator==(const Chipset &lhs, const Chipset &rhs) {
32-
return lhs.majorVersion == rhs.majorVersion &&
33-
lhs.minorVersion == rhs.minorVersion;
34-
}
35-
friend bool operator!=(const Chipset &lhs, const Chipset &rhs) {
36-
return !(lhs == rhs);
37-
}
38-
friend bool operator<(const Chipset &lhs, const Chipset &rhs) {
39-
return std::make_pair(lhs.majorVersion, lhs.minorVersion) <
40-
std::make_pair(rhs.majorVersion, rhs.minorVersion);
35+
std::tuple<unsigned, unsigned, unsigned> asTuple() const {
36+
return {majorVersion, minorVersion, steppingVersion};
4137
}
4238

43-
unsigned majorVersion = 0; // The major version (decimal).
44-
unsigned minorVersion = 0; // The minor version (hexadecimal).
39+
#define DEFINE_COMP_OPERATOR(OPERATOR) \
40+
friend bool operator OPERATOR(const Chipset &lhs, const Chipset &rhs) { \
41+
return lhs.asTuple() OPERATOR rhs.asTuple(); \
42+
}
43+
DEFINE_COMP_OPERATOR(==)
44+
DEFINE_COMP_OPERATOR(!=)
45+
DEFINE_COMP_OPERATOR(<)
46+
DEFINE_COMP_OPERATOR(<=)
47+
DEFINE_COMP_OPERATOR(>)
48+
DEFINE_COMP_OPERATOR(>=)
49+
#undef DEFINE_COMP_OPERATOR
4550
};
4651

4752
} // namespace mlir::amdgpu

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1313
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1414
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
1718
#include "mlir/IR/BuiltinTypes.h"
@@ -42,6 +43,11 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
4243
}
4344

4445
namespace {
46+
// Define commonly used chipsets versions for convenience.
47+
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48+
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49+
constexpr Chipset kGfx940 = Chipset(9, 4, 0);
50+
4551
/// Define lowering patterns for raw buffer ops
4652
template <typename GpuOp, typename Intrinsic>
4753
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
@@ -278,10 +284,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
278284
LogicalResult
279285
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
280286
ConversionPatternRewriter &rewriter) const override {
281-
bool requiresInlineAsm =
282-
chipset.majorVersion < 9 ||
283-
(chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) ||
284-
(chipset.majorVersion == 11);
287+
bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
285288

286289
if (requiresInlineAsm) {
287290
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
@@ -465,7 +468,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
465468
destElem = destType.getElementType();
466469

467470
if (sourceElem.isF32() && destElem.isF32()) {
468-
if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
471+
if (mfma.getReducePrecision() && chipset >= kGfx940) {
469472
if (m == 32 && n == 32 && k == 4 && b == 1)
470473
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
471474
if (m == 16 && n == 16 && k == 8 && b == 1)
@@ -496,7 +499,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
496499
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
497500
}
498501

499-
if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
502+
if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
500503
if (m == 32 && n == 32 && k == 4 && b == 2)
501504
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
502505
if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -533,21 +536,20 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
533536
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
534537
if (m == 16 && n == 16 && k == 16 && b == 1)
535538
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
536-
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
539+
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
537540
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
538-
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
541+
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
539542
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
540543
}
541544

542-
if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
545+
if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
543546
if (m == 16 && n == 16 && k == 4 && b == 1)
544547
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
545548
if (m == 4 && n == 4 && k == 4 && b == 4)
546549
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
547550
}
548551

549-
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
550-
chipset.minorVersion >= 0x40) {
552+
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
551553
// Known to be correct because there are no scalar f8 instructions and
552554
// because a length mismatch will have been caught by the verifier.
553555
Type sourceBElem =
@@ -566,8 +568,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
566568
}
567569
}
568570

569-
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
570-
chipset.minorVersion >= 0x40) {
571+
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
571572
Type sourceBElem =
572573
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
573574
if (m == 16 && n == 16 && k == 32 && b == 1) {
@@ -631,12 +632,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
631632
if (outVecType.getElementType().isBF16())
632633
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
633634

634-
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
635+
if (chipset.majorVersion != 9 || chipset < kGfx908)
635636
return op->emitOpError("MFMA only supported on gfx908+");
636637
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
637638
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
638-
if (chipset.minorVersion < 0x40)
639-
return op.emitOpError("negation unsupported on older than gfx840");
639+
if (chipset < kGfx940)
640+
return op.emitOpError("negation unsupported on older than gfx940");
640641
getBlgpField |=
641642
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
642643
}
@@ -741,7 +742,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
741742
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
742743
ConversionPatternRewriter &rewriter) const {
743744
Location loc = op.getLoc();
744-
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
745+
if (chipset.majorVersion != 9 || chipset < kGfx940)
745746
return rewriter.notifyMatchFailure(
746747
loc, "Fp8 conversion instructions are not available on target "
747748
"architecture and their emulation is not implemented");
@@ -785,7 +786,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
785786
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
786787
ConversionPatternRewriter &rewriter) const {
787788
Location loc = op.getLoc();
788-
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
789+
if (chipset.majorVersion != 9 || chipset < kGfx940)
789790
return rewriter.notifyMatchFailure(
790791
loc, "Fp8 conversion instructions are not available on target "
791792
"architecture and their emulation is not implemented");
@@ -822,7 +823,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
822823
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
823824
ConversionPatternRewriter &rewriter) const {
824825
Location loc = op.getLoc();
825-
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
826+
if (chipset.majorVersion != 9 || chipset < kGfx940)
826827
return rewriter.notifyMatchFailure(
827828
loc, "Fp8 conversion instructions are not available on target "
828829
"architecture and their emulation is not implemented");

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
384384
}
385385

386386
bool convertFP8Arithmetic =
387-
(*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
387+
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
388388
arith::populateArithToAMDGPUConversionPatterns(
389389
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
390390
*maybeChipset);

mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
1010

1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1213
#include "mlir/Dialect/Arith/IR/Arith.h"
1314
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1415
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1516
#include "mlir/IR/BuiltinAttributes.h"
1617
#include "mlir/Transforms/DialectConversion.h"
17-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1818

1919
namespace mlir::amdgpu {
2020
#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
@@ -146,13 +146,12 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
146146
void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
147147
ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
148148
// gfx10 has no atomic adds.
149-
if (chipset.majorVersion == 10 || chipset.majorVersion < 9 ||
150-
(chipset.majorVersion == 9 && chipset.minorVersion < 0x08)) {
149+
if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) {
151150
target.addIllegalOp<RawBufferAtomicFaddOp>();
152151
}
153152
// gfx9 has no to a very limited support for floating-point min and max.
154153
if (chipset.majorVersion == 9) {
155-
if (chipset.minorVersion >= 0x0a && chipset.minorVersion != 0x41) {
154+
if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
156155
// gfx90a supports f64 max (and min, but we don't have a min wrapper right
157156
// now) but all other types need to be emulated.
158157
target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
@@ -162,7 +161,7 @@ void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
162161
} else {
163162
target.addIllegalOp<RawBufferAtomicFmaxOp>();
164163
}
165-
if (chipset.minorVersion == 0x41) {
164+
if (chipset == Chipset(9, 4, 1)) {
166165
// gfx941 requires non-CAS atomics to be implemented with CAS loops.
167166
// The workaround here mirrors HIP and OpenMP.
168167
target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,

mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@ FailureOr<Chipset> Chipset::parse(StringRef name) {
1919

2020
unsigned major = 0;
2121
unsigned minor = 0;
22+
unsigned stepping = 0;
2223

2324
StringRef majorRef = name.drop_back(2);
24-
StringRef minorRef = name.take_back(2);
25+
StringRef minorRef = name.take_back(2).drop_back(1);
26+
StringRef steppingRef = name.take_back(1);
2527
if (majorRef.getAsInteger(10, major))
2628
return failure();
2729
if (minorRef.getAsInteger(16, minor))
2830
return failure();
29-
return Chipset(major, minor);
31+
if (steppingRef.getAsInteger(16, stepping))
32+
return failure();
33+
return Chipset(major, minor, stepping);
3034
}
3135

3236
} // namespace mlir::amdgpu

mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,20 @@ TEST(ChipsetTest, Parsing) {
1616
FailureOr<Chipset> chipset = Chipset::parse("gfx90a");
1717
ASSERT_TRUE(succeeded(chipset));
1818
EXPECT_EQ(chipset->majorVersion, 9u);
19-
EXPECT_EQ(chipset->minorVersion, 0x0au);
19+
EXPECT_EQ(chipset->minorVersion, 0u);
20+
EXPECT_EQ(chipset->steppingVersion, 0xau);
2021

2122
chipset = Chipset::parse("gfx940");
2223
ASSERT_TRUE(succeeded(chipset));
2324
EXPECT_EQ(chipset->majorVersion, 9u);
24-
EXPECT_EQ(chipset->minorVersion, 0x40u);
25+
EXPECT_EQ(chipset->minorVersion, 4u);
26+
EXPECT_EQ(chipset->steppingVersion, 0u);
2527

2628
chipset = Chipset::parse("gfx1103");
2729
ASSERT_TRUE(succeeded(chipset));
2830
EXPECT_EQ(chipset->majorVersion, 11u);
29-
EXPECT_EQ(chipset->minorVersion, 0x03u);
31+
EXPECT_EQ(chipset->minorVersion, 0u);
32+
EXPECT_EQ(chipset->steppingVersion, 3u);
3033
}
3134

3235
TEST(ChipsetTest, ParsingInvalid) {
@@ -43,14 +46,20 @@ TEST(ChipsetTest, ParsingInvalid) {
4346
}
4447

4548
TEST(ChipsetTest, Comparison) {
46-
EXPECT_EQ(Chipset(9, 0x40), Chipset(9, 0x40));
47-
EXPECT_NE(Chipset(9, 0x40), Chipset(9, 0x42));
48-
EXPECT_NE(Chipset(9, 0x00), Chipset(10, 0x00));
49-
50-
EXPECT_LT(Chipset(9, 0x00), Chipset(10, 0x00));
51-
EXPECT_LT(Chipset(9, 0x0a), Chipset(9, 0x42));
52-
EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x42));
53-
EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x40));
49+
EXPECT_EQ(Chipset(9, 4, 0), Chipset(9, 4, 0));
50+
EXPECT_NE(Chipset(9, 4, 0), Chipset(9, 4, 2));
51+
EXPECT_NE(Chipset(9, 0, 0), Chipset(10, 0, 0));
52+
53+
EXPECT_LT(Chipset(9, 0, 0), Chipset(10, 0, 0));
54+
EXPECT_LT(Chipset(9, 0, 0), Chipset(9, 4, 2));
55+
EXPECT_LE(Chipset(9, 4, 1), Chipset(9, 4, 1));
56+
EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 2));
57+
EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 0));
58+
59+
EXPECT_GT(Chipset(9, 0, 0xa), Chipset(9, 0, 8));
60+
EXPECT_GE(Chipset(9, 0, 0xa), Chipset(9, 0, 0xa));
61+
EXPECT_FALSE(Chipset(9, 4, 1) >= Chipset(9, 4, 2));
62+
EXPECT_FALSE(Chipset(9, 0, 0xa) >= Chipset(9, 4, 0));
5463
}
5564

5665
} // namespace

0 commit comments

Comments
 (0)