Skip to content

Commit 7a4c497

Browse files
authored
[mlir][mesh] Use one type for mesh axis (#76830)
Make all ops and attributes use the types MeshAxis and MeshAxesAttr instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.
1 parent 428cf71 commit 7a4c497

File tree

7 files changed

+51
-45
lines changed

7 files changed

+51
-45
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
8080

8181
let parameters = (ins
8282
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
83-
ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
84-
OptionalArrayRefParameter<"int32_t">:$partial_axes,
83+
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
84+
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
8585
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
8686
);
8787

@@ -146,18 +146,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
146146

147147
let builders = [
148148
AttrBuilder<(ins "SymbolRefAttr":$cluster,
149-
"ArrayRef<SmallVector<int32_t>>":$split_axes,
150-
"ArrayRef<int32_t>": $partial_axes,
149+
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
150+
"ArrayRef<MeshAxis>": $partial_axes,
151151
"mesh::Partial": $partial_type), [{
152-
SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
153-
split_axes, [&](ArrayRef<int32_t> array) {
154-
return DenseI32ArrayAttr::get($_ctxt, array);
152+
SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
153+
split_axes, [&](ArrayRef<MeshAxis> array) {
154+
return MeshAxesAttr::get($_ctxt, array);
155155
});
156156
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
157157
partial_type);
158158
}]>,
159159
AttrBuilder<(ins "SymbolRefAttr":$cluster,
160-
"ArrayRef<SmallVector<int32_t>>":$split_axes), [{
160+
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
161161
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
162162
}]>
163163
];

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
#include "mlir/Interfaces/SideEffectInterfaces.h"
1818
#include <algorithm>
1919

20+
namespace mlir {
21+
namespace mesh {
22+
23+
using MeshAxis = int16_t;
24+
using MeshAxesAttr = DenseI16ArrayAttr;
25+
26+
} // namespace mesh
27+
} // namespace mlir
28+
2029
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
2130

2231
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
@@ -30,9 +39,6 @@
3039
namespace mlir {
3140
namespace mesh {
3241

33-
using MeshAxis = int16_t;
34-
using MeshAxesAttr = DenseI16ArrayAttr;
35-
3642
bool isReductionLoop(IteratorType iType);
3743

3844
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMeth
114114

115115
let builders = [
116116
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
117-
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
117+
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
118118
];
119119
}
120120

@@ -228,7 +228,7 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
228228
}];
229229
let builders = [
230230
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
231-
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
231+
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
232232
];
233233
}
234234

mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class Operation;
1818

1919
namespace mesh {
2020

21-
using ShardingArray = SmallVector<SmallVector<int32_t>>;
22-
using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
21+
using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
22+
using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
2323

2424
struct ShardingOption {
2525
// An array of int array. The sub-array at the i-th position signifies the

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,15 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
266266

267267
LogicalResult
268268
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
269-
SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
270-
ArrayRef<int32_t> partialAxes, Partial) {
269+
SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
270+
ArrayRef<MeshAxis> partialAxes, Partial) {
271271
// TODO: At present cluster symbol ref is not verified. This is due to the
272272
// difficulty in fetching the corresponding symbol op based on an attribute.
273273

274-
llvm::SmallSet<int32_t, 4> visitedAxes;
274+
llvm::SmallSet<MeshAxis, 4> visitedAxes;
275275

276-
auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
277-
for (int32_t axis : axesArray) {
276+
auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
277+
for (MeshAxis axis : axesArray) {
278278
if (axis < 0)
279279
return emitError() << "mesh axis is expected to be non-negative";
280280
if (!visitedAxes.insert(axis).second)
@@ -283,8 +283,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
283283
return success();
284284
};
285285

286-
for (DenseI32ArrayAttr subAxes : splitAxes) {
287-
ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
286+
for (MeshAxesAttr subAxes : splitAxes) {
287+
ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
288288
if (failed(checkMeshAxis(subAxesArray)))
289289
return failure();
290290
}
@@ -318,10 +318,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
318318

319319
return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
320320
getSplitAxes().end()),
321-
std::mem_fn(&DenseI32ArrayAttr::empty)) &&
321+
std::mem_fn(&MeshAxesAttr::empty)) &&
322322
llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
323323
rhs.getSplitAxes().end()),
324-
std::mem_fn(&DenseI32ArrayAttr::empty));
324+
std::mem_fn(&MeshAxesAttr::empty));
325325
}
326326

