-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][mesh] Use one type for mesh axis #76830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Make all ops and attributes use the types MeshAxis and MeshAxesAttr instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesMake all ops and attributes use the types MeshAxis and MeshAxesAttr instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr. Full diff: https://github.com/llvm/llvm-project/pull/76830.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a9d30dfbb9a76e..060d54b82efa63 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -80,8 +80,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
- ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
- OptionalArrayRefParameter<"int32_t">:$partial_axes,
+ ArrayRefParameter<"MeshAxesAttr">:$split_axes,
+ OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
);
@@ -146,18 +146,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let builders = [
AttrBuilder<(ins "SymbolRefAttr":$cluster,
- "ArrayRef<SmallVector<int32_t>>":$split_axes,
- "ArrayRef<int32_t>": $partial_axes,
+ "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
+ "ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
- SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
- split_axes, [&](ArrayRef<int32_t> array) {
- return DenseI32ArrayAttr::get($_ctxt, array);
+ SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
+ split_axes, [&](ArrayRef<MeshAxis> array) {
+ return MeshAxesAttr::get($_ctxt, array);
});
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
partial_type);
}]>,
AttrBuilder<(ins "SymbolRefAttr":$cluster,
- "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+ "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index ce7d5d045122d9..83452dcc2e8abe 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -17,6 +17,15 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <algorithm>
+namespace mlir {
+namespace mesh {
+
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
+} // namespace mesh
+} // namespace mlir
+
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
@@ -30,9 +39,6 @@
namespace mlir {
namespace mesh {
-using MeshAxis = int16_t;
-using MeshAxesAttr = DenseI16ArrayAttr;
-
bool isReductionLoop(IteratorType iType);
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1ed54b6519e4d8..1934bdfb427059 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -114,7 +114,7 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMeth
let builders = [
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
- OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -228,7 +228,7 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
}];
let builders = [
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
- OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 270955a3036e89..201c0151754eba 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -18,8 +18,8 @@ class Operation;
namespace mesh {
-using ShardingArray = SmallVector<SmallVector<int32_t>>;
-using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
+using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index de4f58d54e8ca5..c3d8f1d456106d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,15 +266,15 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
- ArrayRef<int32_t> partialAxes, Partial) {
+ SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+ ArrayRef<MeshAxis> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
- llvm::SmallSet<int32_t, 4> visitedAxes;
+ llvm::SmallSet<MeshAxis, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
- for (int32_t axis : axesArray) {
+ auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
+ for (MeshAxis axis : axesArray) {
if (axis < 0)
return emitError() << "mesh axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
@@ -283,8 +283,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
};
- for (DenseI32ArrayAttr subAxes : splitAxes) {
- ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
+ for (MeshAxesAttr subAxes : splitAxes) {
+ ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
@@ -318,10 +318,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
getSplitAxes().end()),
- std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+ std::mem_fn(&MeshAxesAttr::empty)) &&
llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
rhs.getSplitAxes().end()),
- std::mem_fn(&DenseI32ArrayAttr::empty));
+ std::mem_fn(&MeshAxesAttr::empty));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index a6f2f435f36d68..ee885ab16b7b06 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -216,7 +216,7 @@ namespace {
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
SymbolRefAttr cluster,
- ArrayRef<int32_t> meshAxes,
+ ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) {
if ((shardingOption.cluster && cluster &&
shardingOption.cluster != cluster) ||
@@ -230,7 +230,7 @@ static LogicalResult fillShardingOption(Operation *op,
if (i == loopIdx)
continue;
- for (int32_t axis : meshAxes) {
+ for (MeshAxis axis : meshAxes) {
if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
<< axis << " duplicate");
@@ -260,7 +260,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
shardingOption.shardingArray.resize(loopTypes.size());
- llvm::SmallVector<int32_t> partialMeshAxes;
+ llvm::SmallVector<MeshAxis> partialMeshAxes;
Partial partialType;
llvm::SmallSet<unsigned, 4> visitedLoopIndices;
bool anyShardingInResultsOrOperands = false;
@@ -277,7 +277,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// shardingOption[index]
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
@@ -288,7 +288,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// Handle the partial axes: at this stage, the exact loop index/indices
// cannot be decided because there could be multiple reduction loops.
- ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
+ ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
if (!partialAxes.empty()) {
if (!partialMeshAxes.empty())
return op->emitOpError() << "at most one result with partial axes is "
@@ -321,7 +321,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// then the operands with multiple loop indices.
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
checkOperandAffineExpr(expr, numDims);
if (failed(loopIndices))
@@ -362,7 +362,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!partialMeshAxes.empty()) {
bool anyNonEmptyReductionLoop = llvm::any_of(
llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
- SmallVector<int32_t> &subArray = it.value();
+ SmallVector<MeshAxis> &subArray = it.value();
int64_t idx = it.index();
return isReductionLoop(loopTypes[idx]) && !subArray.empty();
});
@@ -406,8 +406,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
return success();
auto resultType = result.getType().cast<RankedTensorType>();
- SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
- SmallVector<int32_t> partialAxes;
+ SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
+ SmallVector<MeshAxis> partialAxes;
// process the split axes
for (auto it : llvm::enumerate(map.getResults())) {
@@ -431,7 +431,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
assert(partialType == curPartialType &&
"Only one reduction type is supported");
partialType = curPartialType;
- const SmallVector<int32_t> &axis = std::get<1>(it);
+ const SmallVector<MeshAxis> &axis = std::get<1>(it);
partialAxes.append(axis);
}
}
@@ -459,7 +459,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
return success();
Value operand = opOperand.get();
auto operandType = operand.getType().cast<RankedTensorType>();
- SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+ SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
int64_t idx = it.index();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 8d7e89662131a0..37b86535959652 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -147,7 +147,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
.getResult()
.cast<TypedValue<ShapedType>>();
- llvm::SmallVector<int32_t> remainingPartialAxes;
+ llvm::SmallVector<MeshAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
std::back_inserter(allReduceMeshAxes),
[&targetShardingPartialAxesSet](Axis a) {
@@ -163,17 +163,17 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
static MeshShardingAttr
targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
splitTensorAxis) {
- targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
}
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
targetSplitAxes.push_back(splitMeshAxis);
targetShardingSplitAxes[splitTensorAxis] =
- DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -356,7 +356,7 @@ static MeshShardingAttr
targetShardingInUnsplitLastAxis(MLIRContext *ctx,
MeshShardingAttr sourceSharding,
int64_t splitTensorAxis) {
- SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
splitTensorAxis);
@@ -365,7 +365,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
targetSplitAxes.pop_back();
targetShardingSplitAxes[splitTensorAxis] =
- DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -475,11 +475,11 @@ static MeshShardingAttr
targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
int64_t sourceTensorAxis,
int64_t targetTensorAxis) {
- SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
targetTensorAxis) {
- targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
}
auto sourceSplitAxes =
@@ -488,13 +488,13 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
auto meshAxis = sourceSplitAxes.back();
sourceSplitAxes.pop_back();
targetShardingSplitAxes[sourceTensorAxis] =
- DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+ MeshAxesAttr::get(ctx, sourceSplitAxes);
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
targetSplitAxes.push_back(meshAxis);
targetShardingSplitAxes[targetTensorAxis] =
- DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
|
@yaochengji, could you review this? |
joker-eph
approved these changes
Jan 3, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Make all ops and attributes use the types MeshAxis and MeshAxesAttr instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.