Skip to content

[mlir][amdgpu] Align Chipset with TargetParser #107720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,44 @@
#define MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_

#include "mlir/Support/LLVM.h"
#include <utility>
#include <tuple>

namespace mlir::amdgpu {

/// Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
/// Note that the leading digits form a decimal number, while the last two
/// digits for a hexadecimal number. For example:
/// gfx942 --> major = 9, minor = 0x42
/// gfx90a --> major = 9, minor = 0xa
/// gfx1103 --> major = 10, minor = 0x3
/// gfx942 --> major = 9, minor = 0x4, stepping = 0x2
/// gfx90a --> major = 9, minor = 0x0, stepping = 0xa
/// gfx1103 --> major = 10, minor = 0x0, stepping = 0x3
struct Chipset {
Chipset() = default;
Chipset(unsigned majorVersion, unsigned minorVersion)
: majorVersion(majorVersion), minorVersion(minorVersion){};
unsigned majorVersion = 0; // The major version (decimal).
unsigned minorVersion = 0; // The minor version (hexadecimal).
unsigned steppingVersion = 0; // The stepping version (hexadecimal).

constexpr Chipset() = default;
constexpr Chipset(unsigned major, unsigned minor, unsigned stepping)
: majorVersion(major), minorVersion(minor), steppingVersion(stepping) {};

/// Parses the chipset version string and returns the chipset on success, and
/// failure otherwise.
static FailureOr<Chipset> parse(StringRef name);

friend bool operator==(const Chipset &lhs, const Chipset &rhs) {
return lhs.majorVersion == rhs.majorVersion &&
lhs.minorVersion == rhs.minorVersion;
}
friend bool operator!=(const Chipset &lhs, const Chipset &rhs) {
return !(lhs == rhs);
}
friend bool operator<(const Chipset &lhs, const Chipset &rhs) {
return std::make_pair(lhs.majorVersion, lhs.minorVersion) <
std::make_pair(rhs.majorVersion, rhs.minorVersion);
std::tuple<unsigned, unsigned, unsigned> asTuple() const {
return {majorVersion, minorVersion, steppingVersion};
}

unsigned majorVersion = 0; // The major version (decimal).
unsigned minorVersion = 0; // The minor version (hexadecimal).
#define DEFINE_COMP_OPERATOR(OPERATOR) \
friend bool operator OPERATOR(const Chipset &lhs, const Chipset &rhs) { \
return lhs.asTuple() OPERATOR rhs.asTuple(); \
}
DEFINE_COMP_OPERATOR(==)
DEFINE_COMP_OPERATOR(!=)
DEFINE_COMP_OPERATOR(<)
DEFINE_COMP_OPERATOR(<=)
DEFINE_COMP_OPERATOR(>)
DEFINE_COMP_OPERATOR(>=)
#undef DEFINE_COMP_OPERATOR
};

} // namespace mlir::amdgpu
Expand Down
39 changes: 20 additions & 19 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -42,6 +43,11 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
}

namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx940 = Chipset(9, 4, 0);

/// Define lowering patterns for raw buffer ops
template <typename GpuOp, typename Intrinsic>
struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Expand Down Expand Up @@ -278,10 +284,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
LogicalResult
matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool requiresInlineAsm =
chipset.majorVersion < 9 ||
(chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) ||
(chipset.majorVersion == 11);
bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;

if (requiresInlineAsm) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
Expand Down Expand Up @@ -465,7 +468,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
destElem = destType.getElementType();

if (sourceElem.isF32() && destElem.isF32()) {
if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
if (mfma.getReducePrecision() && chipset >= kGfx940) {
if (m == 32 && n == 32 && k == 4 && b == 1)
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
if (m == 16 && n == 16 && k == 8 && b == 1)
Expand Down Expand Up @@ -496,7 +499,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f32_16x16x16f16::getOperationName();
}

if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
Expand Down Expand Up @@ -533,21 +536,20 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_i32_32x32x8i8::getOperationName();
if (m == 16 && n == 16 && k == 16 && b == 1)
return ROCDL::mfma_i32_16x16x16i8::getOperationName();
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
}

if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
if (m == 16 && n == 16 && k == 4 && b == 1)
return ROCDL::mfma_f64_16x16x4f64::getOperationName();
if (m == 4 && n == 4 && k == 4 && b == 4)
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}

if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
chipset.minorVersion >= 0x40) {
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
Expand All @@ -566,8 +568,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
}
}

if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
chipset.minorVersion >= 0x40) {
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
Expand Down Expand Up @@ -631,12 +632,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
if (outVecType.getElementType().isBF16())
intrinsicOutType = outVecType.clone(rewriter.getI16Type());

if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
if (chipset.majorVersion != 9 || chipset < kGfx908)
return op->emitOpError("MFMA only supported on gfx908+");
uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
if (chipset.minorVersion < 0x40)
return op.emitOpError("negation unsupported on older than gfx840");
if (chipset < kGfx940)
return op.emitOpError("negation unsupported on older than gfx940");
getBlgpField |=
op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
}
Expand Down Expand Up @@ -741,7 +742,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
if (chipset.majorVersion != 9 || chipset < kGfx940)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down Expand Up @@ -785,7 +786,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
if (chipset.majorVersion != 9 || chipset < kGfx940)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down Expand Up @@ -822,7 +823,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
if (chipset.majorVersion != 9 || chipset < kGfx940)
return rewriter.notifyMatchFailure(
loc, "Fp8 conversion instructions are not available on target "
"architecture and their emulation is not implemented");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
}