327327
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ namespace {
216216
static LogicalResult fillShardingOption(Operation *op,
217217
ShardingOption &shardingOption,
218218
SymbolRefAttr cluster,
219-
ArrayRef<int32_t> meshAxes,
219+
ArrayRef<MeshAxis> meshAxes,
220220
unsigned loopIdx) {
221221
if ((shardingOption.cluster && cluster &&
222222
shardingOption.cluster != cluster) ||
@@ -230,7 +230,7 @@ static LogicalResult fillShardingOption(Operation *op,
230230
if (i == loopIdx)
231231
continue;
232232

233-
for (int32_t axis : meshAxes) {
233+
for (MeshAxis axis : meshAxes) {
234234
if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
235235
LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
236236
<< axis << " duplicate");
@@ -260,7 +260,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
260260
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
261261
unsigned numOperands = op->getNumOperands();
262262
shardingOption.shardingArray.resize(loopTypes.size());
263-
llvm::SmallVector<int32_t> partialMeshAxes;
263+
llvm::SmallVector<MeshAxis> partialMeshAxes;
264264
Partial partialType;
265265
llvm::SmallSet<unsigned, 4> visitedLoopIndices;
266266
bool anyShardingInResultsOrOperands = false;
@@ -277,7 +277,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
277277
// shardingOption[index]
278278
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
279279
AffineExpr expr = std::get<0>(it);
280-
ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
280+
ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
281281
auto dim = cast<AffineDimExpr>(expr);
282282
unsigned index = dim.getPosition();
283283
visitedLoopIndices.insert(index);
@@ -288,7 +288,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
288288

289289
// Handle the partial axes: at this stage, the exact loop index/indices
290290
// cannot be decided because there could be multiple reduction loops.
291-
ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
291+
ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
292292
if (!partialAxes.empty()) {
293293
if (!partialMeshAxes.empty())
294294
return op->emitOpError() << "at most one result with partial axes is "
@@ -321,7 +321,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
321321
// then the operands with multiple loop indices.
322322
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
323323
AffineExpr expr = std::get<0>(it);
324-
ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
324+
ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
325325
FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
326326
checkOperandAffineExpr(expr, numDims);
327327
if (failed(loopIndices))
@@ -362,7 +362,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
362362
if (!partialMeshAxes.empty()) {
363363
bool anyNonEmptyReductionLoop = llvm::any_of(
364364
llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
365-
SmallVector<int32_t> &subArray = it.value();
365+
SmallVector<MeshAxis> &subArray = it.value();
366366
int64_t idx = it.index();
367367
return isReductionLoop(loopTypes[idx]) && !subArray.empty();
368368
});
@@ -406,8 +406,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
406406
return success();
407407

408408
auto resultType = result.getType().cast<RankedTensorType>();
409-
SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
410-
SmallVector<int32_t> partialAxes;
409+
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
410+
SmallVector<MeshAxis> partialAxes;
411411

412412
// process the split axes
413413
for (auto it : llvm::enumerate(map.getResults())) {
@@ -431,7 +431,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
431431
assert(partialType == curPartialType &&
432432
"Only one reduction type is supported");
433433
partialType = curPartialType;
434-
const SmallVector<int32_t> &axis = std::get<1>(it);
434+
const SmallVector<MeshAxis> &axis = std::get<1>(it);
435435
partialAxes.append(axis);
436436
}
437437
}
@@ -459,7 +459,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
459459
return success();
460460
Value operand = opOperand.get();
461461
auto operandType = operand.getType().cast<RankedTensorType>();
462-
SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
462+
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
463463
unsigned numDims = map.getNumDims();
464464
for (auto it : llvm::enumerate(map.getResults())) {
465465
int64_t idx = it.index();

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
147147
.getResult()
148148
.cast<TypedValue<ShapedType>>();
149149

150-
llvm::SmallVector<int32_t> remainingPartialAxes;
150+
llvm::SmallVector<MeshAxis> remainingPartialAxes;
151151
llvm::copy_if(sourceShardingPartialAxesSet,
152152
std::back_inserter(allReduceMeshAxes),
153153
[&targetShardingPartialAxesSet](Axis a) {
@@ -163,17 +163,17 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
163163
static MeshShardingAttr
164164
targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
165165
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
166-
SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
166+
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
167167
llvm::to_vector(sourceSharding.getSplitAxes());
168168
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
169169
splitTensorAxis) {
170-
targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
170+
targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
171171
}
172172
auto targetSplitAxes =
173173
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
174174
targetSplitAxes.push_back(splitMeshAxis);
175175
targetShardingSplitAxes[splitTensorAxis] =
176-
DenseI32ArrayAttr::get(ctx, targetSplitAxes);
176+
MeshAxesAttr::get(ctx, targetSplitAxes);
177177
return MeshShardingAttr::get(
178178
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
179179
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -356,7 +356,7 @@ static MeshShardingAttr
356356
targetShardingInUnsplitLastAxis(MLIRContext *ctx,
357357
MeshShardingAttr sourceSharding,
358358
int64_t splitTensorAxis) {
359-
SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
359+
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
360360
llvm::to_vector(sourceSharding.getSplitAxes());
361361
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
362362
splitTensorAxis);
@@ -365,7 +365,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
365365

366366
targetSplitAxes.pop_back();
367367
targetShardingSplitAxes[splitTensorAxis] =
368-
DenseI32ArrayAttr::get(ctx, targetSplitAxes);
368+
MeshAxesAttr::get(ctx, targetSplitAxes);
369369
return MeshShardingAttr::get(
370370
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
371371
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -475,11 +475,11 @@ static MeshShardingAttr
475475
targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
476476
int64_t sourceTensorAxis,
477477
int64_t targetTensorAxis) {
478-
SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
478+
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
479479
llvm::to_vector(sourceSharding.getSplitAxes());
480480
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
481481
targetTensorAxis) {
482-
targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
482+
targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
483483
}
484484

485485
auto sourceSplitAxes =
@@ -488,13 +488,13 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
488488
auto meshAxis = sourceSplitAxes.back();
489489
sourceSplitAxes.pop_back();
490490
targetShardingSplitAxes[sourceTensorAxis] =
491-
DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
491+
MeshAxesAttr::get(ctx, sourceSplitAxes);
492492

493493
auto targetSplitAxes =
494494
llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
495495
targetSplitAxes.push_back(meshAxis);
496496
targetShardingSplitAxes[targetTensorAxis] =
497-
DenseI32ArrayAttr::get(ctx, targetSplitAxes);
497+
MeshAxesAttr::get(ctx, targetSplitAxes);
498498

499499
return MeshShardingAttr::get(
500500
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,

0 commit comments

Comments
 (0)