Skip to content

Commit 9a9ca98

Browse files
authored
[mlir][MemRef] Add more ops to narrow type support, strided metadata expansion (#102228)
- Add support fef memory_space_cast to strided metadata expansion and narrow type emulation - Add support for expand_shape to narrow type emulation (like collapse_shape, it's a noop after linearization) and to expand-strided-metadata (mirroring the collapse_shape pattern) - Add support for memref.dealloc to narrow type emulation (it is a trivial rewrite) and for memref.copy (which is unsupported when it is used for a layout change but a trivial rewrite otherwise)
1 parent 9a666de commit 9a9ca98

File tree

4 files changed

+283
-3
lines changed

4 files changed

+283
-3
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,46 @@ struct ConvertMemRefAssumeAlignment final
234234
}
235235
};
236236

237+
//===----------------------------------------------------------------------===//
238+
// ConvertMemRefCopy
239+
//===----------------------------------------------------------------------===//
240+
241+
struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
242+
using OpConversionPattern::OpConversionPattern;
243+
244+
LogicalResult
245+
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
246+
ConversionPatternRewriter &rewriter) const override {
247+
auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
248+
auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
249+
if (maybeRankedSource && maybeRankedDest &&
250+
maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
251+
return rewriter.notifyMatchFailure(
252+
op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
253+
"and {1}) is currently unimplemented",
254+
maybeRankedSource.getLayout(),
255+
maybeRankedDest.getLayout()));
256+
rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
257+
adaptor.getTarget());
258+
return success();
259+
}
260+
};
261+
262+
//===----------------------------------------------------------------------===//
263+
// ConvertMemRefDealloc
264+
//===----------------------------------------------------------------------===//
265+
266+
struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
267+
using OpConversionPattern::OpConversionPattern;
268+
269+
LogicalResult
270+
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
271+
ConversionPatternRewriter &rewriter) const override {
272+
rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
273+
return success();
274+
}
275+
};
276+
237277
//===----------------------------------------------------------------------===//
238278
// ConvertMemRefLoad
239279
//===----------------------------------------------------------------------===//
@@ -300,6 +340,30 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
300340
}
301341
};
302342

343+
//===----------------------------------------------------------------------===//
344+
// ConvertMemRefMemorySpaceCast
345+
//===----------------------------------------------------------------------===//
346+
347+
struct ConvertMemRefMemorySpaceCast final
348+
: OpConversionPattern<memref::MemorySpaceCastOp> {
349+
using OpConversionPattern::OpConversionPattern;
350+
351+
LogicalResult
352+
matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
353+
ConversionPatternRewriter &rewriter) const override {
354+
Type newTy = getTypeConverter()->convertType(op.getDest().getType());
355+
if (!newTy) {
356+
return rewriter.notifyMatchFailure(
357+
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
358+
op.getDest().getType()));
359+
}
360+
361+
rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
362+
adaptor.getSource());
363+
return success();
364+
}
365+
};
366+
303367
//===----------------------------------------------------------------------===//
304368
// ConvertMemRefReinterpretCast
305369
//===----------------------------------------------------------------------===//
@@ -490,6 +554,28 @@ struct ConvertMemRefCollapseShape final
490554
}
491555
};
492556

557+
/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
558+
/// that we flatten memrefs to a single dimension as part of the emulation and
559+
/// the expansion would just have been undone.
560+
struct ConvertMemRefExpandShape final
561+
: OpConversionPattern<memref::ExpandShapeOp> {
562+
using OpConversionPattern::OpConversionPattern;
563+
564+
LogicalResult
565+
matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
566+
ConversionPatternRewriter &rewriter) const override {
567+
Value srcVal = adaptor.getSrc();
568+
auto newTy = dyn_cast<MemRefType>(srcVal.getType());
569+
if (!newTy)
570+
return failure();
571+
572+
if (newTy.getRank() != 1)
573+
return failure();
574+
575+
rewriter.replaceOp(expandShapeOp, srcVal);
576+
return success();
577+
}
578+
};
493579
} // end anonymous namespace
494580

495581
//===----------------------------------------------------------------------===//
@@ -502,9 +588,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
502588

503589
// Populate `memref.*` conversion patterns.
504590
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
505-
ConvertMemRefAllocation<memref::AllocaOp>,
506-
ConvertMemRefCollapseShape, ConvertMemRefLoad,
507-
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
591+
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
592+
ConvertMemRefDealloc, ConvertMemRefCollapseShape,
593+
ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
594+
ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
508595
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
509596
typeConverter, patterns.getContext());
510597
memref::populateResolveExtractStridedMetadataPatterns(patterns);

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,41 @@ struct ExtractStridedMetadataOpCollapseShapeFolder
726726
}
727727
};
728728

