diff --git a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h index 0e2708b1efae0..a5dab1ab89630 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Utils/Chipset.h @@ -9,39 +9,44 @@ #define MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_ #include "mlir/Support/LLVM.h" -#include +#include namespace mlir::amdgpu { /// Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103. /// Note that the leading digits form a decimal number, while the last two /// digits for a hexadecimal number. For example: -/// gfx942 --> major = 9, minor = 0x42 -/// gfx90a --> major = 9, minor = 0xa -/// gfx1103 --> major = 10, minor = 0x3 +/// gfx942 --> major = 9, minor = 0x4, stepping = 0x2 +/// gfx90a --> major = 9, minor = 0x0, stepping = 0xa +/// gfx1103 --> major = 10, minor = 0x0, stepping = 0x3 struct Chipset { - Chipset() = default; - Chipset(unsigned majorVersion, unsigned minorVersion) - : majorVersion(majorVersion), minorVersion(minorVersion){}; + unsigned majorVersion = 0; // The major version (decimal). + unsigned minorVersion = 0; // The minor version (hexadecimal). + unsigned steppingVersion = 0; // The stepping version (hexadecimal). + + constexpr Chipset() = default; + constexpr Chipset(unsigned major, unsigned minor, unsigned stepping) + : majorVersion(major), minorVersion(minor), steppingVersion(stepping) {}; /// Parses the chipset version string and returns the chipset on success, and /// failure otherwise. static FailureOr parse(StringRef name); - friend bool operator==(const Chipset &lhs, const Chipset &rhs) { - return lhs.majorVersion == rhs.majorVersion && - lhs.minorVersion == rhs.minorVersion; - } - friend bool operator!=(const Chipset &lhs, const Chipset &rhs) { - return !(lhs == rhs); - } - friend bool operator<(const Chipset &lhs, const Chipset &rhs) { - return std::make_pair(lhs.majorVersion, lhs.minorVersion) < - std::make_pair(rhs.majorVersion, rhs.minorVersion); + std::tuple asTuple() const { + return {majorVersion, minorVersion, steppingVersion}; } - unsigned majorVersion = 0; // The major version (decimal). - unsigned minorVersion = 0; // The minor version (hexadecimal). +#define DEFINE_COMP_OPERATOR(OPERATOR) \ + friend bool operator OPERATOR(const Chipset &lhs, const Chipset &rhs) { \ + return lhs.asTuple() OPERATOR rhs.asTuple(); \ + } + DEFINE_COMP_OPERATOR(==) + DEFINE_COMP_OPERATOR(!=) + DEFINE_COMP_OPERATOR(<) + DEFINE_COMP_OPERATOR(<=) + DEFINE_COMP_OPERATOR(>) + DEFINE_COMP_OPERATOR(>=) +#undef DEFINE_COMP_OPERATOR }; } // namespace mlir::amdgpu diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 7e407f1ca528d..96b433294d258 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinTypes.h" @@ -42,6 +43,11 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, } namespace { +// Define commonly used chipsets versions for convenience. +constexpr Chipset kGfx908 = Chipset(9, 0, 8); +constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); +constexpr Chipset kGfx940 = Chipset(9, 4, 0); + /// Define lowering patterns for raw buffer ops template struct RawBufferOpLowering : public ConvertOpToLLVMPattern { @@ -278,10 +284,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - bool requiresInlineAsm = - chipset.majorVersion < 9 || - (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) || - (chipset.majorVersion == 11); + bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11; if (requiresInlineAsm) { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), @@ -465,7 +468,7 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, destElem = destType.getElementType(); if (sourceElem.isF32() && destElem.isF32()) { - if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) { + if (mfma.getReducePrecision() && chipset >= kGfx940) { if (m == 32 && n == 32 && k == 4 && b == 1) return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); if (m == 16 && n == 16 && k == 8 && b == 1) @@ -496,7 +499,7 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, return ROCDL::mfma_f32_16x16x16f16::getOperationName(); } - if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) { + if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) @@ -533,21 +536,20 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, return ROCDL::mfma_i32_32x32x8i8::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_i32_16x16x16i8::getOperationName(); - if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40) + if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940) return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); - if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40) + if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940) return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); } - if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) { + if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) { if (m == 16 && n == 16 && k == 4 && b == 1) return ROCDL::mfma_f64_16x16x4f64::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 4) return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } - if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && - chipset.minorVersion >= 0x40) { + if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) { // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = @@ -566,8 +568,7 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, } } - if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && - chipset.minorVersion >= 0x40) { + if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) { Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { @@ -631,12 +632,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern { if (outVecType.getElementType().isBF16()) intrinsicOutType = outVecType.clone(rewriter.getI16Type()); - if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08) + if (chipset.majorVersion != 9 || chipset < kGfx908) return op->emitOpError("MFMA only supported on gfx908+"); uint32_t getBlgpField = static_cast(op.getBlgp()); if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { - if (chipset.minorVersion < 0x40) - return op.emitOpError("negation unsupported on older than gfx840"); + if (chipset < kGfx940) + return op.emitOpError("negation unsupported on older than gfx940"); getBlgpField |= op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); } @@ -741,7 +742,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + if (chipset.majorVersion != 9 || chipset < kGfx940) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); @@ -785,7 +786,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + if (chipset.majorVersion != 9 || chipset < kGfx940) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); @@ -822,7 +823,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) + if (chipset.majorVersion != 9 || chipset < kGfx940) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index d36583c8118ff..6b27ec9947cb0 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -384,7 +384,7 @@ void ArithToAMDGPUConversionPass::runOnOperation() { } bool convertFP8Arithmetic = - (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40; + maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0); arith::populateArithToAMDGPUConversionPatterns( patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, *maybeChipset); diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index f89e2537897e8..21042aff529c9 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -9,12 +9,12 @@ #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::amdgpu { #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS @@ -146,13 +146,12 @@ LogicalResult RawBufferAtomicByCasPattern::matchAndRewrite( void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns( ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) { // gfx10 has no atomic adds. - if (chipset.majorVersion == 10 || chipset.majorVersion < 9 || - (chipset.majorVersion == 9 && chipset.minorVersion < 0x08)) { + if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) { target.addIllegalOp(); } // gfx9 has no to a very limited support for floating-point min and max. if (chipset.majorVersion == 9) { - if (chipset.minorVersion >= 0x0a && chipset.minorVersion != 0x41) { + if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) { // gfx90a supports f64 max (and min, but we don't have a min wrapper right // now) but all other types need to be emulated. target.addDynamicallyLegalOp( @@ -162,7 +161,7 @@ void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns( } else { target.addIllegalOp(); } - if (chipset.minorVersion == 0x41) { + if (chipset == Chipset(9, 4, 1)) { // gfx941 requires non-CAS atomics to be implemented with CAS loops. // The workaround here mirrors HIP and OpenMP. target.addIllegalOp Chipset::parse(StringRef name) { unsigned major = 0; unsigned minor = 0; + unsigned stepping = 0; StringRef majorRef = name.drop_back(2); - StringRef minorRef = name.take_back(2); + StringRef minorRef = name.take_back(2).drop_back(1); + StringRef steppingRef = name.take_back(1); if (majorRef.getAsInteger(10, major)) return failure(); if (minorRef.getAsInteger(16, minor)) return failure(); - return Chipset(major, minor); + if (steppingRef.getAsInteger(16, stepping)) + return failure(); + return Chipset(major, minor, stepping); } } // namespace mlir::amdgpu diff --git a/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp b/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp index b08b6681235d3..976ff2e7382ed 100644 --- a/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp +++ b/mlir/unittests/Dialect/AMDGPU/AMDGPUUtilsTest.cpp @@ -16,17 +16,20 @@ TEST(ChipsetTest, Parsing) { FailureOr chipset = Chipset::parse("gfx90a"); ASSERT_TRUE(succeeded(chipset)); EXPECT_EQ(chipset->majorVersion, 9u); - EXPECT_EQ(chipset->minorVersion, 0x0au); + EXPECT_EQ(chipset->minorVersion, 0u); + EXPECT_EQ(chipset->steppingVersion, 0xau); chipset = Chipset::parse("gfx940"); ASSERT_TRUE(succeeded(chipset)); EXPECT_EQ(chipset->majorVersion, 9u); - EXPECT_EQ(chipset->minorVersion, 0x40u); + EXPECT_EQ(chipset->minorVersion, 4u); + EXPECT_EQ(chipset->steppingVersion, 0u); chipset = Chipset::parse("gfx1103"); ASSERT_TRUE(succeeded(chipset)); EXPECT_EQ(chipset->majorVersion, 11u); - EXPECT_EQ(chipset->minorVersion, 0x03u); + EXPECT_EQ(chipset->minorVersion, 0u); + EXPECT_EQ(chipset->steppingVersion, 3u); } TEST(ChipsetTest, ParsingInvalid) { @@ -43,14 +46,20 @@ TEST(ChipsetTest, ParsingInvalid) { } TEST(ChipsetTest, Comparison) { - EXPECT_EQ(Chipset(9, 0x40), Chipset(9, 0x40)); - EXPECT_NE(Chipset(9, 0x40), Chipset(9, 0x42)); - EXPECT_NE(Chipset(9, 0x00), Chipset(10, 0x00)); - - EXPECT_LT(Chipset(9, 0x00), Chipset(10, 0x00)); - EXPECT_LT(Chipset(9, 0x0a), Chipset(9, 0x42)); - EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x42)); - EXPECT_FALSE(Chipset(9, 0x42) < Chipset(9, 0x40)); + EXPECT_EQ(Chipset(9, 4, 0), Chipset(9, 4, 0)); + EXPECT_NE(Chipset(9, 4, 0), Chipset(9, 4, 2)); + EXPECT_NE(Chipset(9, 0, 0), Chipset(10, 0, 0)); + + EXPECT_LT(Chipset(9, 0, 0), Chipset(10, 0, 0)); + EXPECT_LT(Chipset(9, 0, 0), Chipset(9, 4, 2)); + EXPECT_LE(Chipset(9, 4, 1), Chipset(9, 4, 1)); + EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 2)); + EXPECT_FALSE(Chipset(9, 4, 2) < Chipset(9, 4, 0)); + + EXPECT_GT(Chipset(9, 0, 0xa), Chipset(9, 0, 8)); + EXPECT_GE(Chipset(9, 0, 0xa), Chipset(9, 0, 0xa)); + EXPECT_FALSE(Chipset(9, 4, 1) >= Chipset(9, 4, 2)); + EXPECT_FALSE(Chipset(9, 0, 0xa) >= Chipset(9, 4, 0)); } } // namespace