diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 88d56a8fbec74..a45b79194a758 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -234,6 +234,46 @@ struct ConvertMemRefAssumeAlignment final } }; +//===----------------------------------------------------------------------===// +// ConvertMemRefCopy +//===----------------------------------------------------------------------===// + +struct ConvertMemRefCopy final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto maybeRankedSource = dyn_cast(op.getSource().getType()); + auto maybeRankedDest = dyn_cast(op.getTarget().getType()); + if (maybeRankedSource && maybeRankedDest && + maybeRankedSource.getLayout() != maybeRankedDest.getLayout()) + return rewriter.notifyMatchFailure( + op, llvm::formatv("memref.copy emulation with distinct layouts ({0} " + "and {1}) is currently unimplemented", + maybeRankedSource.getLayout(), + maybeRankedDest.getLayout())); + rewriter.replaceOpWithNewOp(op, adaptor.getSource(), + adaptor.getTarget()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefDealloc +//===----------------------------------------------------------------------===// + +struct ConvertMemRefDealloc final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getMemref()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertMemRefLoad //===----------------------------------------------------------------------===// @@ -300,6 +340,30 @@ struct ConvertMemRefLoad final : OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// ConvertMemRefMemorySpaceCast +//===----------------------------------------------------------------------===// + +struct ConvertMemRefMemorySpaceCast final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getDest().getType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getDest().getType())); + } + + rewriter.replaceOpWithNewOp(op, newTy, + adaptor.getSource()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertMemRefReinterpretCast //===----------------------------------------------------------------------===// @@ -490,6 +554,28 @@ struct ConvertMemRefCollapseShape final } }; +/// Emulating a `memref.expand_shape` becomes a no-op after emulation given +/// that we flatten memrefs to a single dimension as part of the emulation and +/// the expansion would just have been undone. +struct ConvertMemRefExpandShape final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value srcVal = adaptor.getSrc(); + auto newTy = dyn_cast(srcVal.getType()); + if (!newTy) + return failure(); + + if (newTy.getRank() != 1) + return failure(); + + rewriter.replaceOp(expandShapeOp, srcVal); + return success(); + } +}; } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -502,9 +588,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( // Populate `memref.*` conversion patterns. patterns.add, - ConvertMemRefAllocation, - ConvertMemRefCollapseShape, ConvertMemRefLoad, - ConvertMemrefStore, ConvertMemRefAssumeAlignment, + ConvertMemRefAllocation, ConvertMemRefCopy, + ConvertMemRefDealloc, ConvertMemRefCollapseShape, + ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore, + ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast, ConvertMemRefSubview, ConvertMemRefReinterpretCast>( typeConverter, patterns.getContext()); memref::populateResolveExtractStridedMetadataPatterns(patterns); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 585c5b7381421..a2049ba4a4924 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -726,6 +726,41 @@ struct ExtractStridedMetadataOpCollapseShapeFolder } }; +/// Pattern to replace `extract_strided_metadata(expand_shape)` +/// with the results of computing the sizes and strides on the expanded shape +/// and dividing up dimensions into static and dynamic parts as needed. +struct ExtractStridedMetadataOpExpandShapeFolder + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto expandShapeOp = op.getSource().getDefiningOp(); + if (!expandShapeOp) + return failure(); + + FailureOr stridedMetadata = + resolveReshapeStridedMetadata( + rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure( + op, "failed to resolve metadata in terms of source expand_shape op"); + } + + Location loc = expandShapeOp.getLoc(); + SmallVector results; + results.push_back(stridedMetadata->basePtr); + results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->offset)); + results.append( + getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); + results.append(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->strides)); + rewriter.replaceOp(op, results); + return success(); + } +}; + /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(allocLikeOp)` /// @@ -1060,6 +1095,54 @@ class ExtractStridedMetadataOpCastFolder } }; +/// Replace `base, offset, sizes, strides = extract_strided_metadata( +/// memory_space_cast(src) to dstTy)` +/// with +/// ``` +/// oldBase, offset, sizes, strides = extract_strided_metadata(src) +/// destBaseTy = type(oldBase) with memory space from destTy +/// base = memory_space_cast(oldBase) to destBaseTy +/// ``` +/// +/// In other words, propagate metadata extraction accross memory space casts. +class ExtractStridedMetadataOpMemorySpaceCastFolder + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + PatternRewriter &rewriter) const override { + Location loc = extractStridedMetadataOp.getLoc(); + Value source = extractStridedMetadataOp.getSource(); + auto memSpaceCastOp = source.getDefiningOp(); + if (!memSpaceCastOp) + return failure(); + auto newExtractStridedMetadata = + rewriter.create( + loc, memSpaceCastOp.getSource()); + SmallVector results(newExtractStridedMetadata.getResults()); + // As with most other strided metadata rewrite patterns, don't introduce + // a use of the base pointer where non existed. This needs to happen here, + // as opposed to in later dead-code elimination, because these patterns are + // sometimes used during dialect conversion (see EmulateNarrowType, for + // example), so adding spurious usages would cause a pre-legalization value + // to be live that would be dead had this pattern not run. + if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) { + auto baseBuffer = results[0]; + auto baseBufferType = cast(baseBuffer.getType()); + MemRefType::Builder newTypeBuilder(baseBufferType); + newTypeBuilder.setMemorySpace( + memSpaceCastOp.getResult().getType().getMemorySpace()); + results[0] = rewriter.create( + loc, Type{newTypeBuilder}, baseBuffer); + } else { + results[0] = nullptr; + } + rewriter.replaceOp(extractStridedMetadataOp, results); + return success(); + } +}; + /// Replace `base, offset = /// extract_strided_metadata(extract_strided_metadata(src)#0)` /// With @@ -1099,11 +1182,13 @@ void memref::populateExpandStridedMetadataPatterns( ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpCollapseShapeFolder, + ExtractStridedMetadataOpExpandShapeFolder, ExtractStridedMetadataOpGetGlobalFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpSubviewFolder, ExtractStridedMetadataOpCastFolder, + ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } @@ -1113,11 +1198,13 @@ void memref::populateResolveExtractStridedMetadataPatterns( patterns.add, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpCollapseShapeFolder, + ExtractStridedMetadataOpExpandShapeFolder, ExtractStridedMetadataOpGetGlobalFolder, ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpCastFolder, + ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index a67237b5e4dd1..540da239fced0 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -6,11 +6,13 @@ func.func @memref_i8() -> i8 { %c3 = arith.constant 3 : index %m = memref.alloc() : memref<4xi8, 1> %v = memref.load %m[%c3] : memref<4xi8, 1> + memref.dealloc %m : memref<4xi8, 1> return %v : i8 } // CHECK-LABEL: func @memref_i8() // CHECK: %[[M:.+]] = memref.alloc() : memref<4xi8, 1> // CHECK-NEXT: %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1> +// CHECK-NEXT: memref.dealloc %[[M]] // CHECK-NEXT: return %[[V]] // CHECK32-LABEL: func @memref_i8() @@ -21,6 +23,7 @@ func.func @memref_i8() -> i8 { // CHECK32: %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32 // CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]] // CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8 +// CHECK32-NEXT: memref.dealloc %[[M]] // CHECK32-NEXT: return %[[TRUNC]] // ----- @@ -485,3 +488,68 @@ func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 { // CHECK32-NOT: memref.collapse_shape // CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32> +// ----- + +func.func @memref_expand_shape_i4(%idx0 : index, %idx1 : index, %idx2 : index) -> i4 { + %arr = memref.alloc() : memref<256x128xi4> + %expand = memref.expand_shape %arr[[0, 1], [2]] output_shape [32, 8, 128] : memref<256x128xi4> into memref<32x8x128xi4> + %1 = memref.load %expand[%idx0, %idx1, %idx2] : memref<32x8x128xi4> + return %1 : i4 +} + +// CHECK-LABEL: func.func @memref_expand_shape_i4( +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8> +// CHECK-NOT: memref.expand_shape +// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8> + +// CHECK32-LABEL: func.func @memref_expand_shape_i4( +// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32> +// CHECK32-NOT: memref.expand_shape +// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32> + +// ----- + +func.func @memref_memory_space_cast_i4(%arg0: memref<32x128xi4, 1>) -> memref<32x128xi4> { + %cast = memref.memory_space_cast %arg0 : memref<32x128xi4, 1> to memref<32x128xi4> + return %cast : memref<32x128xi4> +} + +// CHECK-LABEL: func.func @memref_memory_space_cast_i4( +// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1> +// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<2048xi8, 1> to memref<2048xi8> +// CHECK: return %[[CAST]] + +// CHECK32-LABEL: func.func @memref_memory_space_cast_i4( +// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1> +// CHECK32: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<512xi32, 1> to memref<512xi32> +// CHECK32: return %[[CAST]] + +// ----- + +func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) { + memref.copy %arg0, %arg1 : memref<32x128xi4, 1> to memref<32x128xi4> + return +} + +// CHECK-LABEL: func.func @memref_copy_i4( +// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>, %[[ARG1:.*]]: memref<2048xi8> +// CHECK: memref.copy %[[ARG0]], %[[ARG1]] +// CHECK: return + +// CHECK32-LABEL: func.func @memref_copy_i4( +// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32> +// CHECK32: memref.copy %[[ARG0]], %[[ARG1]] +// CHECK32: return + +// ----- + +!colMajor = memref<8x8xi4, strided<[1, 8]>> +func.func @copy_distinct_layouts(%idx : index) -> i4 { + %c0 = arith.constant 0 : index + %arr = memref.alloc() : memref<8x8xi4> + %arr2 = memref.alloc() : !colMajor + // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}} + memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor + %ld = memref.load %arr2[%c0, %c0] : !colMajor + return %ld : i4 +} diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index d884ade319532..8aac802ba10ae 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1553,3 +1553,41 @@ func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>) // CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index // CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata // CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref, index, index, index + +// ----- + +func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>) + -> (memref, index, index, index) { + + %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1> + + %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast : + memref<20xf32, 1> -> memref, index, index, index + + return %base_buffer, %offset, %size, %stride : + memref, index, index, index +} + +// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast +// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index +// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index +// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata +// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[BASE]] +// CHECK: return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref, index, index, index + +// ----- + +func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>) + -> (index, index, index) { + + %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1> + + %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast : + memref<20xf32, 1> -> memref, index, index, index + + return %offset, %size, %stride : index, index, index +} + +// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base +// CHECK-NOT: memref.memory_space_cast