729+
/// Pattern to replace `extract_strided_metadata(expand_shape)`
730+
/// with the results of computing the sizes and strides on the expanded shape
731+
/// and dividing up dimensions into static and dynamic parts as needed.
732+
struct ExtractStridedMetadataOpExpandShapeFolder
733+
: OpRewritePattern<memref::ExtractStridedMetadataOp> {
734+
using OpRewritePattern::OpRewritePattern;
735+
736+
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
737+
PatternRewriter &rewriter) const override {
738+
auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
739+
if (!expandShapeOp)
740+
return failure();
741+
742+
FailureOr<StridedMetadata> stridedMetadata =
743+
resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
744+
rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
745+
if (failed(stridedMetadata)) {
746+
return rewriter.notifyMatchFailure(
747+
op, "failed to resolve metadata in terms of source expand_shape op");
748+
}
749+
750+
Location loc = expandShapeOp.getLoc();
751+
SmallVector<Value> results;
752+
results.push_back(stridedMetadata->basePtr);
753+
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
754+
stridedMetadata->offset));
755+
results.append(
756+
getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
757+
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
758+
stridedMetadata->strides));
759+
rewriter.replaceOp(op, results);
760+
return success();
761+
}
762+
};
763+
729764
/// Replace `base, offset, sizes, strides =
730765
/// extract_strided_metadata(allocLikeOp)`
731766
///
@@ -1060,6 +1095,54 @@ class ExtractStridedMetadataOpCastFolder
10601095
}
10611096
};
10621097

1098+
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
1099+
/// memory_space_cast(src) to dstTy)`
1100+
/// with
1101+
/// ```
1102+
/// oldBase, offset, sizes, strides = extract_strided_metadata(src)
1103+
/// destBaseTy = type(oldBase) with memory space from destTy
1104+
/// base = memory_space_cast(oldBase) to destBaseTy
1105+
/// ```
1106+
///
1107+
/// In other words, propagate metadata extraction accross memory space casts.
1108+
class ExtractStridedMetadataOpMemorySpaceCastFolder
1109+
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1110+
using OpRewritePattern::OpRewritePattern;
1111+
1112+
LogicalResult
1113+
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1114+
PatternRewriter &rewriter) const override {
1115+
Location loc = extractStridedMetadataOp.getLoc();
1116+
Value source = extractStridedMetadataOp.getSource();
1117+
auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
1118+
if (!memSpaceCastOp)
1119+
return failure();
1120+
auto newExtractStridedMetadata =
1121+
rewriter.create<memref::ExtractStridedMetadataOp>(
1122+
loc, memSpaceCastOp.getSource());
1123+
SmallVector<Value> results(newExtractStridedMetadata.getResults());
1124+
// As with most other strided metadata rewrite patterns, don't introduce
1125+
// a use of the base pointer where non existed. This needs to happen here,
1126+
// as opposed to in later dead-code elimination, because these patterns are
1127+
// sometimes used during dialect conversion (see EmulateNarrowType, for
1128+
// example), so adding spurious usages would cause a pre-legalization value
1129+
// to be live that would be dead had this pattern not run.
1130+
if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1131+
auto baseBuffer = results[0];
1132+
auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1133+
MemRefType::Builder newTypeBuilder(baseBufferType);
1134+
newTypeBuilder.setMemorySpace(
1135+
memSpaceCastOp.getResult().getType().getMemorySpace());
1136+
results[0] = rewriter.create<memref::MemorySpaceCastOp>(
1137+
loc, Type{newTypeBuilder}, baseBuffer);
1138+
} else {
1139+
results[0] = nullptr;
1140+
}
1141+
rewriter.replaceOp(extractStridedMetadataOp, results);
1142+
return success();
1143+
}
1144+
};
1145+
10631146
/// Replace `base, offset =
10641147
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
10651148
/// With
@@ -1099,11 +1182,13 @@ void memref::populateExpandStridedMetadataPatterns(
10991182
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
11001183
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
11011184
ExtractStridedMetadataOpCollapseShapeFolder,
1185+
ExtractStridedMetadataOpExpandShapeFolder,
11021186
ExtractStridedMetadataOpGetGlobalFolder,
11031187
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
11041188
ExtractStridedMetadataOpReinterpretCastFolder,
11051189
ExtractStridedMetadataOpSubviewFolder,
11061190
ExtractStridedMetadataOpCastFolder,
1191+
ExtractStridedMetadataOpMemorySpaceCastFolder,
11071192
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
11081193
patterns.getContext());
11091194
}
@@ -1113,11 +1198,13 @@ void memref::populateResolveExtractStridedMetadataPatterns(
11131198
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
11141199
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
11151200
ExtractStridedMetadataOpCollapseShapeFolder,
1201+
ExtractStridedMetadataOpExpandShapeFolder,
11161202
ExtractStridedMetadataOpGetGlobalFolder,
11171203
ExtractStridedMetadataOpSubviewFolder,
11181204
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
11191205
ExtractStridedMetadataOpReinterpretCastFolder,
11201206
ExtractStridedMetadataOpCastFolder,
1207+
ExtractStridedMetadataOpMemorySpaceCastFolder,
11211208
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
11221209
patterns.getContext());
11231210
}

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ func.func @memref_i8() -> i8 {
66
%c3 = arith.constant 3 : index
77
%m = memref.alloc() : memref<4xi8, 1>
88
%v = memref.load %m[%c3] : memref<4xi8, 1>
9+
memref.dealloc %m : memref<4xi8, 1>
910
return %v : i8
1011
}
1112
// CHECK-LABEL: func @memref_i8()
1213
// CHECK: %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
1314
// CHECK-NEXT: %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
15+
// CHECK-NEXT: memref.dealloc %[[M]]
1416
// CHECK-NEXT: return %[[V]]
1517

