12
12
#include " mlir/Conversion/LLVMCommon/Pattern.h"
13
13
#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
14
14
#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15
+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
15
16
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
16
17
#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
17
18
#include " mlir/IR/BuiltinTypes.h"
@@ -42,6 +43,11 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
42
43
}
43
44
44
45
namespace {
46
+ // Define commonly used chipsets versions for convenience.
47
+ constexpr Chipset kGfx908 = Chipset(9 , 0 , 8 );
48
+ constexpr Chipset kGfx90a = Chipset(9 , 0 , 0xa );
49
+ constexpr Chipset kGfx940 = Chipset(9 , 4 , 0 );
50
+
45
51
// / Define lowering patterns for raw buffer ops
46
52
template <typename GpuOp, typename Intrinsic>
47
53
struct RawBufferOpLowering : public ConvertOpToLLVMPattern <GpuOp> {
@@ -278,10 +284,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
278
284
LogicalResult
279
285
matchAndRewrite (LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
280
286
ConversionPatternRewriter &rewriter) const override {
281
- bool requiresInlineAsm =
282
- chipset.majorVersion < 9 ||
283
- (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a ) ||
284
- (chipset.majorVersion == 11 );
287
+ bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11 ;
285
288
286
289
if (requiresInlineAsm) {
287
290
auto asmDialectAttr = LLVM::AsmDialectAttr::get (rewriter.getContext (),
@@ -465,7 +468,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
465
468
destElem = destType.getElementType ();
466
469
467
470
if (sourceElem.isF32 () && destElem.isF32 ()) {
468
- if (mfma.getReducePrecision () && chipset. minorVersion >= 0x40 ) {
471
+ if (mfma.getReducePrecision () && chipset >= kGfx940 ) {
469
472
if (m == 32 && n == 32 && k == 4 && b == 1 )
470
473
return ROCDL::mfma_f32_32x32x4_xf32::getOperationName ();
471
474
if (m == 16 && n == 16 && k == 8 && b == 1 )
@@ -496,7 +499,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
496
499
return ROCDL::mfma_f32_16x16x16f16::getOperationName ();
497
500
}
498
501
499
- if (sourceElem.isBF16 () && destElem.isF32 () && chipset. minorVersion >= 0x0a ) {
502
+ if (sourceElem.isBF16 () && destElem.isF32 () && chipset >= kGfx90a ) {
500
503
if (m == 32 && n == 32 && k == 4 && b == 2 )
501
504
return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName ();
502
505
if (m == 16 && n == 16 && k == 4 && b == 4 )
@@ -533,21 +536,20 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
533
536
return ROCDL::mfma_i32_32x32x8i8::getOperationName ();
534
537
if (m == 16 && n == 16 && k == 16 && b == 1 )
535
538
return ROCDL::mfma_i32_16x16x16i8::getOperationName ();
536
- if (m == 32 && n == 32 && k == 16 && b == 1 && chipset. minorVersion >= 0x40 )
539
+ if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940 )
537
540
return ROCDL::mfma_i32_32x32x16_i8::getOperationName ();
538
- if (m == 16 && n == 16 && k == 32 && b == 1 && chipset. minorVersion >= 0x40 )
541
+ if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940 )
539
542
return ROCDL::mfma_i32_16x16x32_i8::getOperationName ();
540
543
}
541
544
542
- if (sourceElem.isF64 () && destElem.isF64 () && chipset. minorVersion >= 0x0a ) {
545
+ if (sourceElem.isF64 () && destElem.isF64 () && chipset >= kGfx90a ) {
543
546
if (m == 16 && n == 16 && k == 4 && b == 1 )
544
547
return ROCDL::mfma_f64_16x16x4f64::getOperationName ();
545
548
if (m == 4 && n == 4 && k == 4 && b == 4 )
546
549
return ROCDL::mfma_f64_4x4x4f64::getOperationName ();
547
550
}
548
551
549
- if (sourceElem.isFloat8E5M2FNUZ () && destElem.isF32 () &&
550
- chipset.minorVersion >= 0x40 ) {
552
+ if (sourceElem.isFloat8E5M2FNUZ () && destElem.isF32 () && chipset >= kGfx940 ) {
551
553
// Known to be correct because there are no scalar f8 instructions and
552
554
// because a length mismatch will have been caught by the verifier.
553
555
Type sourceBElem =
@@ -566,8 +568,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
566
568
}
567
569
}
568
570
569
- if (sourceElem.isFloat8E4M3FNUZ () && destElem.isF32 () &&
570
- chipset.minorVersion >= 0x40 ) {
571
+ if (sourceElem.isFloat8E4M3FNUZ () && destElem.isF32 () && chipset >= kGfx940 ) {
571
572
Type sourceBElem =
572
573
cast<VectorType>(mfma.getSourceB ().getType ()).getElementType ();
573
574
if (m == 16 && n == 16 && k == 32 && b == 1 ) {
@@ -631,12 +632,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
631
632
if (outVecType.getElementType ().isBF16 ())
632
633
intrinsicOutType = outVecType.clone (rewriter.getI16Type ());
633
634
634
- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x08 )
635
+ if (chipset.majorVersion != 9 || chipset < kGfx908 )
635
636
return op->emitOpError (" MFMA only supported on gfx908+" );
636
637
uint32_t getBlgpField = static_cast <uint32_t >(op.getBlgp ());
637
638
if (op.getNegateA () || op.getNegateB () || op.getNegateC ()) {
638
- if (chipset. minorVersion < 0x40 )
639
- return op.emitOpError (" negation unsupported on older than gfx840 " );
639
+ if (chipset < kGfx940 )
640
+ return op.emitOpError (" negation unsupported on older than gfx940 " );
640
641
getBlgpField |=
641
642
op.getNegateA () | (op.getNegateB () << 1 ) | (op.getNegateC () << 2 );
642
643
}
@@ -741,7 +742,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
741
742
ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
742
743
ConversionPatternRewriter &rewriter) const {
743
744
Location loc = op.getLoc ();
744
- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x40 )
745
+ if (chipset.majorVersion != 9 || chipset < kGfx940 )
745
746
return rewriter.notifyMatchFailure (
746
747
loc, " Fp8 conversion instructions are not available on target "
747
748
" architecture and their emulation is not implemented" );
@@ -785,7 +786,7 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
785
786
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
786
787
ConversionPatternRewriter &rewriter) const {
787
788
Location loc = op.getLoc ();
788
- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x40 )
789
+ if (chipset.majorVersion != 9 || chipset < kGfx940 )
789
790
return rewriter.notifyMatchFailure (
790
791
loc, " Fp8 conversion instructions are not available on target "
791
792
" architecture and their emulation is not implemented" );
@@ -822,7 +823,7 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
822
823
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
823
824
ConversionPatternRewriter &rewriter) const {
824
825
Location loc = op.getLoc ();
825
- if (chipset.majorVersion != 9 || chipset. minorVersion < 0x40 )
826
+ if (chipset.majorVersion != 9 || chipset < kGfx940 )
826
827
return rewriter.notifyMatchFailure (
827
828
loc, " Fp8 conversion instructions are not available on target "
828
829
" architecture and their emulation is not implemented" );
0 commit comments