diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 4841f94de75f4..0136b18ccfa94 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -59,19 +59,29 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> { let summary = [{a composite attribute for `TensorDescType`}]; - let description = [{`ScatterTensorDesc` (or `scatter_tdesc_attr`) is a composite - attribute defined for `TensorDescType` for describing following - properties of a `TensorDesc`. + let description = [{ + `ScatterTensorDesc` is a composite attribute defined for `TensorDescType` + for describing following properties of a `TensorDesc`: + 1. `memory_space`: It describes where the data block described by the TensorDesc is located, `Global` device memory or `Shared` local memory. It is default to `Global`. - 2. `chunk_size`: indicates number of continious elements accessed for each + + 2. `chunk_size`: indicates number of contiguous elements accessed for each offset, default is 1. It is used with `scattered` attr only. }]; let parameters = (ins - OptionalParameter<"MemorySpaceAttr">: $memory_space, - OptionalParameter<"IntegerAttr", "1">: $chunk_size + DefaultValuedParameter< + "MemorySpaceAttr", + "MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)", + "Data memory location" + >: $memory_space, + DefaultValuedParameter< + "IntegerAttr", + "IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)", + "Number of contiguous elements" + >: $chunk_size ); let builders = [ @@ -80,6 +90,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat CArg<"int", "1">: $chunk_size )> ]; + + let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 494f11f041b71..cc2e93fb19a70 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -172,7 +172,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", auto attr = getEncoding(); auto scatter_attr = mlir::dyn_cast_if_present(attr); assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr."); - if (scatter_attr && scatter_attr.getChunkSize()) + if (scatter_attr) return scatter_attr.getChunkSize().getInt(); return 1; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index becc32d122697..06fd03f3af3ad 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -55,6 +55,18 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context, return Base::get(context, scopeAttr, chunkSizeAttr); } +LogicalResult ScatterTensorDescAttr::verify( + llvm::function_ref emitError, + MemorySpaceAttr memory_space, IntegerAttr chunk_size) { + int64_t chunkSize = chunk_size.getInt(); + SmallVector supportedChunkSizes = {1, 2, 3, 4, 8, + 16, 32, 64, 128, 256}; + if (!llvm::is_contained(supportedChunkSizes, chunkSize)) + return emitError() << "invalid chunk size"; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_SGMapAttr //===----------------------------------------------------------------------===// @@ -166,8 +178,6 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { continue; } } - parser.emitError(parser.getCurrentLocation(), - "Failed to parse the attribute.\n"); return {}; } @@ -237,8 +247,7 @@ LogicalResult TensorDescType::verify( // Expected tensor ranks for scattered data: // - 1D tensor for fully non-contiguous elements (chunk size == 1) // - 2D tensor for scattered blocks (chunk size > 1) - IntegerAttr chunkAttr = scatterAttr.getChunkSize(); - unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1; + unsigned chunkSize = scatterAttr.getChunkSize().getInt(); if (rank == 1 && chunkSize != 1) return emitError() << "expected non-contiguous elements for 1D tensor"; if (rank == 2 && chunkSize < 2) @@ -273,8 +282,7 @@ LogicalResult TensorDescType::verify( return emitError() << "cannot map over non-contiguous scattered row elements"; - IntegerAttr chunkAttr = scatterAttr.getChunkSize(); - unsigned chunkSize = chunkAttr ? chunkAttr.getInt() : 1; + unsigned chunkSize = scatterAttr.getChunkSize().getInt(); if (wiData[1] != chunkSize) return emitError() << "work item data mapping must match the number of " "contiguous elements"; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index e06d99ac20bb7..25dc1f22f0432 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -419,16 +419,8 @@ LogicalResult CreateDescOp::verify() { << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; - auto chunkSize = tdescTy.getChunkSize(); - - // check chunk_size - llvm::SmallVector supportedChunkSizes = {1, 2, 3, 4, 8, - 16, 32, 64, 128, 256}; - if (!llvm::is_contained(supportedChunkSizes, chunkSize)) - return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, " - "8, 16, 32, 64, 128, or 256."); - // check total size + auto chunkSize = tdescTy.getChunkSize(); auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); auto bitsPerLane = elemBits * chunkSize; if (chunkSize > 1 && bitsPerLane % 32) { diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir index 8af1b600ad0a4..472176af72b19 100644 --- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir @@ -181,6 +181,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref) { gpu.return } +// CHECK: gpu.func @test_create_tdesc_vc_2(%[[arg0:.*]]: memref) { +gpu.func @test_create_tdesc_vc_2(%src: memref) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr<> + %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> -> !xegpu.tensor_desc<4xf32, #xegpu.scatter_tdesc_attr> + gpu.return +} + // CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) { gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 9162e0012f6d5..86356e09de57c 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -190,7 +190,7 @@ func.func @test_create_tdesc_vc_2(%src: ui64) { } // ----- -func.func @test_create_tdesc_vc_1(%src: memref) { +func.func @test_create_tdesc_vc_3(%src: memref) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> // expected-error@+1 {{Memory space mismatch}} %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> @@ -198,6 +198,15 @@ func.func @test_create_tdesc_vc_1(%src: memref) { return } +// ----- +func.func @test_create_tdesc_vc_4(%src: memref) { + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %1 = xegpu.create_tdesc %src, %0 : memref, vector<4xindex> + // expected-error@+1 {{invalid chunk size}} + -> !xegpu.tensor_desc<4x5xf32, #xegpu.scatter_tdesc_attr> + return +} + // ----- func.func @test_prefetch_vc_1(%src: memref<24x32xf16>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16>