Skip to content

Commit 9144fed

Browse files
authored
[mlir] Add option for a cleanup pattern set to SCF tiling helper (llvm#109554)
The SCF helper for tiling an operation implementing the TilingInterface and greedily fusing consumers requires an uninterrupted chain of operations implementing the tiling interface to succeed. There can be cases with intermediate ops that don't implement the interface but have producers that could be fused if various canonicalization/simplification patterns could run in between fusion steps. This adds an option to SCFTileAndFuseOptions for a pattern set to run between fusion steps to the ops that result from fusion/tiling. Removed and newly inserted slices are tracked for continued fusion applications. See this RFC for more discussion: https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155
1 parent f74879c commit 9144fed

File tree

5 files changed

+252
-28
lines changed

5 files changed

+252
-28
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,18 +295,23 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
295295
let description = [{
296296
Tiles the operations pointed to by the target handle and fuses their
297297
producers greedily using the options provided as attributes.
298+
299+
If `apply_cleanup` is true then slice canonicalization is applied between
300+
fusion steps.
298301
}];
299302

300303
let arguments =
301304
(ins TransformHandleTypeInterface:$target,
302305
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
303-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
306+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
307+
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
304308
let results = (outs TransformHandleTypeInterface:$transformed,
305309
Variadic<TransformHandleTypeInterface>:$loops);
306310

307311
let assemblyFormat = [{
308312
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
309-
attr-dict `:` functional-type(operands, results)
313+
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
314+
`:` functional-type(operands, results)
310315
}];
311316
let hasVerifier = 1;
312317
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Interfaces/LoopLikeInterface.h"
1616
#include "mlir/Interfaces/TilingInterface.h"
1717
#include "mlir/Interfaces/ViewLikeInterface.h"
18+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
1819

1920
#include <deque>
2021

@@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions {
153154
fusionControlFn = controlFn;
154155
return *this;
155156
}
157+
158+
/// An optional set of rewrite patterns to apply to the results of tiling
159+
/// before fusion. This will track deleted and newly inserted
160+
/// `tensor.extract_slice` ops and update the worklist.
161+
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
156162
};
157163

158164
/// Fuse the producer of the source of `candidateSliceOp` by computing the

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
562562
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
563563
scf::SCFTileAndFuseOptions tileAndFuseOptions;
564564
tileAndFuseOptions.tilingOptions = tilingOptions;
565+
566+
if (getApplyCleanup()) {
567+
MLIRContext *context = rewriter.getContext();
568+
RewritePatternSet patterns(context);
569+
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
570+
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
571+
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
572+
}
573+
565574
LogicalResult result = applyTilingToAll(
566575
rewriter, getOperation(), state.getPayloadOps(getTarget()),
567576
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 130 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "mlir/IR/PatternMatch.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2626
#include "mlir/Interfaces/TilingInterface.h"
27+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2729
#include "llvm/ADT/TypeSwitch.h"
2830
#include "llvm/Support/Debug.h"
2931
#include <optional>
@@ -1315,6 +1317,104 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13151317
return generatedSlices;
13161318
}
13171319

1320+
namespace {
1321+
1322+
//===----------------------------------------------------------------------===//
1323+
// SliceTrackingListener
1324+
//===----------------------------------------------------------------------===//
1325+
1326+
/// This class is a listener for tracking the insertion and removal of
1327+
/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1328+
/// fusion algorithm to apply cleanup patterns in between fusion steps.
1329+
class SliceTrackingListener : public RewriterBase::Listener {
1330+
public:
1331+
explicit SliceTrackingListener(
1332+
std::optional<FrozenRewritePatternSet> patterns);
1333+
SliceTrackingListener() = default;
1334+
1335+
/// Adds the given list of operations to the worklist, and if present, applies
1336+
/// the list of `patterns` to the newly added operations. This only processes
1337+
/// the given operations and any newly inserted ones by the pattern set.
1338+
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1339+
1340+
/// Add to the new operation worklist if it is an extract_slice.
1341+
void notifyOperationInserted(Operation *op,
1342+
OpBuilder::InsertPoint previous) override;
1343+
1344+
/// Shared helper for operation removal from the worklist.
1345+
void removeOp(Operation *op);
1346+
1347+
/// Remove the operation from the worklist.
1348+
void notifyOperationErased(Operation *op) override;
1349+
1350+
/// Remove the operation from the worklist.
1351+
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1352+
1353+
/// The worklist for this transformation keeps track of the slices to visit
1354+
/// next for fusion.
1355+
std::deque<tensor::ExtractSliceOp> worklist;
1356+
1357+
private:
1358+
/// Optional pattern set to apply when adding new operations to the worklist.
1359+
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1360+
};
1361+
1362+
SliceTrackingListener::SliceTrackingListener(
1363+
std::optional<FrozenRewritePatternSet> p) {
1364+
patterns = std::move(p);
1365+
}
1366+
1367+
LogicalResult
1368+
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1369+
for (Operation *op : ops) {
1370+
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371+
worklist.push_back(slice);
1372+
}
1373+
1374+
if (!patterns)
1375+
return success();
1376+
1377+
GreedyRewriteConfig config;
1378+
config.listener = this;
1379+
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1380+
return applyOpPatternsAndFold(ops, patterns.value(), config);
1381+
}
1382+
1383+
void SliceTrackingListener::notifyOperationInserted(
1384+
Operation *op, OpBuilder::InsertPoint previous) {
1385+
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386+
if (!slice)
1387+
return;
1388+
worklist.push_back(slice);
1389+
}
1390+
1391+
// Scan the worklist for the given op and remove it if present. The expectation
1392+
// is for the worklist to be small and for removal to be relatively rare.
1393+
void SliceTrackingListener::removeOp(Operation *op) {
1394+
if (!isa<tensor::ExtractSliceOp>(op))
1395+
return;
1396+
auto iter = worklist.begin();
1397+
while (iter != worklist.end()) {
1398+
if (*iter == op)
1399+
break;
1400+
iter++;
1401+
}
1402+
if (iter == worklist.end())
1403+
return;
1404+
1405+
worklist.erase(iter);
1406+
}
1407+
1408+
void SliceTrackingListener::notifyOperationErased(Operation *op) {
1409+
removeOp(op);
1410+
}
1411+
1412+
void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1413+
ValueRange replacement) {
1414+
removeOp(op);
1415+
}
1416+
} // namespace
1417+
13181418
/// Implementation of tile consumer and fuse producer greedily.
13191419
FailureOr<scf::SCFTileAndFuseResult>
13201420
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
@@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
13701470
tensor::ExtractSliceOp candidateSlice;
13711471
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
13721472
};
1373-
std::deque<WorklistItem> worklist;
1374-
auto addCandidateSlices = [&worklist, &options,
1375-
&loops](ArrayRef<Operation *> candidates) {
1376-
for (auto candidate : candidates) {
1377-
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378-
if (!sliceOp || sliceOp.use_empty())
1379-
continue;
13801473

1381-
auto [fusableProducer, destinationInitArg] =
1382-
getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
1383-
if (!fusableProducer)
1384-
continue;
1385-
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386-
options.fusionControlFn(sliceOp, fusableProducer,
1387-
destinationInitArg.has_value());
1388-
if (!controlFnResult)
1389-
continue;
1390-
worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
1391-
}
1392-
};
1474+
SliceTrackingListener sliceTracker =
1475+
SliceTrackingListener(options.cleanupPatterns);
13931476

