24
24
#include " mlir/IR/PatternMatch.h"
25
25
#include " mlir/Interfaces/DestinationStyleOpInterface.h"
26
26
#include " mlir/Interfaces/TilingInterface.h"
27
+ #include " mlir/Rewrite/FrozenRewritePatternSet.h"
28
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27
29
#include " llvm/ADT/TypeSwitch.h"
28
30
#include " llvm/Support/Debug.h"
29
31
#include < optional>
@@ -1315,6 +1317,104 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1315
1317
return generatedSlices;
1316
1318
}
1317
1319
1320
+ namespace {
1321
+
1322
+ // ===----------------------------------------------------------------------===//
1323
+ // SliceTrackingListener
1324
+ // ===----------------------------------------------------------------------===//
1325
+
1326
+ // / This class is a listener for tracking the insertion and removal of
1327
+ // / `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1328
+ // / fusion algorithm to apply cleanup patterns in between fusion steps.
1329
+ class SliceTrackingListener : public RewriterBase ::Listener {
1330
+ public:
1331
+ explicit SliceTrackingListener (
1332
+ std::optional<FrozenRewritePatternSet> patterns);
1333
+ SliceTrackingListener () = default ;
1334
+
1335
+ // / Adds the given list of operations to the worklist, and if present, applies
1336
+ // / the list of `patterns` to the newly added operations. This only processes
1337
+ // / the given operations and any newly inserted ones by the pattern set.
1338
+ LogicalResult insertAndApplyPatterns (ArrayRef<Operation *> newOps);
1339
+
1340
+ // / Add to the new operation worklist if it is an extract_slice.
1341
+ void notifyOperationInserted (Operation *op,
1342
+ OpBuilder::InsertPoint previous) override ;
1343
+
1344
+ // / Shared helper for operation removal from the worklist.
1345
+ void removeOp (Operation *op);
1346
+
1347
+ // / Remove the operation from the worklist.
1348
+ void notifyOperationErased (Operation *op) override ;
1349
+
1350
+ // / Remove the operation from the worklist.
1351
+ void notifyOperationReplaced (Operation *op, ValueRange replacement) override ;
1352
+
1353
+ // / The worklist for this transformation keeps track of the slices to visit
1354
+ // / next for fusion.
1355
+ std::deque<tensor::ExtractSliceOp> worklist;
1356
+
1357
+ private:
1358
+ // / Optional pattern set to apply when adding new operations to the worklist.
1359
+ std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1360
+ };
1361
+
1362
+ SliceTrackingListener::SliceTrackingListener (
1363
+ std::optional<FrozenRewritePatternSet> p) {
1364
+ patterns = std::move (p);
1365
+ }
1366
+
1367
+ LogicalResult
1368
+ SliceTrackingListener::insertAndApplyPatterns (ArrayRef<Operation *> ops) {
1369
+ for (Operation *op : ops) {
1370
+ if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371
+ worklist.push_back (slice);
1372
+ }
1373
+
1374
+ if (!patterns)
1375
+ return success ();
1376
+
1377
+ GreedyRewriteConfig config;
1378
+ config.listener = this ;
1379
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1380
+ return applyOpPatternsAndFold (ops, patterns.value (), config);
1381
+ }
1382
+
1383
+ void SliceTrackingListener::notifyOperationInserted (
1384
+ Operation *op, OpBuilder::InsertPoint previous) {
1385
+ auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386
+ if (!slice)
1387
+ return ;
1388
+ worklist.push_back (slice);
1389
+ }
1390
+
1391
+ // Scan the worklist for the given op and remove it if present. The expectation
1392
+ // is for the worklist to be small and for removal to be relatively rare.
1393
+ void SliceTrackingListener::removeOp (Operation *op) {
1394
+ if (!isa<tensor::ExtractSliceOp>(op))
1395
+ return ;
1396
+ auto iter = worklist.begin ();
1397
+ while (iter != worklist.end ()) {
1398
+ if (*iter == op)
1399
+ break ;
1400
+ iter++;
1401
+ }
1402
+ if (iter == worklist.end ())
1403
+ return ;
1404
+
1405
+ worklist.erase (iter);
1406
+ }
1407
+
1408
+ void SliceTrackingListener::notifyOperationErased (Operation *op) {
1409
+ removeOp (op);
1410
+ }
1411
+
1412
+ void SliceTrackingListener::notifyOperationReplaced (Operation *op,
1413
+ ValueRange replacement) {
1414
+ removeOp (op);
1415
+ }
1416
+ } // namespace
1417
+
1318
1418
// / Implementation of tile consumer and fuse producer greedily.
1319
1419
FailureOr<scf::SCFTileAndFuseResult>
1320
1420
mlir::scf::tileConsumerAndFuseProducersUsingSCF (
@@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1370
1470
tensor::ExtractSliceOp candidateSlice;
1371
1471
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1372
1472
};
1373
- std::deque<WorklistItem> worklist;
1374
- auto addCandidateSlices = [&worklist, &options,
1375
- &loops](ArrayRef<Operation *> candidates) {
1376
- for (auto candidate : candidates) {
1377
- auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378
- if (!sliceOp || sliceOp.use_empty ())
1379
- continue ;
1380
1473
1381
- auto [fusableProducer, destinationInitArg] =
1382
- getUntiledProducerFromSliceSource (&sliceOp.getSourceMutable (), loops);
1383
- if (!fusableProducer)
1384
- continue ;
1385
- std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386
- options.fusionControlFn (sliceOp, fusableProducer,
1387
- destinationInitArg.has_value ());
1388
- if (!controlFnResult)
1389
- continue ;
1390
- worklist.emplace_back (WorklistItem{sliceOp, controlFnResult.value ()});
1391
- }
1392
- };
1474
+ SliceTrackingListener sliceTracker =
1475
+ SliceTrackingListener (options.cleanupPatterns );
1393
1476
1394
- addCandidateSlices (tilingResult->generatedSlices );
1477
+ if (failed (
1478
+ sliceTracker.insertAndApplyPatterns (tilingResult->generatedSlices ))) {
1479
+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1480
+ }
1395
1481
OpBuilder::InsertionGuard g (rewriter);
1396
- while (!worklist.empty ()) {
1397
- // Traverse the slices in BFS fashion.
1398
- WorklistItem worklistItem = worklist.front ();
1399
- worklist.pop_front ();
1482
+ while (!sliceTracker.worklist .empty ()) {
1483
+ auto candidateSlice = sliceTracker.worklist .front ();
1484
+ sliceTracker.worklist .pop_front ();
1485
+
1486
+ auto [fusableProducer, destinationInitArg] =
1487
+ getUntiledProducerFromSliceSource (&candidateSlice.getSourceMutable (),
1488
+ loops);
1489
+ if (!fusableProducer)
1490
+ continue ;
1491
+
1492
+ std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1493
+ options.fusionControlFn (candidateSlice, fusableProducer,
1494
+ destinationInitArg.has_value ());
1495
+ if (!controlFnResult)
1496
+ continue ;
1497
+
1498
+ WorklistItem worklistItem = {candidateSlice, controlFnResult.value ()};
1400
1499
1401
1500
// The operands of the fused producer might themselved be slices of
1402
1501
// values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1407
1506
if (!fusedResult)
1408
1507
continue ;
1409
1508
1509
+ SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices ;
1510
+
1410
1511
if (worklistItem.controlFnResult .yieldProducerReplacement ) {
1411
1512
// Reconstruct and yield all opResult of fusableProducerOp by default. The
1412
1513
// caller can specific which one to yield by designating optional argument
@@ -1421,20 +1522,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1421
1522
fusableProducerOp, " failed to replacement value for this "
1422
1523
" operation from within the tiled loop" );
1423
1524
}
1424
- addCandidateSlices (newSlices.value ());
1525
+ worklistCandidates. append (newSlices.value ());
1425
1526
for (auto [index , result] :
1426
1527
llvm::enumerate (fusableProducerOp->getResults ())) {
1427
1528
origValToResultNumber[result] = loops.front ()->getNumResults () -
1428
1529
fusableProducerOp->getNumResults () +
1429
1530
index ;
1430
1531
}
1431
1532
}
1432
- addCandidateSlices (fusedResult->generatedSlices );
1433
1533
if (Operation *tiledAndFusedOp =
1434
1534
fusedResult->tiledAndFusedProducer .getDefiningOp ()) {
1435
1535
fusedProducers.insert (fusedResult->origProducer .getDefiningOp ());
1436
1536
tiledAndFusedOps.insert (tiledAndFusedOp);
1437
1537
}
1538
+
1539
+ if (failed (sliceTracker.insertAndApplyPatterns (worklistCandidates))) {
1540
+ return rewriter.notifyMatchFailure (consumer, " cleanup patterns failed" );
1541
+ }
1438
1542
}
1439
1543
1440
1544
DenseMap<Value, Value> replacements;
0 commit comments