Skip to content

Commit 20d6def

Browse files
authored
[mlir][xegpu] refine basic routines (#138701)
This PR adds two interfaces for `LayoutAttr` and updates the builder of `CreateNdOp` for convenience.
1 parent 512a5d0 commit 20d6def

File tree

3 files changed

+51
-36
lines changed

3 files changed

+51
-36
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
243243
);
244244

245245
let builders = [
246-
AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
247-
"llvm::ArrayRef<int>": $lane_data),
246+
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
247+
"llvm::ArrayRef<int32_t>": $lane_data),
248248
[{
249249
auto sg_layout = DenseI32ArrayAttr();
250250
auto sg_data = DenseI32ArrayAttr();
@@ -253,6 +253,25 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
253253
return $_get($_ctxt, sg_layout, sg_data, inst_data,
254254
DenseI32ArrayAttr::get($_ctxt, lane_layout),
255255
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
256+
}]>,
257+
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
258+
"llvm::ArrayRef<int32_t>": $lane_data,
259+
"llvm::ArrayRef<int32_t>": $order),
260+
[{
261+
return $_get($_ctxt,
262+
/*sg_layout =*/ nullptr,
263+
/*sg_data =*/ nullptr,
264+
/*inst_data =*/ nullptr,
265+
DenseI32ArrayAttr::get($_ctxt, lane_layout),
266+
DenseI32ArrayAttr::get($_ctxt, lane_data),
267+
DenseI32ArrayAttr::get($_ctxt, order));
268+
}]>,
269+
AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
270+
"DenseI32ArrayAttr": $lane_data,
271+
"DenseI32ArrayAttr": $order),
272+
[{
273+
return $_get($_ctxt, /*sg_layout =*/ nullptr, /*sg_data =*/ nullptr,
274+
/*inst_data =*/ nullptr, lane_layout, lane_data, order);
256275
}]>
257276
];
258277

@@ -262,7 +281,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
262281
}
263282

264283
bool isSgLayout() {
265-
return getSgLayout() == nullptr && getLaneLayout() != nullptr;
284+
return !isWgLayout();
266285
}
267286

268287
int64_t getRank() {
@@ -274,6 +293,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
274293
return attr.size();
275294
return 0;
276295
}
296+
297+
LayoutAttr dropSgLayoutAndData() {
298+
return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
299+
getLaneLayout(), getLaneData(), getOrder());
300+
}
301+
302+
LayoutAttr dropInstData() {
303+
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
304+
getLaneLayout(), getLaneData(), getOrder());
305+
}
306+
277307
}];
278308

279309
let assemblyFormat = "`<` struct(params) `>`";

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
142142
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
143143
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
144144

145-
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
146-
"llvm::ArrayRef<OpFoldResult>": $offsets,
147-
"llvm::ArrayRef<OpFoldResult>": $shape,
148-
"llvm::ArrayRef<OpFoldResult>": $strides)>,
149-
150-
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
145+
OpBuilder<(ins "Type": $tdesc, "Value": $source,
151146
"llvm::ArrayRef<OpFoldResult>": $offsets,
152147
"llvm::ArrayRef<OpFoldResult>": $shape,
153148
"llvm::ArrayRef<OpFoldResult>": $strides)>

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
141141
}
142142

143143
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
144-
Type tdesc, TypedValue<MemRefType> source,
144+
Type tdesc, Value source,
145145
llvm::ArrayRef<OpFoldResult> offsets,
146146
llvm::ArrayRef<OpFoldResult> shape,
147147
llvm::ArrayRef<OpFoldResult> strides) {
148148
assert(shape.size() && offsets.size() && strides.size() &&
149149
shape.size() == strides.size() && shape.size() == offsets.size());
150150

151-
llvm::SmallVector<int64_t> staticOffsets;
152-
llvm::SmallVector<int64_t> staticShape;
153-
llvm::SmallVector<int64_t> staticStrides;
151+
Type srcTy = source.getType();
152+
assert(isa<IntegerType>(srcTy) ||
153+
isa<MemRefType>(srcTy) && "Source has to be either int or memref.");
154+
154155
llvm::SmallVector<Value> dynamicOffsets;
155156
llvm::SmallVector<Value> dynamicShape;
156157
llvm::SmallVector<Value> dynamicStrides;
157158

158-
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
159-
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
160-
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
161-
162-
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
163-
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
164-
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
165-
166-
build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
167-
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
168-
}
169-
170-
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
171-
Type tdesc, TypedValue<IntegerType> source,
172-
llvm::ArrayRef<OpFoldResult> offsets,
173-
llvm::ArrayRef<OpFoldResult> shape,
174-
llvm::ArrayRef<OpFoldResult> strides) {
175-
assert(shape.size() && offsets.size() && strides.size() &&
176-
shape.size() == strides.size() && shape.size() == offsets.size());
177-
178159
llvm::SmallVector<int64_t> staticOffsets;
179160
llvm::SmallVector<int64_t> staticShape;
180161
llvm::SmallVector<int64_t> staticStrides;
181-
llvm::SmallVector<Value> dynamicOffsets;
182-
llvm::SmallVector<Value> dynamicShape;
183-
llvm::SmallVector<Value> dynamicStrides;
184162

185163
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
186164
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
190168
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
191169
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
192170

171+
if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
172+
auto memrefShape = memrefTy.getShape();
173+
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
174+
175+
// if shape and strides are from Memref, we don't need attributes for them
176+
// to keep the IR print clean.
177+
if (staticShape == memrefShape && staticStrides == memrefStrides) {
178+
staticShapeAttr = DenseI64ArrayAttr();
179+
staticStridesAttr = DenseI64ArrayAttr();
180+
}
181+
}
182+
193183
build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
194184
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
195185
}

0 commit comments

Comments
 (0)