bool convertFP8Arithmetic =
(*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40;
maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
Expand Down Expand Up @@ -146,13 +146,12 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
// gfx10 has no atomic adds.
if (chipset.majorVersion == 10 || chipset.majorVersion < 9 ||
(chipset.majorVersion == 9 && chipset.minorVersion < 0x08)) {
if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) {
target.addIllegalOp<RawBufferAtomicFaddOp>();
}
// gfx9 has no to a very limited support for floating-point min and max.
if (chipset.majorVersion == 9) {
if (chipset.minorVersion >= 0x0a && chipset.minorVersion != 0x41) {
if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
// gfx90a supports f64 max (and min, but we don't have a min wrapper right
// now) but all other types need to be emulated.
target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
Expand All @@ -162,7 +161,7 @@ void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
} else {
target.addIllegalOp<RawBufferAtomicFmaxOp>();
}
if (chipset.minorVersion == 0x41) {
if (chipset == Chipset(9, 4, 1)) {
// gfx941 requires non-CAS atomics to be implemented with CAS loops.
// The workaround here mirrors HIP and OpenMP.
target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/AMDGPU/Utils/Chipset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@ FailureOr<Chipset> Chipset::parse(StringRef name) {

unsigned major = 0;
unsigned minor = 0;
unsigned stepping = 0;

StringRef majorRef = name.drop_back(2);
StringRef minorRef = name.take_back(2);
StringRef minorRef = name.take_back(2).drop_back(1);
StringRef steppingRef = name.take_back(1);
if (majorRef.getAsInteger(10, major))
return failure();
if (minorRef.getAsInteger(16, minor))
return failure();
return Chipset(major, minor);
if (steppingRef.getAsInteger(16, stepping))
return failure();
return Chipset(major, minor, stepping);
}

} // namespace mlir::amdgpu
31 changes: 20 additions & 11 deletions mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@ TEST(ChipsetTest, Parsing) {
FailureOr<Chipset> chipset = Chipset::parse("gfx90a");
ASSERT_TRUE(succeeded(chipset));
EXPECT_EQ(chipset->majorVersion, 9u);
EXPECT_EQ(chipset->minorVersion, 0x0au);
EXPECT_EQ(chipset->minorVersion, 0u);
EXPECT_EQ(chipset->steppingVersion, 0xau);

chipset = Chipset::parse("gfx940");
ASSERT_TRUE(succeeded(chipset));
EXPECT_EQ(chipset->majorVersion, 9u);
EXPECT_EQ(chipset->minorVersion, 0x40u);
EXPECT_EQ(chipset->minorVersion, 4u);
EXPECT_EQ(chipset->steppingVersion, 0u);

chipset = Chipset::parse("gfx1103");
ASSERT_TRUE(succeeded(chipset));
EXPECT_EQ(chipset->majorVersion, 11u);
EXPECT_EQ(chipset->minorVersion, 0x03u);
EXPECT_EQ(chipset->minorVersion, 0u);
EXPECT_EQ(chipset->steppingVersion, 3u);
}

TEST(ChipsetTest, ParsingInvalid) {
Expand All @@ -43,14 +46,20 @@ TEST(ChipsetTest, ParsingInvalid) {
}

TEST(ChipsetTest, Comparison) {
EXPECT_EQ(Chipset(9, 0x40), Chipset(9, 0x40));
EXPECT_NE(Chipset(9, 0x40), Chipset(9, 0x42));
EXPECT_NE(Chipset(9, 0x00), Chipset(10, 0x00));

EXPECT_LT(Chipset(9, 0x00), Chipset(10, 0x00));
EXPECT_LT(Chipset(9, 0x0a), Chipset(9, 0x42));
EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x42));
EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x40));
EXPECT_EQ(Chipset(9, 4, 0), Chipset(9, 4, 0));
EXPECT_NE(Chipset(9, 4, 0), Chipset(9, 4, 2));
EXPECT_NE(Chipset(9, 0, 0), Chipset(10, 0, 0));

EXPECT_LT(Chipset(9, 0, 0), Chipset(10, 0, 0));
EXPECT_LT(Chipset(9, 0, 0), Chipset(9, 4, 2));
EXPECT_LE(Chipset(9, 4, 1), Chipset(9, 4, 1));
EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 2));
EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 0));

EXPECT_GT(Chipset(9, 0, 0xa), Chipset(9, 0, 8));
EXPECT_GE(Chipset(9, 0, 0xa), Chipset(9, 0, 0xa));
EXPECT_FALSE(Chipset(9, 4, 1) >= Chipset(9, 4, 2));
EXPECT_FALSE(Chipset(9, 0, 0xa) >= Chipset(9, 4, 0));
}

} // namespace
Expand Down
Loading