diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index 7093eeea68948..57cd1a3806c2e 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -728,4 +728,24 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> { let hasVerifier = 1; } +def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> { + let description = [{ + The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result + in $matrixD to given memref. + + [See the details of register fragment layout for accumulator matrix D] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + + Note that, the op must be run with warp group. + }]; + + let arguments = (ins Variadic:$matrixD, + Arg:$dstMemref); + + let assemblyFormat = [{ + `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref) + }]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt index 00e775ce7dd22..a050749eb7da8 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM MLIRLLVMDialect MLIRNVGPUDialect MLIRNVVMDialect + MLIRArithDialect MLIRPass MLIRSCFTransforms MLIRTransforms diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index cce525dcdcbe2..99c4d42233513 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -11,6 +11,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -394,8 +395,8 @@ struct ConvertNVGPUToNVVMPass using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() override { @@ -436,6 +437,7 @@ struct ConvertNVGPUToNVVMPass populateNVGPUToNVVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::arith::ArithDialect>(); target.addLegalDialect<::mlir::memref::MemRefDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); mlir::scf::populateSCFStructuralTypeConversionsAndLegality( @@ -1434,6 +1436,116 @@ struct NVGPUWarpgroupMmaOpLowering } }; +struct NVGPUWarpgroupMmaStoreOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; + + /// This function stores a fragmented register matrix owned by a warp group + /// (128 threads) into a memref. Each thread has 64 registers, each the size + /// of a struct. + /// Here is what each threads (T) holds, each `d` is struct value with a + /// number. + /// + /// Threads in warp-group (128 threads) and what they owns in the matrixD: + /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N] + /// 32-63 Warp-1 -> MatrixD[16:31][0:N] + /// 64-95 Warp-2 -> MatrixD[32:47][0:N] + /// 96-127 Warp-3 -> MatrixD[48:64][0:N] + /// + /// Matrix-D: + /// +______________________________________________________________________+ + /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 | + /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY| + /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW| + /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........| + /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........| + /// ..| .........|.........|.........|.........|........|...........|........| + /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........| + /// ..| .........|.........|.........|.........|........|...........|........| + /// +______________________________________________________________________+ + /// + /// \param rewriter: The pattern rewriter. + /// \param matrixD: Result of the warp-group MMA operation (fragmented + /// matrix). It is holded by a thread and a struct with 64 elements. + /// \param dstMemref: The memref where the registers will be stored. + /// \param offset: the offset within the memref where the registers will be + /// stored. + void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, + TypedValue dstMemref, + int offset) const { + Type i32 = b.getI32Type(); + + auto makeConst = [&](int32_t index) -> Value { + return b.create(i32, b.getI32IntegerAttr(index)); + }; + Value c1 = makeConst(1); + Value c2 = makeConst(2); + Value c4 = makeConst(4); + Value c8 = makeConst(8); + Value c16 = makeConst(16); + Value warpSize = makeConst(kWarpSize); + + auto makeMul = [&](Value lhs, Value rhs) -> Value { + return b.create(lhs.getType(), lhs, rhs); + }; + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return b.create(lhs.getType(), lhs, rhs); + }; + + Value tidx = b.create(i32); + Value laneId = b.create(i32, tidx, warpSize); + Value warpId = b.create(i32, tidx, warpSize); + Value lane4Id = b.create(i32, laneId, c4); + Value lane4modId = b.create(i32, laneId, c4); + + auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, + TypedValue<::mlir::MemRefType> memref) { + Type it = b.getIndexType(); + Value idx = b.create(it, x); + Value idy0 = b.create(it, y); + Value idy1 = b.create(it, makeAdd(y, c1)); + Value d0 = b.create(wgmmaResult, i); + Value d1 = b.create(wgmmaResult, i + 1); + b.create(d0, memref, ValueRange{idx, idy0}); + b.create(d1, memref, ValueRange{idx, idy1}); + }; + + Value tj = makeMul(lane4modId, c2); + Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); + if (offset) + ti = makeAdd(ti, makeConst(offset)); + for (int i = 0; i < 2; ++i) { + Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); + for (int j = 0; j < 16; ++j) { + Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); + int sIndex = i * 2 + j * 4; + makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref); + } + } + } + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int offset = 0; + ImplicitLocOpBuilder lb(op->getLoc(), rewriter); + for (Value matrixD : adaptor.getMatrixD()) { + auto structType = matrixD.getType().cast(); + storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset); + offset += structType.getBody().size(); + } + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, @@ -1450,6 +1562,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma + NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 0ee0f70eebcf6..e8ecd0faa4c86 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -529,6 +530,39 @@ LogicalResult WarpgroupMmaOp::verify() { return success(); } +LogicalResult WarpgroupMmaStoreOp::verify() { + MemRefType dstMemrefType = getDstMemref().getType(); + VectorType firstVtype = getMatrixD() + .front() + .getType() + .cast() + .getFragmented(); + + int64_t totalFirstDimension = 0; + for (Value result : getMatrixD()) { + VectorType vtype = + result.getType().cast().getFragmented(); + if (vtype != firstVtype) + return emitOpError() << "all fragmented types must be the same"; + // Limitation + if (!vtype.getElementType().isF32()) { + return emitOpError() + << "hit a limitation: only f32 results for the time being"; + } + totalFirstDimension += vtype.getDimSize(0); + } + if (totalFirstDimension != dstMemrefType.getDimSize(0) || + firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) { + return emitOpError() << "results [" << totalFirstDimension << "][" + << firstVtype.getDimSize(1) + << "] values. However, destination memref[" + << dstMemrefType.getDimSize(0) << "][" + << dstMemrefType.getDimSize(1) + << "] does not have same size as results"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 3710b06288e2a..e54b62a06d431 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -772,6 +772,135 @@ func.func @warpgroup_mma_128_128_64( return } +// CHECK-LABEL: @warpgroup_mma_store( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>) +func.func @warpgroup_mma_store( + %result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + %matrixD: memref<128x128xf32,3>) { +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32 +// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: %[[WarpSize:.+]] = llvm.mlir.constant(32 : i32) : i32 + +// ### Store {d0, d1} of each thread ### + +// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32 +// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[WarpSize]] : i32 +// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[WarpSize]] : i32 +// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32 +// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32 +// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32 +// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32 +// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32 +// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32 +// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32 +// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32 +// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32 +// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index +// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32 +// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index +// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct +// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct +// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3> + +// ### Store {d2, d3} of each thread ### + +// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32 +// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32 +// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index +// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32 +// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index +// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct< +// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct< +// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3> + +// ### Store {d4, d5} of each thread ### + +// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32 +// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32 +// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index +// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32 +// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index +// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct< +// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct< +// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3> + +// ### Store {d6, d7} of each thread ### + +// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32 +// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32 +// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index +// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32 +// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index +// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct< +// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct< +// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3> + +// Pattern continues similarly 28x times until {... d62, d63} + +// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32 + +// ### Store {d64, d65} of each thread ### + +// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32 +// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: %[[WS2:.+]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32 +// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[WS2]] : i32 +// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[WS2]] : i32 +// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]], %[[S311]] : i32 +// CHECK: %[[S321:.+]] = llvm.urem %[[S318]], %[[S311]] : i32 +// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S312]] : i32 +// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32 +// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32 +// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32 +// CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32 +// CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32 +// CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32 +// CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32 +// CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32 +// CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index +// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index +// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32 +// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index +// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0] +// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1] +// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3> + +// Pattern continues similarly 31x times until {... d126, d127} + + nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD : + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> + to memref<128x128xf32,3> + return +} + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index d1e3a9c0f5144..99c3402621b82 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5308,6 +5308,7 @@ cc_library( ":LLVMCommonConversion", ":LLVMDialect", ":MemRefDialect", + ":MLIRArithDialect", ":NVGPUDialect", ":NVVMDialect", ":Pass",