diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index 49ea6bb5f8614..4d49efecbe05c 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -232,14 +232,11 @@ struct AssignTileIDsPattern static_cast(getDiscardableIntAttr(kTilesInUseAttr)); auto tileId = allocateTileId(*tileType, tilesInUse); bool tileIsInMemory = failed(tileId); - if (!tileIsInMemory) - setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); - else { + if (tileIsInMemory) { // If we could not find a real tile ID, use an in-memory tile ID (ID >= // 16). A later pass will insert the necessary spills and reloads. tileId = getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase); - setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); tileOp->emitWarning( "failed to allocate SME virtual tile to operation, all tile " "operations will go through memory, expect degraded performance"); @@ -263,14 +260,25 @@ struct AssignTileIDsPattern SetVector dependantOps; findDependantOps(tileOp->getResult(0), dependantOps); auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId); - rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); for (auto *op : dependantOps) { if (auto dependantTileOp = llvm::dyn_cast(op)) { auto currentTileId = dependantTileOp.getTileId(); if (currentTileId && unsigned(currentTileId.getInt()) != tileId) return dependantTileOp.emitOpError( "already assigned different SME virtual tile!"); - dependantTileOp.setTileId(tileIDAttr); + } + } + + // Rewrite IR. + if (!tileIsInMemory) + setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); + else + setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); + rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); + for (auto *op : dependantOps) { + if (auto dependantTileOp = llvm::dyn_cast(op)) { + rewriter.updateRootInPlace( + dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); }); } } diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index bee1f3659753b..bf627d95ae557 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -106,7 +106,8 @@ struct RelaxScalableVectorAllocaAlignment // Set alignment based on the defaults for SVE vectors and predicates. unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16; - allocaOp.setAlignment(aligment); + rewriter.updateRootInPlace(allocaOp, + [&] { allocaOp.setAlignment(aligment); }); return success(); }