1394-
addCandidateSlices(tilingResult->generatedSlices);
1477+
if (failed(
1478+
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1479+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1480+
}
13951481
OpBuilder::InsertionGuard g(rewriter);
1396-
while (!worklist.empty()) {
1397-
// Traverse the slices in BFS fashion.
1398-
WorklistItem worklistItem = worklist.front();
1399-
worklist.pop_front();
1482+
while (!sliceTracker.worklist.empty()) {
1483+
auto candidateSlice = sliceTracker.worklist.front();
1484+
sliceTracker.worklist.pop_front();
1485+
1486+
auto [fusableProducer, destinationInitArg] =
1487+
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1488+
loops);
1489+
if (!fusableProducer)
1490+
continue;
1491+
1492+
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1493+
options.fusionControlFn(candidateSlice, fusableProducer,
1494+
destinationInitArg.has_value());
1495+
if (!controlFnResult)
1496+
continue;
1497+
1498+
WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
14001499

14011500
// The operands of the fused producer might themselved be slices of
14021501
// values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14071506
if (!fusedResult)
14081507
continue;
14091508

1509+
SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1510+
14101511
if (worklistItem.controlFnResult.yieldProducerReplacement) {
14111512
// Reconstruct and yield all opResult of fusableProducerOp by default. The
14121513
// caller can specific which one to yield by designating optional argument
@@ -1421,20 +1522,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14211522
fusableProducerOp, "failed to replacement value for this "
14221523
"operation from within the tiled loop");
14231524
}
1424-
addCandidateSlices(newSlices.value());
1525+
worklistCandidates.append(newSlices.value());
14251526
for (auto [index, result] :
14261527
llvm::enumerate(fusableProducerOp->getResults())) {
14271528
origValToResultNumber[result] = loops.front()->getNumResults() -
14281529
fusableProducerOp->getNumResults() +
14291530
index;
14301531
}
14311532
}
1432-
addCandidateSlices(fusedResult->generatedSlices);
14331533
if (Operation *tiledAndFusedOp =
14341534
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
14351535
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
14361536
tiledAndFusedOps.insert(tiledAndFusedOp);
14371537
}
1538+
1539+
if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1540+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1541+
}
14381542
}
14391543

