diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 208f26489d6c3..2ab32836c80b1 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -78,6 +78,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" +#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 2915cf7d5bb01..4d272ba219c6f 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1421,4 +1421,18 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> { let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// VectorToXeGPU +//===----------------------------------------------------------------------===// + +def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { + let summary = "Lower the operations from the vector dialect into the XeGPU " + "dialect"; + let constructor = "mlir::createConvertVectorToXeGPUPass()"; + let dependentDialects = [ + "memref::MemRefDialect", "arith::ArithDialect", + "vector::VectorDialect", "xegpu::XeGPUDialect" + ]; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h b/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h new file mode 100644 index 0000000000000..ac4915901fdec --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h @@ -0,0 +1,29 @@ +//===- VectorToXeGPU.h - Convert vector to XeGPU dialect --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H +#define MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class Pass; +class RewritePatternSet; + +#define GEN_PASS_DECL_CONVERTVECTORTOXEGPU +#include "mlir/Conversion/Passes.h.inc" + +/// Collect a set of patterns to convert from the vector to XeGPU ops. +void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns); + +/// Create a pass to convert ops from vector to XeGPU. +std::unique_ptr createConvertVectorToXeGPUPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOXEGPU_VECTORTOXEGPU_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 813f700c5556e..6651d87162257 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -69,3 +69,4 @@ add_subdirectory(VectorToGPU) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) +add_subdirectory(VectorToXeGPU) diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt new file mode 100644 index 0000000000000..567083da00239 --- /dev/null +++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRVectorToXeGPU + VectorToXeGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToXeGPU + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRMemRefDialect + MLIRTransforms + MLIRVectorDialect + MLIRXeGPUDialect + ) diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp new file mode 100644 index 0000000000000..be1581d619a8b --- /dev/null +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -0,0 +1,257 @@ +//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering of vector operations to XeGPU dialect ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" + +#include +#include + +namespace mlir { +#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +static bool isZeroConstant(Value val) { + auto constant = val.getDefiningOp(); + if (!constant) + return false; + + return TypeSwitch(constant.getValue()) + .Case( + [](auto floatAttr) { return floatAttr.getValue().isZero(); }) + .Case( + [](auto intAttr) { return intAttr.getValue().isZero(); }) + .Default([](auto) { return false; }); +} + +static LogicalResult transferPreconditions(PatternRewriter &rewriter, + VectorTransferOpInterface xferOp) { + if (xferOp.getMask()) + return rewriter.notifyMatchFailure(xferOp, + "Masked transfer is not supported"); + + auto srcTy = dyn_cast(xferOp.getShapedType()); + if (!srcTy) + return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); + VectorType vecTy = xferOp.getVectorType(); + unsigned vecRank = vecTy.getRank(); + if (!(vecRank == 1 || vecRank == 2)) + return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector"); + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(srcTy, strides, offset)) || + strides.back() != 1) + return rewriter.notifyMatchFailure( + xferOp, "Buffer must be contiguous in the innermost dimension"); + + AffineMap map = xferOp.getPermutationMap(); + if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) + return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map"); + unsigned numInputDims = map.getNumInputs(); + for (AffineExpr expr : map.getResults().take_back(vecRank)) { + auto dim = dyn_cast(expr); + if (dim.getPosition() < (numInputDims - vecRank)) + return rewriter.notifyMatchFailure( + xferOp, "Only the innermost dimensions can be accessed"); + } + + return success(); +} + +static xegpu::CreateNdDescOp +createNdDescriptor(PatternRewriter &rewriter, Location loc, + xegpu::TensorDescType descType, TypedValue src, + Operation::operand_range offsets) { + MemRefType srcTy = src.getType(); + auto [strides, offset] = getStridesAndOffset(srcTy); + + xegpu::CreateNdDescOp ndDesc; + if (srcTy.hasStaticShape()) { + ndDesc = rewriter.create(loc, descType, src, + getAsOpFoldResult(offsets)); + } else { + // In case of any dynamic shapes, source's shape and strides have to be + // explicitly provided. + SmallVector sourceDims; + unsigned srcRank = srcTy.getRank(); + for (unsigned i = 0; i < srcRank; ++i) + sourceDims.push_back(rewriter.create(loc, src, i)); + + SmallVector constOffsets; + SmallVector dynOffsets; + for (Value offset : offsets) { + std::optional staticVal = getConstantIntValue(offset); + if (!staticVal) + dynOffsets.push_back(offset); + constOffsets.push_back(staticVal ? *staticVal : ShapedType::kDynamic); + } + + SmallVector dynShapes; + for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { + if (shape == ShapedType::kDynamic) + dynShapes.push_back(sourceDims[idx]); + } + + // Compute strides in reverse order. + SmallVector dynStrides; + Value accStride = rewriter.create(loc, 1); + // Last stride is guaranteed to be static and unit. + for (int i = static_cast(strides.size()) - 2; i >= 0; --i) { + accStride = + rewriter.create(loc, accStride, sourceDims[i + 1]); + if (strides[i] == ShapedType::kDynamic) + dynStrides.push_back(accStride); + } + std::reverse(dynStrides.begin(), dynStrides.end()); + + ndDesc = rewriter.create( + loc, descType, src, dynOffsets, dynShapes, dynStrides, + DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), + DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), + DenseI64ArrayAttr::get(rewriter.getContext(), strides)); + } + + return ndDesc; +} + +struct TransferReadLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const override { + Location loc = readOp.getLoc(); + + if (failed(transferPreconditions(rewriter, readOp))) + return failure(); + + bool isOutOfBounds = readOp.hasOutOfBoundsDim(); + if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) + return rewriter.notifyMatchFailure( + readOp, "Unsupported non-zero padded out-of-bounds read"); + + AffineMap readMap = readOp.getPermutationMap(); + bool isTransposeLoad = !readMap.isMinorIdentity(); + + VectorType vecTy = readOp.getVectorType(); + Type elementType = vecTy.getElementType(); + unsigned minTransposeBitWidth = 32; + if (isTransposeLoad && + elementType.getIntOrFloatBitWidth() < minTransposeBitWidth) + return rewriter.notifyMatchFailure( + readOp, "Unsupported data type for tranposition"); + + // If load is transposed, get the base shape for the tensor descriptor. + SmallVector descShape{vecTy.getShape()}; + if (isTransposeLoad) + std::reverse(descShape.begin(), descShape.end()); + auto descType = xegpu::TensorDescType::get( + descShape, elementType, /*scattered=*/false, /*array_length=*/1, + xegpu::MemoryScope::Global, + /*boundary_check=*/isOutOfBounds); + + xegpu::CreateNdDescOp ndDesc = + createNdDescriptor(rewriter, loc, descType, + dyn_cast>(readOp.getSource()), + readOp.getIndices()); + + DenseI64ArrayAttr transposeAttr = + !isTransposeLoad ? nullptr + : DenseI64ArrayAttr::get(rewriter.getContext(), + ArrayRef{1, 0}); + // By default, no specific caching policy is assigned. + xegpu::CachePolicyAttr hint = nullptr; + auto loadOp = rewriter.create( + loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + rewriter.replaceOp(readOp, loadOp); + + return success(); + } +}; + +struct TransferWriteLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const override { + Location loc = writeOp.getLoc(); + + if (failed(transferPreconditions(rewriter, writeOp))) + return failure(); + + if (writeOp.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(writeOp, + "Unsupported out-of-bounds write"); + AffineMap map = writeOp.getPermutationMap(); + if (!map.isMinorIdentity()) + return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); + + VectorType vecTy = writeOp.getVectorType(); + auto descType = xegpu::TensorDescType::get( + vecTy.getShape(), vecTy.getElementType(), + /*scattered=*/false, /*array_length=*/1, xegpu::MemoryScope::Global, + /*boundary_check=*/false); + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, + dyn_cast>(writeOp.getSource()), + writeOp.getIndices()); + + // By default, no specific caching policy is assigned. + xegpu::CachePolicyAttr hint = nullptr; + auto storeOp = + rewriter.create(loc, writeOp.getVector(), ndDesc, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); + rewriter.replaceOp(writeOp, storeOp); + + return success(); + } +}; + +struct ConvertVectorToXeGPUPass + : public impl::ConvertVectorToXeGPUBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorToXeGPUConversionPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateVectorToXeGPUConversionPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +std::unique_ptr mlir::createConvertVectorToXeGPUPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir new file mode 100644 index 0000000000000..4841ecbb62e80 --- /dev/null +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -0,0 +1,200 @@ +// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s + +func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0 + {in_bounds = [true]} : memref<8x16x32xf32>, vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: @load_1D_vector( +// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc +// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32, +// CHECK-SAME: boundary_check = false +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32> +// CHECK: return %[[VEC]] + +// ----- + +func.func @load_2D_vector(%source: memref<8x16x32xf32>, + %offset: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0 + {in_bounds = [true, true]} : memref<8x16x32xf32>, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @load_2D_vector( +// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc +// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: boundary_check = false +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// CHECK: return %[[VEC]] + +// ----- + +func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>, + %offset: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset], %c0 + {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @load_zero_pad_out_of_bounds( +// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: boundary_check = true +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// CHECK: return %[[VEC]] + +// ----- + +func.func @load_transposed(%source: memref<32x64xf32>, + %offset: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset], %c0 + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, + in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @load_transposed( +// CHECK-SAME: %[[SRC:.+]]: memref<32x64xf32>, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32 +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array}> +// CHECK-SAME: -> vector<8x16xf32> +// CHECK: return %[[VEC]] + +// ----- + +func.func @load_dynamic_source(%source: memref, + %offset: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0 + {in_bounds = [true, true]} : memref, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @load_dynamic_source( +// CHECK-SAME: %[[SRC:.+]]: memref, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]] +// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]] +// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] +// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] +// CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32 +// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32> +// CHECK: return %[[VEC]] + +// ----- + +func.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>, + %offset: index, %arg2: index, %pad: f32) -> (vector<8x16xf32>, vector<8x16xf32>) { + %c1 = arith.constant 1.0 : f32 + %0 = vector.transfer_read %source[%offset, %arg2], %c1 + {in_bounds = [true, false]} : memref<32x64xf32>, vector<8x16xf32> + %1 = vector.transfer_read %source[%arg2, %offset], %pad + {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32> + return %0, %1 : vector<8x16xf32>, vector<8x16xf32> +} + +// CHECK-LABEL: @no_load_out_of_bounds_non_zero_pad( +// CHECK-COUNT-2: vector.transfer_read + +// ----- + +func.func @no_load_masked(%source : memref<4xf32>, + %offset : index) -> vector<4xf32> { + %c0 = arith.constant 0.0 : f32 + %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1> + %0 = vector.transfer_read %source[%offset], %c0, %mask + {in_bounds = [true]} : memref<4xf32>, vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @no_load_masked( +// CHECK: vector.transfer_read + +// ----- + +func.func @no_load_tensor(%source: tensor<32x64xf32>, + %offset: index, %arg2: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %arg2], %c0 + {in_bounds = [true, true]} : tensor<32x64xf32>, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @no_load_tensor( +// CHECK: vector.transfer_read + +// ----- + +func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>, + %offset: index, %arg2: index) -> vector<8x16x32xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0 + {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32> + return %0 : vector<8x16x32xf32> +} + +// CHECK-LABEL: @no_load_high_dim_vector( +// CHECK: vector.transfer_read + +// ----- + +func.func @no_load_non_unit_inner_stride( + %source: memref<32xf32, strided<[?], offset: ?>>, + %offset: index) -> vector<8xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset], %c0 {in_bounds = [true]} + : memref<32xf32, strided<[?], offset: ?>>, vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: @no_load_non_unit_inner_stride( +// CHECK: vector.transfer_read + +// ----- + +func.func @no_load_unsupported_map(%source: memref<16x32x64xf32>, + %offset: index) -> vector<8x16xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0 + {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>, + in_bounds = [true, true]} : memref<16x32x64xf32>, vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// CHECK-LABEL: @no_load_unsupported_map( +// CHECK: vector.transfer_read + +// ----- + +func.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>, + %offset: index) -> vector<8x16xf16> { + %c0 = arith.constant 0.0 : f16 + %0 = vector.transfer_read %source[%offset, %offset], %c0 + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, + in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16> + return %0 : vector<8x16xf16> +} + +// CHECK-LABEL: @no_load_transpose_unsupported_data_type( +// CHECK: vector.transfer_read diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir new file mode 100644 index 0000000000000..361919c47b097 --- /dev/null +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir @@ -0,0 +1,159 @@ +// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s + +func.func @store_1D_vector(%vec: vector<8xf32>, + %source: memref<8x16x32xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset, %offset] + {in_bounds = [true]} + : vector<8xf32>, memref<8x16x32xf32> + return +} + +// CHECK-LABEL: @store_1D_vector( +// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>, +// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc +// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32, +// CHECK-SAME: boundary_check = false +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32> + +// ----- + +func.func @store_2D_vector(%vec: vector<8x16xf32>, + %source: memref<8x16x32xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset, %offset] + {in_bounds = [true, true]} + : vector<8x16xf32>, memref<8x16x32xf32> + return +} + +// CHECK-LABEL: @store_2D_vector( +// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, +// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc +// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32, +// CHECK-SAME: boundary_check = false +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> + +// ----- + +func.func @store_dynamic_source(%vec: vector<8x16xf32>, + %source: memref, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset, %offset] + {in_bounds = [true, true]} + : vector<8x16xf32>, memref + return +} + +// CHECK-LABEL: @store_dynamic_source( +// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>, +// CHECK-SAME: %[[SRC:.+]]: memref, +// CHECK-SAME: %[[OFFSET:.+]]: index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]] +// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]] +// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] +// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] +// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] +// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] +// CHECK-SAME: memref -> !xegpu.tensor_desc<8x16xf32 +// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32> + +// ----- + +func.func @no_store_transposed(%vec: vector<8x16xf32>, + %source: memref<32x64xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset] + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, + in_bounds = [true, true]} + : vector<8x16xf32>, memref<32x64xf32> + return +} + +// CHECK-LABEL: @no_store_transposed( +// CHECK: vector.transfer_write + +// ----- + +func.func @no_store_out_of_bounds(%vec: vector<8x16xf32>, + %source: memref<32x64xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset] + {in_bounds = [false, true]} + : vector<8x16xf32>, memref<32x64xf32> + return +} + +// CHECK-LABEL: @no_store_out_of_bounds( +// CHECK: vector.transfer_write + +// ----- + +func.func @no_store_masked(%vec: vector<4xf32>, + %source: memref<4xf32>, %offset: index) { + %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1> + vector.transfer_write %vec, %source[%offset], %mask + {in_bounds = [true]} + : vector<4xf32>, memref<4xf32> + return +} + +// CHECK-LABEL: @no_store_masked( +// CHECK: vector.transfer_write + +// ----- + +func.func @no_store_tensor(%vec: vector<8x16xf32>, + %source: tensor<32x64xf32>, %offset: index) -> tensor<32x64xf32> { + %0 = vector.transfer_write %vec, %source[%offset, %offset] + {in_bounds = [true, true]} + : vector<8x16xf32>, tensor<32x64xf32> + return %0 : tensor<32x64xf32> +} + +// CHECK-LABEL: @no_store_tensor( +// CHECK: vector.transfer_write + +// ----- + +func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>, + %source: memref<16x32x64xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset, %offset] + {in_bounds = [true, true, true]} + : vector<8x16x32xf32>, memref<16x32x64xf32> + return +} + +// CHECK-LABEL: @no_store_high_dim_vector( +// CHECK: vector.transfer_write + +// ----- + +func.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>, + %source: memref<32xf32, strided<[?], offset: ?>>, %offset: index) { + vector.transfer_write %vec, %source[%offset] + {in_bounds = [true]} + : vector<8xf32>, memref<32xf32, strided<[?], offset: ?>> + return +} + +// CHECK-LABEL: @no_store_non_unit_inner_stride( +// CHECK: vector.transfer_write + +// ----- + +func.func @no_store_unsupported_map(%vec: vector<8x16xf32>, + %source: memref<16x32x64xf32>, %offset: index) { + vector.transfer_write %vec, %source[%offset, %offset, %offset] + {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>, + in_bounds = [true, true]} + : vector<8x16xf32>, memref<16x32x64xf32> + return +} + +// CHECK-LABEL: @no_store_unsupported_map( +// CHECK: vector.transfer_write