1618
// CHECK32-LABEL: func @memref_i8()
@@ -21,6 +23,7 @@ func.func @memref_i8() -> i8 {
2123
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
2224
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
2325
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
26+
// CHECK32-NEXT: memref.dealloc %[[M]]
2427
// CHECK32-NEXT: return %[[TRUNC]]
2528

2629
// -----
@@ -485,3 +488,68 @@ func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
485488
// CHECK32-NOT: memref.collapse_shape
486489
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
487490

491+
// -----
492+
493+
func.func @memref_expand_shape_i4(%idx0 : index, %idx1 : index, %idx2 : index) -> i4 {
494+
%arr = memref.alloc() : memref<256x128xi4>
495+
%expand = memref.expand_shape %arr[[0, 1], [2]] output_shape [32, 8, 128] : memref<256x128xi4> into memref<32x8x128xi4>
496+
%1 = memref.load %expand[%idx0, %idx1, %idx2] : memref<32x8x128xi4>
497+
return %1 : i4
498+
}
499+
500+
// CHECK-LABEL: func.func @memref_expand_shape_i4(
501+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
502+
// CHECK-NOT: memref.expand_shape
503+
// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
504+
505+
// CHECK32-LABEL: func.func @memref_expand_shape_i4(
506+
// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
507+
// CHECK32-NOT: memref.expand_shape
508+
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
509+
510+
// -----
511+
512+
func.func @memref_memory_space_cast_i4(%arg0: memref<32x128xi4, 1>) -> memref<32x128xi4> {
513+
%cast = memref.memory_space_cast %arg0 : memref<32x128xi4, 1> to memref<32x128xi4>
514+
return %cast : memref<32x128xi4>
515+
}
516+
517+
// CHECK-LABEL: func.func @memref_memory_space_cast_i4(
518+
// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>
519+
// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<2048xi8, 1> to memref<2048xi8>
520+
// CHECK: return %[[CAST]]
521+
522+
// CHECK32-LABEL: func.func @memref_memory_space_cast_i4(
523+
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>
524+
// CHECK32: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<512xi32, 1> to memref<512xi32>
525+
// CHECK32: return %[[CAST]]
526+
527+
// -----
528+
529+
func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) {
530+
memref.copy %arg0, %arg1 : memref<32x128xi4, 1> to memref<32x128xi4>
531+
return
532+
}
533+
534+
// CHECK-LABEL: func.func @memref_copy_i4(
535+
// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>, %[[ARG1:.*]]: memref<2048xi8>
536+
// CHECK: memref.copy %[[ARG0]], %[[ARG1]]
537+
// CHECK: return
538+
539+
// CHECK32-LABEL: func.func @memref_copy_i4(
540+
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
541+
// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
542+
// CHECK32: return
543+
544+
// -----
545+
546+
!colMajor = memref<8x8xi4, strided<[1, 8]>>
547+
func.func @copy_distinct_layouts(%idx : index) -> i4 {
548+
%c0 = arith.constant 0 : index
549+
%arr = memref.alloc() : memref<8x8xi4>
550+
%arr2 = memref.alloc() : !colMajor
551+
// expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
552+
memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
553+
%ld = memref.load %arr2[%c0, %c0] : !colMajor
554+
return %ld : i4
555+
}

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,3 +1553,41 @@ func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
15531553
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
15541554
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
15551555
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index
1556+
1557+
// -----
1558+
1559+
func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>)
1560+
-> (memref<f32, 1>, index, index, index) {
1561+
1562+
%memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
1563+
1564+
%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
1565+
memref<20xf32, 1> -> memref<f32, 1>, index, index, index
1566+
1567+
return %base_buffer, %offset, %size, %stride :
1568+
memref<f32, 1>, index, index, index
1569+
}
1570+
1571+
// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast
1572+
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
1573+
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index
1574+
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
1575+
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
1576+
// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[BASE]]
1577+
// CHECK: return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32, 1>, index, index, index
1578+
1579+
// -----
1580+
1581+
func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>)
1582+
-> (index, index, index) {
1583+
1584+
%memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
1585+
1586+
%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
1587+
memref<20xf32, 1> -> memref<f32, 1>, index, index, index
1588+
1589+
return %offset, %size, %stride : index, index, index
1590+
}
1591+
1592+
// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
1593+
// CHECK-NOT: memref.memory_space_cast

0 commit comments

Comments
 (0)