14401544
DenseMap<Value, Value> replacements;

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,103 @@ module attributes {transform.with_named_sequence} {
178178
transform.yield
179179
}
180180
}
181+
182+
// -----
183+
184+
// CHECK-LABEL: func.func @fuse_through_slice
185+
func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
186+
187+
// CHECK: %[[RES:.*]] = scf.for
188+
// CHECK: scf.for
189+
// CHECK: linalg.elemwise_unary
190+
// CHECK: linalg.elemwise_binary
191+
// CHECK: return %[[RES]]
192+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
193+
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
194+
%c0 = arith.constant 0 : index
195+
%c1 = arith.constant 1 : index
196+
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
197+
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
198+
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
199+
%2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
200+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
201+
return %2 : tensor<?x?xf32>
202+
}
203+
204+
module attributes {transform.with_named_sequence} {
205+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
206+
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
207+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
208+
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
209+
transform.yield
210+
}
211+
}
212+
213+
// -----
214+
215+
// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain
216+
func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
217+
218+
// CHECK: %[[RES:.*]] = scf.for
219+
// CHECK: scf.for
220+
// CHECK: linalg.elemwise_unary
221+
// CHECK: linalg.elemwise_binary
222+
// CHECK: return %[[RES]]
223+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>)
224+
outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32>
225+
%1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32>
226+
%2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32>
227+
%3 = tensor.cast %2 : tensor<98x98xf32> to tensor<?x?xf32>
228+
%c0 = arith.constant 0 : index
229+
%c1 = arith.constant 1 : index
230+
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
231+
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
232+
%4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
233+
%5 = linalg.elemwise_binary ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
234+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
235+
return %5 : tensor<?x?xf32>
236+
}
237+
238+
module attributes {transform.with_named_sequence} {
239+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
240+
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
241+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
242+
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
243+
transform.yield
244+
}
245+
}
246+
247+
// -----
248+
249+
// CHECK-LABEL: func.func @fuse_unrelated_slice
250+
func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<10x10xf32>) {
251+
252+
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice
253+
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]]
254+
// CHECK: %[[RES:.*]] = scf.for
255+
// CHECK: scf.for
256+
// CHECK: linalg.elemwise_unary
257+
// CHECK: linalg.elemwise_binary
258+
// CHECK: return %[[RES]], %[[SLICE2]]
259+
%c0 = arith.constant 0 : index
260+
%c1 = arith.constant 1 : index
261+
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
262+
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
263+
%slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
264+
%slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32>
265+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
266+
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
267+
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
268+
%2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
269+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
270+
return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32>
271+
}
272+
273+
module attributes {transform.with_named_sequence} {
274+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
275+
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
276+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
277+
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
278+
transform.yield
279+
}
280+
}

0 commit comments

Comments
 (0)