Skip to content

[mlir][mesh] Use one type for mesh axis #76830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 3, 2024

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Jan 3, 2024

Make all ops and attributes use the types MeshAxis and MeshAxesAttr instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.

Make all ops and attributes use the types MeshAxis and MeshAxesAttr
instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.
@llvmbot llvmbot added the mlir label Jan 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Make all ops and attributes use the types MeshAxis and MeshAxesAttr instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.


Full diff: https://github.com/llvm/llvm-project/pull/76830.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+8-8)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+9-3)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+2-2)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+9-9)
  • (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+11-11)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+10-10)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a9d30dfbb9a76e..060d54b82efa63 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -80,8 +80,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let parameters = (ins
     AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
-    ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
-    OptionalArrayRefParameter<"int32_t">:$partial_axes,
+    ArrayRefParameter<"MeshAxesAttr">:$split_axes,
+    OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
   );
 
@@ -146,18 +146,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let builders = [
     AttrBuilder<(ins "SymbolRefAttr":$cluster, 
-                     "ArrayRef<SmallVector<int32_t>>":$split_axes,
-                     "ArrayRef<int32_t>": $partial_axes,
+                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
+                     "ArrayRef<MeshAxis>": $partial_axes,
                      "mesh::Partial": $partial_type), [{
-      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
-                  split_axes, [&](ArrayRef<int32_t> array) {
-          return DenseI32ArrayAttr::get($_ctxt, array);
+      SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
+                  split_axes, [&](ArrayRef<MeshAxis> array) {
+          return MeshAxesAttr::get($_ctxt, array);
       });
       return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
                    partial_type);
     }]>,
     AttrBuilder<(ins "SymbolRefAttr":$cluster, 
-                     "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+                     "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
       return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index ce7d5d045122d9..83452dcc2e8abe 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -17,6 +17,15 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include <algorithm>
 
+namespace mlir {
+namespace mesh {
+
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
+} // namespace mesh
+} // namespace mlir
+
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
@@ -30,9 +39,6 @@
 namespace mlir {
 namespace mesh {
 
-using MeshAxis = int16_t;
-using MeshAxesAttr = DenseI16ArrayAttr;
-
 bool isReductionLoop(IteratorType iType);
 
 bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1ed54b6519e4d8..1934bdfb427059 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -114,7 +114,7 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMeth
 
   let builders = [
     OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
-    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
 
@@ -228,7 +228,7 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
   }];
   let builders = [
     OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
-    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 270955a3036e89..201c0151754eba 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -18,8 +18,8 @@ class Operation;
 
 namespace mesh {
 
-using ShardingArray = SmallVector<SmallVector<int32_t>>;
-using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
+using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
 
 struct ShardingOption {
   // An array of int array. The sub-array at the i-th position signifies the
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index de4f58d54e8ca5..c3d8f1d456106d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,15 +266,15 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
-                         ArrayRef<int32_t> partialAxes, Partial) {
+                         SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+                         ArrayRef<MeshAxis> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
-  llvm::SmallSet<int32_t, 4> visitedAxes;
+  llvm::SmallSet<MeshAxis, 4> visitedAxes;
 
-  auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
-    for (int32_t axis : axesArray) {
+  auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
+    for (MeshAxis axis : axesArray) {
       if (axis < 0)
         return emitError() << "mesh axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
@@ -283,8 +283,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   };
 
-  for (DenseI32ArrayAttr subAxes : splitAxes) {
-    ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
+  for (MeshAxesAttr subAxes : splitAxes) {
+    ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }
@@ -318,10 +318,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 
   return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
                                        getSplitAxes().end()),
-                      std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+                      std::mem_fn(&MeshAxesAttr::empty)) &&
          llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
                                        rhs.getSplitAxes().end()),
-                      std::mem_fn(&DenseI32ArrayAttr::empty));
+                      std::mem_fn(&MeshAxesAttr::empty));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index a6f2f435f36d68..ee885ab16b7b06 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -216,7 +216,7 @@ namespace {
 static LogicalResult fillShardingOption(Operation *op,
                                         ShardingOption &shardingOption,
                                         SymbolRefAttr cluster,
-                                        ArrayRef<int32_t> meshAxes,
+                                        ArrayRef<MeshAxis> meshAxes,
                                         unsigned loopIdx) {
   if ((shardingOption.cluster && cluster &&
        shardingOption.cluster != cluster) ||
@@ -230,7 +230,7 @@ static LogicalResult fillShardingOption(Operation *op,
     if (i == loopIdx)
       continue;
 
-    for (int32_t axis : meshAxes) {
+    for (MeshAxis axis : meshAxes) {
       if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
         LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
                           << axis << " duplicate");
@@ -260,7 +260,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
   SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
   unsigned numOperands = op->getNumOperands();
   shardingOption.shardingArray.resize(loopTypes.size());
-  llvm::SmallVector<int32_t> partialMeshAxes;
+  llvm::SmallVector<MeshAxis> partialMeshAxes;
   Partial partialType;
   llvm::SmallSet<unsigned, 4> visitedLoopIndices;
   bool anyShardingInResultsOrOperands = false;
@@ -277,7 +277,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
     // shardingOption[index]
     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
       AffineExpr expr = std::get<0>(it);
-      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
       auto dim = cast<AffineDimExpr>(expr);
       unsigned index = dim.getPosition();
       visitedLoopIndices.insert(index);
@@ -288,7 +288,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
 
     // Handle the partial axes: at this stage, the exact loop index/indices
     // cannot be decided because there could be multiple reduction loops.
-    ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
+    ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
     if (!partialAxes.empty()) {
       if (!partialMeshAxes.empty())
         return op->emitOpError() << "at most one result with partial axes is "
@@ -321,7 +321,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
     // then the operands with multiple loop indices.
     for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
       AffineExpr expr = std::get<0>(it);
-      ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+      ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
       FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
           checkOperandAffineExpr(expr, numDims);
       if (failed(loopIndices))
@@ -362,7 +362,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
   if (!partialMeshAxes.empty()) {
     bool anyNonEmptyReductionLoop = llvm::any_of(
         llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
-          SmallVector<int32_t> &subArray = it.value();
+          SmallVector<MeshAxis> &subArray = it.value();
           int64_t idx = it.index();
           return isReductionLoop(loopTypes[idx]) && !subArray.empty();
         });
@@ -406,8 +406,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
     return success();
 
   auto resultType = result.getType().cast<RankedTensorType>();
-  SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
-  SmallVector<int32_t> partialAxes;
+  SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
+  SmallVector<MeshAxis> partialAxes;
 
   // process the split axes
   for (auto it : llvm::enumerate(map.getResults())) {
@@ -431,7 +431,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
         assert(partialType == curPartialType &&
                "Only one reduction type is supported");
       partialType = curPartialType;
-      const SmallVector<int32_t> &axis = std::get<1>(it);
+      const SmallVector<MeshAxis> &axis = std::get<1>(it);
       partialAxes.append(axis);
     }
   }
@@ -459,7 +459,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
     return success();
   Value operand = opOperand.get();
   auto operandType = operand.getType().cast<RankedTensorType>();
-  SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+  SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
   unsigned numDims = map.getNumDims();
   for (auto it : llvm::enumerate(map.getResults())) {
     int64_t idx = it.index();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 8d7e89662131a0..37b86535959652 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -147,7 +147,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
           .getResult()
           .cast<TypedValue<ShapedType>>();
 
-  llvm::SmallVector<int32_t> remainingPartialAxes;
+  llvm::SmallVector<MeshAxis> remainingPartialAxes;
   llvm::copy_if(sourceShardingPartialAxesSet,
                 std::back_inserter(allReduceMeshAxes),
                 [&targetShardingPartialAxesSet](Axis a) {
@@ -163,17 +163,17 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
 static MeshShardingAttr
 targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
                               int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
          splitTensorAxis) {
-    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+    targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
   }
   auto targetSplitAxes =
       llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
   targetSplitAxes.push_back(splitMeshAxis);
   targetShardingSplitAxes[splitTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+      MeshAxesAttr::get(ctx, targetSplitAxes);
   return MeshShardingAttr::get(
       ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -356,7 +356,7 @@ static MeshShardingAttr
 targetShardingInUnsplitLastAxis(MLIRContext *ctx,
                                 MeshShardingAttr sourceSharding,
                                 int64_t splitTensorAxis) {
-  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
          splitTensorAxis);
@@ -365,7 +365,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
 
   targetSplitAxes.pop_back();
   targetShardingSplitAxes[splitTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+      MeshAxesAttr::get(ctx, targetSplitAxes);
   return MeshShardingAttr::get(
       ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -475,11 +475,11 @@ static MeshShardingAttr
 targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
                              int64_t sourceTensorAxis,
                              int64_t targetTensorAxis) {
-  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+  SmallVector<MeshAxesAttr> targetShardingSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
   while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
          targetTensorAxis) {
-    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+    targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
   }
 
   auto sourceSplitAxes =
@@ -488,13 +488,13 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
   auto meshAxis = sourceSplitAxes.back();
   sourceSplitAxes.pop_back();
   targetShardingSplitAxes[sourceTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+      MeshAxesAttr::get(ctx, sourceSplitAxes);
 
   auto targetSplitAxes =
       llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
   targetSplitAxes.push_back(meshAxis);
   targetShardingSplitAxes[targetTensorAxis] =
-      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+      MeshAxesAttr::get(ctx, targetSplitAxes);
 
   return MeshShardingAttr::get(
       ctx, sourceSharding.getCluster(), targetShardingSplitAxes,

@sogartar
Copy link
Contributor Author

sogartar commented Jan 3, 2024

@yaochengji, could you review this?

@sogartar sogartar requested a review from joker-eph January 3, 2024 16:03
@sogartar sogartar merged commit 7a4c497 into llvm:main Jan 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants