Skip to content

Commit fd8349c

Browse files
[mlir][Linalg] Move linalg.fill -> linalg.pack pattern into fill canonicalization patterns. (#66002)
This pattern fits better with the other canonicalization patterns that exist for `linalg.fill`.
1 parent 8bc676c commit fd8349c

File tree

4 files changed

+101
-110
lines changed

4 files changed

+101
-110
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,8 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
737737

738738
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
739739
PatternRewriter &rewriter) const override {
740-
// See if tensor input of tensor.extract op is the result of a linalg.fill op.
740+
// See if tensor input of tensor.extract op is the result of a linalg.fill
741+
// op.
741742
auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
742743
if (!fillOp)
743744
return failure();
@@ -751,15 +752,65 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
751752
}
752753
};
753754

755+
/// Folds pack(fill) into a single fill op if
756+
/// 1. The pack op does not have padding value, or
757+
/// 2. The filled value and padding value are the same.
758+
static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
759+
tensor::PackOp packOp) {
760+
auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
761+
if (!fillOp)
762+
return failure();
763+
764+
if (auto paddingValue = packOp.getPaddingValue())
765+
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
766+
return failure();
767+
768+
OpBuilder::InsertionGuard guard(rewriter);
769+
rewriter.setInsertionPoint(fillOp);
770+
771+
Value packOpDest = packOp.getDest();
772+
if (!packOpDest.hasOneUse())
773+
return failure();
774+
if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
775+
packOpDest = tensor::PackOp::createDestinationTensor(
776+
rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
777+
packOp.getMixedTiles(), packOp.getInnerDimsPos(),
778+
packOp.getOuterDimsPerm());
779+
} else {
780+
DominanceInfo dom(fillOp);
781+
if (!dom.properlyDominates(packOpDest, fillOp))
782+
return failure();
783+
}
784+
785+
Value fillDest = packOpDest;
786+
return clone(rewriter, fillOp, packOpDest.getType(),
787+
{fillOp.value(), fillDest});
788+
}
789+
790+
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
791+
struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
792+
public:
793+
FoldFillWithPack(MLIRContext *context)
794+
: OpRewritePattern<tensor::PackOp>(context) {}
795+
796+
LogicalResult matchAndRewrite(tensor::PackOp packOp,
797+
PatternRewriter &rewriter) const override {
798+
auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
799+
if (failed(fillOp))
800+
return failure();
801+
rewriter.replaceOp(packOp, fillOp.value().result());
802+
return success();
803+
}
804+
};
805+
754806
} // namespace
755807

756808
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
757809
MLIRContext *context) {
758-
results
759-
.add<FoldFillWithTensorExtract,
760-
FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
761-
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
762-
FoldInsertPadIntoFill>(context);
810+
results.add<FoldFillWithTensorExtract, FoldFillWithPack, FoldFillWithPad,
811+
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
812+
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
813+
FoldInsertPadIntoFill>(context);
763814
}
764815

765816
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -448,46 +448,6 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
448448
*packInfo);
449449
}
450450

451-
/// Folds pack(fill) into a single fill op if
452-
/// 1. The pack op does not have padding value, or
453-
/// 2. The filled value and padding value are the same.
454-
static FailureOr<FillOp>
455-
foldFillPackIntoFillOp(RewriterBase &rewriter, tensor::PackOp packOp,
456-
ControlPropagationFn controlFn) {
457-
auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
458-
if (!fillOp)
459-
return failure();
460-
461-
// User controlled propagation function.
462-
if (!controlFn(fillOp))
463-
return failure();
464-
465-
if (auto paddingValue = packOp.getPaddingValue())
466-
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
467-
return failure();
468-
469-
OpBuilder::InsertionGuard guard(rewriter);
470-
rewriter.setInsertionPoint(fillOp);
471-
472-
Value packOpDest = packOp.getDest();
473-
if (!packOpDest.hasOneUse())
474-
return failure();
475-
if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
476-
packOpDest = tensor::PackOp::createDestinationTensor(
477-
rewriter, fillOp.getLoc(), fillOp.getDpsInitOperand(0)->get(),
478-
packOp.getMixedTiles(), packOp.getInnerDimsPos(),
479-
packOp.getOuterDimsPerm());
480-
} else {
481-
DominanceInfo dom(fillOp);
482-
if (!dom.properlyDominates(packOpDest, fillOp))
483-
return failure();
484-
}
485-
486-
Value fillDest = packOpDest;
487-
return clone(rewriter, fillOp, packOpDest.getType(),
488-
{fillOp.value(), fillDest});
489-
}
490-
491451
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
492452
struct BubbleUpPackOpThroughGenericOpPattern
493453
: public OpRewritePattern<tensor::PackOp> {
@@ -510,25 +470,6 @@ struct BubbleUpPackOpThroughGenericOpPattern
510470
ControlPropagationFn controlFn;
511471
};
512472

513-
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
514-
struct FoldFillPackIntoFillOpPattern : public OpRewritePattern<tensor::PackOp> {
515-
public:
516-
FoldFillPackIntoFillOpPattern(MLIRContext *context, ControlPropagationFn fun)
517-
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
518-
519-
LogicalResult matchAndRewrite(tensor::PackOp packOp,
520-
PatternRewriter &rewriter) const override {
521-
auto fillOp = foldFillPackIntoFillOp(rewriter, packOp, controlFn);
522-
if (failed(fillOp))
523-
return failure();
524-
rewriter.replaceOp(packOp, fillOp.value().result());
525-
return success();
526-
}
527-
528-
private:
529-
ControlPropagationFn controlFn;
530-
};
531-
532473
// TODO: Relax this restriction. We should unpack a generic op also
533474
// in the presence of multiple unpack ops as producers.
534475
/// Return the unpacked operand, if present, for the current generic op.
@@ -750,7 +691,6 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
750691
RewritePatternSet &patterns,
751692
const ControlPropagationFn &controlPackUnPackPropagation) {
752693
patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
753-
FoldFillPackIntoFillOpPattern,
754694
PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
755695
patterns.getContext(), controlPackUnPackPropagation);
756696
}

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,50 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
353353

354354
// -----
355355

356+
func.func @fill_pack() -> tensor<24x32x16x16xf32> {
357+
%dest = tensor.empty() : tensor<384x512xf32>
358+
%cst = arith.constant 0.000000e+00 : f32
359+
%0 = tensor.empty() : tensor<24x32x16x16xf32>
360+
%1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
361+
%pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
362+
return %pack : tensor<24x32x16x16xf32>
363+
}
364+
// CHECK-LABEL: func.func @fill_pack
365+
// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
366+
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
367+
// CHECK: return %[[FILL]]
368+
369+
// -----
370+
371+
#map = affine_map<()[s0] -> (s0 ceildiv 16)>
372+
func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
373+
%cst = arith.constant 0.000000e+00 : f32
374+
%c0 = arith.constant 0 : index
375+
%c1 = arith.constant 1 : index
376+
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
377+
%dim = tensor.dim %0, %c0 : tensor<?x?xf32>
378+
%dim_0 = tensor.dim %0, %c1 : tensor<?x?xf32>
379+
%1 = affine.apply #map()[%dim]
380+
%2 = affine.apply #map()[%dim_0]
381+
%3 = tensor.empty(%1, %2) : tensor<?x?x16x16xf32>
382+
%pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor<?x?xf32> -> tensor<?x?x16x16xf32>
383+
return %pack : tensor<?x?x16x16xf32>
384+
}
385+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
386+
// CHECK: func.func @dynamic_fill_pack
387+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
388+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
389+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
390+
// CHECK: %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
391+
// CHECK: %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
392+
// CHECK: %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
393+
// CHECK: %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
394+
// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor<?x?x16x16xf32>
395+
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
396+
// CHECK: return %[[FILL]]
397+
398+
// -----
399+
356400
// CHECK: func @fold_self_copy
357401
func.func @fold_self_copy(%0 : memref<4x16xf32>) {
358402
// CHECK-NEXT: return

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -839,47 +839,3 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
839839
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
840840
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
841841
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
842-
843-
// -----
844-
845-
func.func @fill_pack() -> tensor<24x32x16x16xf32> {
846-
%dest = tensor.empty() : tensor<384x512xf32>
847-
%cst = arith.constant 0.000000e+00 : f32
848-
%0 = tensor.empty() : tensor<24x32x16x16xf32>
849-
%1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
850-
%pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
851-
return %pack : tensor<24x32x16x16xf32>
852-
}
853-
// CHECK-LABEL: func.func @fill_pack
854-
// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
855-
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
856-
// CHECK: return %[[FILL]]
857-
858-
// -----
859-
860-
#map = affine_map<()[s0] -> (s0 ceildiv 16)>
861-
func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
862-
%cst = arith.constant 0.000000e+00 : f32
863-
%c0 = arith.constant 0 : index
864-
%c1 = arith.constant 1 : index
865-
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
866-
%dim = tensor.dim %0, %c0 : tensor<?x?xf32>
867-
%dim_0 = tensor.dim %0, %c1 : tensor<?x?xf32>
868-
%1 = affine.apply #map()[%dim]
869-
%2 = affine.apply #map()[%dim_0]
870-
%3 = tensor.empty(%1, %2) : tensor<?x?x16x16xf32>
871-
%pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor<?x?xf32> -> tensor<?x?x16x16xf32>
872-
return %pack : tensor<?x?x16x16xf32>
873-
}
874-
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
875-
// CHECK: func.func @dynamic_fill_pack
876-
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
877-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
878-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
879-
// CHECK: %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
880-
// CHECK: %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
881-
// CHECK: %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
882-
// CHECK: %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
883-
// CHECK: %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor<?x?x16x16xf32>
884-
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
885-
// CHECK: return %[[FILL]]

0 commit comments

Comments
 (0)