From 4697178ab883acb96c65d7dedff1a507f24613ec Mon Sep 17 00:00:00 2001 From: Groverkss Date: Wed, 4 Oct 2023 08:56:54 +0530 Subject: [PATCH 1/4] [mlir][memref] Fix emulate narrow types for strided memref offset --- .../MemRef/Transforms/EmulateNarrowType.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 2a524ceb9db88..2a9b27debaece 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -271,9 +271,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions( return std::nullopt; StridedLayoutAttr layoutAttr; + // If the offset is 0, we do not need a strided layout as the stride is + // 1, so we only use the strided layout if the offset is not 0. if (offset != 0) { - layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, - ArrayRef{1}); + if (offset == ShapedType::kDynamic) { + layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, + ArrayRef{1}); + } else { + // Check if the number of bytes are a multiple of the loadStoreWidth + // and if so, divide it by the loadStoreWidth to get the offset. + if ((offset * width) % loadStoreWidth != 0) + return std::nullopt; + offset = (offset * width) / loadStoreWidth; + + layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, + ArrayRef{1}); + } } return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth), From f9b25612e2389980166523edb774a807c2228c3a Mon Sep 17 00:00:00 2001 From: Groverkss Date: Wed, 4 Oct 2023 11:19:55 +0530 Subject: [PATCH 2/4] Add tests and memref.subview support in emulate narrow types --- .../MemRef/Transforms/EmulateNarrowType.cpp | 74 ++++++++++++++++++- .../Dialect/MemRef/emulate-narrow-type.mlir | 19 +++++ 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 2a9b27debaece..453a18ff3c254 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -209,6 +209,74 @@ struct ConvertMemRefLoad final : OpConversionPattern { return success(); } }; + +//===----------------------------------------------------------------------===// +// ConvertMemRefAssumeAlignment +//===----------------------------------------------------------------------===// + +struct ConvertMemRefSubview final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto convertedType = + cast(getTypeConverter()->convertType(op.getSourceType())); + auto convertedElementType = convertedType.getElementType(); + auto oldElementType = op.getSourceType().getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = convertedElementType.getIntOrFloatBitWidth(); + if (dstBits % srcBits != 0) { + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + } + + MemRefType newTy = + cast(getTypeConverter()->convertType(op.getType())); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + } + + // Only support offset for 1-D subview. + if (op.getType().getRank() != 1) { + return rewriter.notifyMatchFailure( + op->getLoc(), "subview with rank > 1 is not supported"); + } + + // Only support stride of 1. + if (op.getStaticStride(0) != 1) { + return rewriter.notifyMatchFailure( + op->getLoc(), "subview with stride != 1 is not supported"); + } + + auto size = op.getStaticSize(0); + auto offset = op.getStaticOffset(0); + // Only support static sizes and offsets. + if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) { + return rewriter.notifyMatchFailure( + op->getLoc(), "subview with dynamic size or offset is not supported"); + } + + int elementsPerByte = dstBits / srcBits; + if (size % elementsPerByte != 0 || offset % elementsPerByte != 0) { + return rewriter.notifyMatchFailure( + op->getLoc(), + "subview with size or offset not multiple of elementsPerByte is not " + "supported"); + } + + size = size / elementsPerByte; + offset = offset / elementsPerByte; + + rewriter.replaceOpWithNewOp( + op, newTy, *adaptor.getODSOperands(0).begin(), offset, size, + op.getStaticStrides()); + return success(); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -220,9 +288,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( RewritePatternSet &patterns) { // Populate `memref.*` conversion patterns. - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); memref::populateResolveExtractStridedMetadataPatterns(patterns); } diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index c0050d8c510d5..6ed97f05aa7cf 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -155,3 +155,22 @@ func.func @rank_zero_memref() -> i4 { // CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref // CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4 // CHECK32: return %[[TRUNC]] + +// ----- + +func.func @memref_strided_i4(%idx : index) -> i4 { + %arr = memref.alloc() : memref<128xi4> + %subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>> + %1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>> + return %1 : i4 +} + +// CHECK-LABEL: func @memref_strided_i4 +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>> +// CHECK: %[[LOAD:.+]] = memref.load %[[SUBVIEW]] + +// CHECK32-LABEL: func @memref_strided_i4 +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32> +// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>> +// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]] From 3c3cd037829a6ed79375ae6b7d76a47dc919a58f Mon Sep 17 00:00:00 2001 From: Groverkss Date: Wed, 4 Oct 2023 11:28:53 +0530 Subject: [PATCH 3/4] Fix doc --- mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 453a18ff3c254..c98dda27f6c0b 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -211,7 +211,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { }; //===----------------------------------------------------------------------===// -// ConvertMemRefAssumeAlignment +// ConvertMemRefSubview //===----------------------------------------------------------------------===// struct ConvertMemRefSubview final : OpConversionPattern { From 0b1caf93c97e08fcac6bbf6e95680d6adc7a2050 Mon Sep 17 00:00:00 2001 From: Groverkss Date: Thu, 5 Oct 2023 10:47:56 +0530 Subject: [PATCH 4/4] Address Mahesh's comments --- .../MemRef/Transforms/EmulateNarrowType.cpp | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index c98dda27f6c0b..9f58e9055acad 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" @@ -214,31 +215,33 @@ struct ConvertMemRefLoad final : OpConversionPattern { // ConvertMemRefSubview //===----------------------------------------------------------------------===// +/// Emulating narrow ints on subview have limited support, supporting only +/// static offset and size and stride of 1. Ideally, the subview should be +/// folded away before running narrow type emulation, and this pattern would +/// never run. This pattern is mostly used for testing pruposes. struct ConvertMemRefSubview final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto convertedType = - cast(getTypeConverter()->convertType(op.getSourceType())); - auto convertedElementType = convertedType.getElementType(); - auto oldElementType = op.getSourceType().getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = convertedElementType.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { - return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); - } - MemRefType newTy = - cast(getTypeConverter()->convertType(op.getType())); + dyn_cast(getTypeConverter()->convertType(op.getType())); if (!newTy) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {0}", op.getType())); } + auto convertedElementType = newTy.getElementType(); + auto oldElementType = op.getType().getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = convertedElementType.getIntOrFloatBitWidth(); + if (dstBits % srcBits != 0) { + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + } + // Only support offset for 1-D subview. if (op.getType().getRank() != 1) { return rewriter.notifyMatchFailure( @@ -251,8 +254,8 @@ struct ConvertMemRefSubview final : OpConversionPattern { op->getLoc(), "subview with stride != 1 is not supported"); } - auto size = op.getStaticSize(0); - auto offset = op.getStaticOffset(0); + int64_t size = op.getStaticSize(0); + int64_t offset = op.getStaticOffset(0); // Only support static sizes and offsets. if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) { return rewriter.notifyMatchFailure( @@ -260,14 +263,14 @@ struct ConvertMemRefSubview final : OpConversionPattern { } int elementsPerByte = dstBits / srcBits; - if (size % elementsPerByte != 0 || offset % elementsPerByte != 0) { + if (offset % elementsPerByte != 0) { return rewriter.notifyMatchFailure( op->getLoc(), - "subview with size or offset not multiple of elementsPerByte is not " + "subview with offset not multiple of elementsPerByte is not " "supported"); } - size = size / elementsPerByte; + size = ceilDiv(size, elementsPerByte); offset = offset / elementsPerByte; rewriter.replaceOpWithNewOp(