diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index c6162cc1..bdbccd2d 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -94,6 +94,7 @@ static bool isDPASCompatible(linalg::LinalgOp linalgOp, int kTile, ArrayRef dpasTile) { if (!(isa(linalgOp) || isa(linalgOp) || + isa(linalgOp) || isa(linalgOp))) { return false; } @@ -633,12 +634,11 @@ static SmallVector updateTilesOffsets(PatternRewriter &rewriter, // // The descriptor sub-tiles are ordered in row-major fashion with respect to the // whole load tile. -static SmallVector createDescriptorTiles(PatternRewriter &rewriter, - Location loc, Value src, - ArrayRef loadShape, - ArrayRef loadOffsets, - ArrayRef descTile, - int arrayLength = 1) { +static SmallVector +createDescriptorTiles(PatternRewriter &rewriter, Location loc, Value src, + ArrayRef loadShape, + ArrayRef loadOffsets, ArrayRef descTile, + int arrayLength = 1, bool transpose = false) { assert(arrayLength == 1 && "Array descriptors are not supported"); auto type = cast(src.getType()); @@ -669,6 +669,9 @@ static SmallVector createDescriptorTiles(PatternRewriter &rewriter, Value newRowOffs = rewriter.create(loc, i); for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) { Value newColOffs = rewriter.create(loc, j); + if (transpose) { + std::swap(newRowOffs, newColOffs); + } auto tile = rewriter .create( loc, descType, rootTile, @@ -693,7 +696,8 @@ static SmallVector createDescriptorTiles(PatternRewriter &rewriter, static SmallVector createCoarseDscTiles(PatternRewriter &rewriter, Location loc, Value src, ArrayRef sgTile, - bool isVnni) { + bool isVnni, + bool transpose = false) { assert(sgTile.size() <= 2 && "Require at most 2D tile size for eltwise lowering"); @@ -727,7 +731,8 @@ static SmallVector createCoarseDscTiles(PatternRewriter &rewriter, // NOLINTEND return createDescriptorTiles(rewriter, loc, src, sgTile2D, {0, 0}, - {sgLoadRows, sgLoadCols}, arrayLength); + {sgLoadRows, sgLoadCols}, arrayLength, + transpose); } // Return vector type with specified VNNI shape. @@ -745,7 +750,8 @@ static SmallVector loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, xegpu::CachePolicyAttr hint, std::optional vnniConf = std::nullopt, - DenseI64ArrayAttr transpose = nullptr) { + DenseI64ArrayAttr transpose = nullptr, + IntegerAttr transpose_bit = nullptr) { // Assume all tiles have the same shape. auto tileType = cast(loadTiles[0].getType()); assert(llvm::all_of(loadTiles, @@ -760,7 +766,6 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, *vnniConf); packedAttr = mlir::UnitAttr::get(rewriter.getContext()); } - IntegerAttr transpose_bit = nullptr; SmallVector loadVec; for (auto tile : loadTiles) { @@ -860,6 +865,74 @@ extractVecSubTiles(PatternRewriter &rewriter, Location loc, return subTiles; } +// Checks whether the given `matmulOperand` is produced by a +// `linalg::TransposeOp` and ensures that the transpose result is only used by +// valid operations, such as `linalg::MatmulOp`, `linalg::BatchReduceMatmulOp`, +// or `linalg::GenericOp`. +// +// If a valid transpose operation is found, the function records it for later +// removal and returns the operand of the transpose operation as the new matrix +// multiplication operand. +static FailureOr findAndReplaceTranspose(const Value &matmulOperand, + size_t operandIdx, + PatternRewriter &rewriter) { + auto defOp = matmulOperand.getDefiningOp(); + if (!defOp) { + return failure(); + } + linalg::TransposeOp transposeOp = nullptr; + + for (auto x : defOp->getUsers()) { + if (isa(x)) { + if (transposeOp) { + return rewriter.notifyMatchFailure( + transposeOp, "Only one transpose operation is allowed"); + } + + transposeOp = dyn_cast(x); + + auto transposeRes = transposeOp.getDpsInits()[0]; + // verify that there are no other users of the transpose result + // rather than our matmul + for (auto trUser : transposeRes.getUsers()) { + if (isa(trUser) || + isa(trUser) || + isa(trUser)) { + auto matmulOp = dyn_cast(trUser); + auto actualMatmulOperand = matmulOp.getDpsInputs()[operandIdx]; + if (actualMatmulOperand != matmulOperand) { + return rewriter.notifyMatchFailure( + trUser, + "Transpose result is used by more than one matmul operation"); + } + } else if (isa(trUser)) { + // allow deallocs as users + continue; + } else if (isa(trUser)) { + // check if it's the same transpose as we're processing + if (!mlir::OperationEquivalence::isEquivalentTo(trUser, transposeOp, + /*flags=*/nullptr)) { + return rewriter.notifyMatchFailure( + trUser, "Only one transpose operation is allowed"); + } + continue; + } else { + return rewriter.notifyMatchFailure( + trUser, + "Transpose result is not allowed to be used by this operation"); + } + } + } + } + if (transposeOp) { + auto ret = transposeOp.getDpsInputs()[0]; + rewriter.eraseOp(transposeOp); + return ret; + } + return rewriter.notifyMatchFailure( + defOp, "No transpose operation producing the operand was found"); +} + // Create XeGPU DPAS kernel out of GEMM-like operation. static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, ArrayRef dpasTile, int kTile, @@ -867,6 +940,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, PatternRewriter &rewriter) { assert((isa(linalgOp) || isa(linalgOp) || + isa(linalgOp) || isa(linalgOp)) && "Requires a GEMM-like op for DPAS lowering"); @@ -877,6 +951,17 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, auto matB = linalgOp.getDpsInputs()[1]; auto matC = linalgOp.getDpsInits()[0]; + bool transposeB = false; + if (isa(linalgOp)) { + transposeB = true; + } else { + auto newMatB = findAndReplaceTranspose(matB, /*operandIdx=*/1, rewriter); + if (!failed(newMatB)) { + matB = *newMatB; + transposeB = true; + } + } + auto typeA = cast(matA.getType()); auto typeC = cast(matC.getType()); @@ -961,7 +1046,8 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, // Create B sub-tiles. SmallVector tilesB = - createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN}, /*isVnni=*/true); + createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN}, + /*isVnni=*/true, transposeB); // Create input prefetch tiles. int64_t numThreads = 1; @@ -997,7 +1083,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, {dimM, dimN}, kTile); auto prefetchDescB = createGemmCoopPrefetchTile( rewriter, linalgOp, /*inputPos=*/1, numThreads, {blockRows, blockCols}, - {dimM, dimN}, kTile); + (transposeB) ? std::vector{dimM, dimN} + : std::vector{dimN, dimM}, + kTile); if (succeeded(prefetchDescA) && succeeded(prefetchDescB)) { prefetchA = prefetchDescA->getResult(); @@ -1012,7 +1100,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, prefetchA = updateTilesOffsets(rewriter, loc, ValueRange{prefetchA}, {0, kTile})[0]; prefetchB = updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, - {kTile, 0})[0]; + (transposeB) + ? std::vector{0, kTile} + : std::vector{kTile, 0})[0]; } } else { // Disable coop prefetching on failure. @@ -1083,15 +1173,26 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, loadNdDescTiles(rewriter, loc, tilesA, readCacheHint); auto tileTypeA = cast(tilesA[0].getType()); + DenseI64ArrayAttr transpose = nullptr; + IntegerAttr transpose_bit = nullptr; + + if (transposeB) { + transpose_bit = rewriter.getIntegerAttr(rewriter.getIntegerType(32), 32); + transpose = DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}); + } + // Load B sub-tiles. SmallVector loadVecB = - loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB); + loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB, + transpose, transpose_bit); auto tileTypeB = cast(tilesB[0].getType()); // Update offsets of the input tiles. // Shift along the reduction dimension. tilesA = updateTilesOffsets(rewriter, loc, tilesA, {0, kTile}); - tilesB = updateTilesOffsets(rewriter, loc, tilesB, {kTile, 0}); + tilesB = updateTilesOffsets(rewriter, loc, tilesB, + transposeB ? std::vector{0, kTile} + : std::vector{kTile, 0}); // Prefetch the next set of input tiles. if (isCoopPrefetch) { @@ -1101,7 +1202,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, prefetchA = updateTilesOffsets(rewriter, loc, ValueRange{prefetchA}, {0, kTile})[0]; prefetchB = - updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, {kTile, 0})[0]; + updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, + transposeB ? std::vector{0, kTile} + : std::vector{kTile, 0})[0]; } else { // Apply naive prefetching for each subgroup separately. prefetchTiles(rewriter, loc, tilesA, readCacheHint); @@ -1288,7 +1391,7 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern { // Constrain conversion to the supported GEMM-like ops. static_assert( llvm::is_one_of::value); + linalg::GenericOp, linalg::MatmulTransposeBOp>::value); ConvertGemmLikeToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options) : OpRewritePattern(ctx), options(options) {} @@ -1495,8 +1598,9 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern { void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns, LinalgToXeGPUOptions options) { patterns.add, - ConvertGemmLikeToXeGPU>(patterns.getContext(), - options); + ConvertGemmLikeToXeGPU, + ConvertGemmLikeToXeGPU>( + patterns.getContext(), options); } void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns, diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose-sep-alloc.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose-sep-alloc.mlir new file mode 100644 index 00000000..81726450 --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose-sep-alloc.mlir @@ -0,0 +1,86 @@ +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s + +module { + func.func @matmul_transpose_b_sep(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xf16> + scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { + %subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> + %subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> + %subview_3 = memref.subview %alloc[0, %arg4] [1024, 32] [1, 1] : memref<1024x1024xf16> to memref<1024x32xf16, strided<[1024, 1], offset: ?>> + linalg.transpose ins(%subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_3 : memref<1024x32xf16, strided<[1024, 1], offset: ?>>) permutation = [1, 0] + linalg.matmul ins(%subview_1, %subview_3 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x32xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>) + scf.reduce + } + memref.dealloc %alloc : memref<1024x1024xf16> + return + } +} + +// CHECK-LABEL: func.func @matmul_transpose_b_sep +// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16> + +// CHECK-NOT: memref.alloc() + +// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { +// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}} +// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}} +// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}} + +// CHECK-NOT: linalg.transpose + +// Create output initial value load tiles. +// CHECK-DAG: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]] +// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0] +// CHECK-COUNT-7: xegpu.update_nd_offset + +// Load initial accumulator values. +// CHECK-DAG: %[[vC:.+]] = xegpu.load_nd %[[tC]] +// CHECK-COUNT-7: xegpu.load_nd + +// Extend the type to match DPAS output precision. +// CHECK: %[[vC_f32:.+]] = arith.extf %[[vC]] +// CHECK-COUNT-7: arith.extf + +// Create input load tiles. +// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]] +// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0] +// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] +// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0] +// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0] + +// Create DPAS computation loop over tiled reduction dimension. +// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16 +// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]] +// CHECK-SAME: { + +// Load input values and update the load tile position. +// CHECK: %[[vA:.+]] = xegpu.load_nd %[[iterA]] +// CHECK: %[[vB:.+]] = xegpu.load_nd %[[iterB]] {{.*}}transpose = array{{.*}}transpose_bit_width = 32 : i32{{.*}} +// CHECK: %[[vB1:.+]] = xegpu.load_nd %[[iterB1]] {{.*}}transpose = array, transpose_bit_width = 32 : i32{{.*}} + +// CHECK: %[[new_tA:.+]] = xegpu.update_nd_offset %[[iterA]], [%c0, %c16] +// CHECK: %[[new_tB:.+]] = xegpu.update_nd_offset %[[iterB]], [%c0, %c16] +// CHECK: %[[new_tB1:.+]] = xegpu.update_nd_offset %[[iterB1]], [%c0, %c16] + +// Apply simple prefetching scheme - start loading the next set of input +// tiles before computation is started. +// CHECK: xegpu.prefetch_nd %[[new_tA]] +// CHECK: xegpu.prefetch_nd %[[new_tB]] +// CHECK: xegpu.prefetch_nd %[[new_tB1]] + +// Extract DPAS-sized chunks from larger loaded tile A. +// Tile B is already in the correct shape. +// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16> +// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> +// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16> +// CHECK-COUNT-3: vector.extract_strided_slice + +// Perform DPAS computation. +// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]] +// CHECK-COUNT-7: xegpu.dpas + +// CHECK-NOT: memref.dealloc() diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose-sep.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose-sep.mlir new file mode 100644 index 00000000..1687f0ba --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose-sep.mlir @@ -0,0 +1,83 @@ +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s + +module { + func.func @matmul_transpose_b_sep(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { + %subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> + %subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1024x32xf16> + linalg.transpose ins(%subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%alloc : memref<1024x32xf16>) permutation = [1, 0] + linalg.matmul ins(%subview_1, %alloc : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x32xf16>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>) + memref.dealloc %alloc : memref<1024x32xf16> + scf.reduce + } + return + } +} + +// CHECK-LABEL: func.func @matmul_transpose_b_sep +// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16> + +// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { +// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}} +// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}} +// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}} + +// CHECK-NOT: memref.alloc() +// CHECK-NOT: linalg.transpose +// CHECK-NOT: memref.dealloc() + +// Create output initial value load tiles. +// CHECK-DAG: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]] +// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0] +// CHECK-COUNT-7: xegpu.update_nd_offset + +// Load initial accumulator values. +// CHECK-DAG: %[[vC:.+]] = xegpu.load_nd %[[tC]] +// CHECK-COUNT-7: xegpu.load_nd + +// Extend the type to match DPAS output precision. +// CHECK: %[[vC_f32:.+]] = arith.extf %[[vC]] +// CHECK-COUNT-7: arith.extf + +// Create input load tiles. +// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]] +// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0] +// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] +// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0] +// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0] + +// Create DPAS computation loop over tiled reduction dimension. +// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16 +// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]] +// CHECK-SAME: { + +// Load input values and update the load tile position. +// CHECK: %[[vA:.+]] = xegpu.load_nd %[[iterA]] +// CHECK: %[[vB:.+]] = xegpu.load_nd %[[iterB]] {{.*}}transpose = array{{.*}}transpose_bit_width = 32 : i32{{.*}} +// CHECK: %[[vB1:.+]] = xegpu.load_nd %[[iterB1]] {{.*}}transpose = array, transpose_bit_width = 32 : i32{{.*}} + +// CHECK: %[[new_tA:.+]] = xegpu.update_nd_offset %[[iterA]], [%c0, %c16] +// CHECK: %[[new_tB:.+]] = xegpu.update_nd_offset %[[iterB]], [%c0, %c16] +// CHECK: %[[new_tB1:.+]] = xegpu.update_nd_offset %[[iterB1]], [%c0, %c16] + +// Apply simple prefetching scheme - start loading the next set of input +// tiles before computation is started. +// CHECK: xegpu.prefetch_nd %[[new_tA]] +// CHECK: xegpu.prefetch_nd %[[new_tB]] +// CHECK: xegpu.prefetch_nd %[[new_tB1]] + +// Extract DPAS-sized chunks from larger loaded tile A. +// Tile B is already in the correct shape. +// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16> +// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> +// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16> +// CHECK-COUNT-3: vector.extract_strided_slice + +// Perform DPAS computation. +// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]] +// CHECK-COUNT-7: xegpu.dpas diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir new file mode 100644 index 00000000..d9dc180b --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir @@ -0,0 +1,76 @@ +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s + +module { + func.func @matmul_transpose_b(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { + %subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> + %subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> + linalg.matmul_transpose_b ins(%subview_1, %subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>) + scf.reduce + } + return + } +} + +// CHECK-LABEL: func.func @matmul_transpose_b +// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16> + +// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { +// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}} +// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}} +// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}} + +// Create output initial value load tiles. +// CHECK-DAG: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]] +// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0] +// CHECK-COUNT-7: xegpu.update_nd_offset + +// Load initial accumulator values. +// CHECK-DAG: %[[vC:.+]] = xegpu.load_nd %[[tC]] +// CHECK-COUNT-7: xegpu.load_nd + +// Extend the type to match DPAS output precision. +// CHECK: %[[vC_f32:.+]] = arith.extf %[[vC]] +// CHECK-COUNT-7: arith.extf + +// Create input load tiles. +// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]] +// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0] +// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] +// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0] +// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0] + +// Create DPAS computation loop over tiled reduction dimension. +// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16 +// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]] +// CHECK-SAME: { + +// Load input values and update the load tile position. +// CHECK: %[[vA:.+]] = xegpu.load_nd %[[iterA]] +// CHECK: %[[vB:.+]] = xegpu.load_nd %[[iterB]] {{.*}}transpose = array{{.*}}transpose_bit_width = 32 : i32{{.*}} +// CHECK: %[[vB1:.+]] = xegpu.load_nd %[[iterB1]] {{.*}}transpose = array, transpose_bit_width = 32 : i32{{.*}} + +// CHECK: %[[new_tA:.+]] = xegpu.update_nd_offset %[[iterA]], [%c0, %c16] +// CHECK: %[[new_tB:.+]] = xegpu.update_nd_offset %[[iterB]], [%c0, %c16] +// CHECK: %[[new_tB1:.+]] = xegpu.update_nd_offset %[[iterB1]], [%c0, %c16] + +// Apply simple prefetching scheme - start loading the next set of input +// tiles before computation is started. +// CHECK: xegpu.prefetch_nd %[[new_tA]] +// CHECK: xegpu.prefetch_nd %[[new_tB]] +// CHECK: xegpu.prefetch_nd %[[new_tB1]] + +// Extract DPAS-sized chunks from larger loaded tile A. +// Tile B is already in the correct shape. +// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16> +// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> +// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16> +// CHECK-COUNT-3: vector.extract_strided_slice + +// Perform DPAS computation. +// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]] +// CHECK-COUNT-7: xegpu.dpas diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_128x64_transpose.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_128x64_transpose.mlir new file mode 100644 index 00000000..4b3fc3b0 --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_128x64_transpose.mlir @@ -0,0 +1,33 @@ +// RUN: gc-opt %s --gc-gpu-pipeline="is-usm-args=false" \ +// RUN: | gc-cpu-runner -e main --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s +module{ + +func.func @linalg_matmul(%arg0: tensor<128x256xf16>, + %arg1: tensor<64x256xf16>, + %arg2: tensor<128x64xf16>) -> tensor<128x64xf16> { + %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x256xf16>, tensor<64x256xf16>) + outs(%arg2 : tensor<128x64xf16>) -> tensor<128x64xf16> + return %0 : tensor<128x64xf16> +} + +func.func @main() { + %0 = arith.constant dense<0.1> : tensor<128x256xf16> + %1 = arith.constant dense<0.2> : tensor<64x256xf16> + %2 = arith.constant dense<0.0> : tensor<128x64xf16> + %gpu_res = call @linalg_matmul(%0, %1, %2) : (tensor<128x256xf16>, tensor<64x256xf16>, tensor<128x64xf16>) -> tensor<128x64xf16> + + %slice = tensor.extract_slice %gpu_res[0, 0][32, 1][1, 1] : tensor<128x64xf16> to tensor<32xf16> + %cast = tensor.cast %slice : tensor<32xf16> to tensor<*xf16> + call @printMemrefF16(%cast) : (tensor<*xf16>) -> () + + return +} + +func.func private @printMemrefF16(%ptr : tensor<*xf16>) +} + +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} +// CHECK-SAME: rank = 1 offset = 0 sizes = [32] strides = [64] data = +// Computed using numpy: +// CHECK-NEXT: [5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719] diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_128x64_transpose_sep.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_128x64_transpose_sep.mlir new file mode 100644 index 00000000..75db435a --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_matmul_128x64_transpose_sep.mlir @@ -0,0 +1,34 @@ +// RUN: gc-opt %s --gc-gpu-pipeline="is-usm-args=false" \ +// RUN: | gc-cpu-runner -e main --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s +module{ + +func.func @linalg_matmul(%arg0: tensor<128x256xf16>, + %arg1: tensor<64x256xf16>, + %arg2: tensor<128x64xf16>) -> tensor<128x64xf16> { + %t = tensor.empty() : tensor<256x64xf16> + %0 = linalg.transpose ins(%arg1 : tensor<64x256xf16>) outs(%t: tensor<256x64xf16>) permutation = [1, 0] + %1 = linalg.matmul ins(%arg0, %0 : tensor<128x256xf16>, tensor<256x64xf16>) + outs(%arg2 : tensor<128x64xf16>) -> tensor<128x64xf16> + return %1 : tensor<128x64xf16> +} + +func.func @main() { + %0 = arith.constant dense<0.1> : tensor<128x256xf16> + %1 = arith.constant dense<0.2> : tensor<64x256xf16> + %2 = arith.constant dense<0.0> : tensor<128x64xf16> + %gpu_res = call @linalg_matmul(%0, %1, %2) : (tensor<128x256xf16>, tensor<64x256xf16>, tensor<128x64xf16>) -> tensor<128x64xf16> + + %slice = tensor.extract_slice %gpu_res[0, 0][32, 1][1, 1] : tensor<128x64xf16> to tensor<32xf16> + %cast = tensor.cast %slice : tensor<32xf16> to tensor<*xf16> + call @printMemrefF16(%cast) : (tensor<*xf16>) -> () + return +} + +func.func private @printMemrefF16(%ptr : tensor<*xf16>) +} + +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} +// CHECK-SAME: rank = 1 offset = 0 sizes = [32] strides = [64] data = +// Computed using numpy: +// CHECK-NEXT: [5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719, 5.11719] diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096_transpose.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096_transpose.mlir new file mode 100644 index 00000000..c2ca278b --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/f16_mlp_32x4096x4096x4096_transpose.mlir @@ -0,0 +1,122 @@ +// RUN: gc-opt %s --gc-gpu-pipeline="is-usm-args=false" \ +// RUN: | gc-cpu-runner -e main --entry-point-result=void \ +// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%opencl_runtime | FileCheck %s + +module { + func.func @linalg_mlp(%arg0: tensor<32x4096xf16>, %arg1: tensor<4096x4096xf16>, %arg2 : tensor<32x4096xf16>, + %arg3: tensor<4096x4096xf16>, %arg4 : tensor<32x4096xf16>) { + %cst = arith.constant 0.000000e+00 : f16 + %0 = tensor.empty() : tensor<32x4096xf16> + %1 = linalg.fill ins(%cst : f16) outs(%0 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<32x4096xf16>, tensor<4096x4096xf16>) + outs(%1 : tensor<32x4096xf16>) -> (tensor<32x4096xf16>) + %3 = tensor.empty() : tensor<32x4096xf16> + %4 = linalg.add ins(%arg2, %2 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%3 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %5 = arith.constant dense<0.000000e+00> : tensor<32x4096xf16> + %6 = tensor.empty() : tensor<32x4096xf16> + %7 = linalg.max ins(%5, %4 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%6 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + + %8 = tensor.empty() : tensor<32x4096xf16> + %9 = linalg.fill ins(%cst : f16) outs(%8 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %t = tensor.empty() : tensor<4096x4096xf16> + %transposed = linalg.transpose ins(%arg3 : tensor<4096x4096xf16>) outs(%t : tensor<4096x4096xf16>) permutation = [1, 0] + + %10 = linalg.matmul ins(%7, %transposed : tensor<32x4096xf16>, tensor<4096x4096xf16>) + outs(%9 : tensor<32x4096xf16>) -> (tensor<32x4096xf16>) + %11 = tensor.empty() : tensor<32x4096xf16> + %12 = linalg.add ins(%arg4, %10 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%11 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + %13 = arith.constant dense<0.000000e+00> : tensor<32x4096xf16> + %14 = tensor.empty() : tensor<32x4096xf16> + %15 = linalg.max ins(%13, %12 : tensor<32x4096xf16>, tensor<32x4096xf16>) + outs(%14 : tensor<32x4096xf16>) -> tensor<32x4096xf16> + + %slice = tensor.extract_slice %15[0, 0][32, 2][1, 1] : tensor<32x4096xf16> to tensor<32x2xf16> + %cast = tensor.cast %slice : tensor<32x2xf16> to tensor<*xf16> + call @printMemrefF16(%cast) : (tensor<*xf16>) -> () + + return + } + + // generates asymmetric tensor + func.func @generate_t(%even_val : f16, %odd_val : f16) -> tensor<4096x4096xf16> { + %0 = tensor.generate { + ^bb0(%i : index, %j : index): + %int0 = arith.index_cast %i : index to i32 + %int1 = arith.index_cast %j : index to i32 + + %c2 = arith.constant 2 : i32 + %c0 = arith.constant 0 : i32 + %remeinder = arith.remui %int0, %c2 : i32 + %is_even = arith.cmpi eq, %remeinder, %c0 : i32 + + %val = scf.if %is_even -> (f16) { + scf.yield %even_val : f16 + } else { + scf.yield %odd_val : f16 + } + + tensor.yield %val : f16 + } : tensor<4096x4096xf16> + return %0 : tensor<4096x4096xf16> + } + + func.func @main() { + %0 = arith.constant dense<0.01> : tensor<32x4096xf16> + + %even_v1 = arith.constant 0.02 : f16 + %odd_v1 = arith.constant 0.01 : f16 + %1 = call @generate_t(%even_v1, %odd_v1) : (f16, f16) -> tensor<4096x4096xf16> + + %2 = arith.constant dense<0.02> : tensor<32x4096xf16> + + %even_v2 = arith.constant 0.06 : f16 + %odd_v2 = arith.constant 0.03 : f16 + %3 = call @generate_t(%even_v2, %odd_v2) : (f16, f16) -> tensor<4096x4096xf16> + + %4 = arith.constant dense<0.02> : tensor<32x4096xf16> + + func.call @linalg_mlp(%0, %1, %2, %3, %4) : (tensor<32x4096xf16>, tensor<4096x4096xf16>, tensor<32x4096xf16>, + tensor<4096x4096xf16>, tensor<32x4096xf16>) -> () + return + } + + func.func private @printMemrefF16(%ptr : tensor<*xf16>) attributes { llvm.emit_c_interface } +} + +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} +// CHECK-SAME: rank = 2 offset = 0 sizes = [32, 2] strides = [4096, 1] data = +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375], +// CHECK-NEXT: [155.875, 77.9375]