Skip to content

[mlir][memref] Fix emulate narrow types for strided memref offset #68181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
Expand Down Expand Up @@ -209,6 +210,76 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//

/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
/// folded away before running narrow type emulation, and this pattern would
/// never run. This pattern is mostly used for testing pruposes.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType newTy =
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}

auto convertedElementType = newTy.getElementType();
auto oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}

// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
}

// Only support stride of 1.
if (op.getStaticStride(0) != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with stride != 1 is not supported");
}

int64_t size = op.getStaticSize(0);
int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(),
"subview with offset not multiple of elementsPerByte is not "
"supported");
}

size = ceilDiv(size, elementsPerByte);
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
op.getStaticStrides());
return success();
}
};

} // end anonymous namespace

//===----------------------------------------------------------------------===//
Expand All @@ -220,9 +291,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {

// Populate `memref.*` conversion patterns.
patterns
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
typeConverter, patterns.getContext());
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}

Expand Down Expand Up @@ -271,9 +342,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
return std::nullopt;

StridedLayoutAttr layoutAttr;
// If the offset is 0, we do not need a strided layout as the stride is
// 1, so we only use the strided layout if the offset is not 0.
if (offset != 0) {
layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
if (offset == ShapedType::kDynamic) {
layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
} else {
// Check if the number of bytes are a multiple of the loadStoreWidth
// and if so, divide it by the loadStoreWidth to get the offset.
if ((offset * width) % loadStoreWidth != 0)
return std::nullopt;
offset = (offset * width) / loadStoreWidth;

layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
ArrayRef<int64_t>{1});
}
}

return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,22 @@ func.func @rank_zero_memref() -> i4 {
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
// CHECK32: return %[[TRUNC]]

// -----

func.func @memref_strided_i4(%idx : index) -> i4 {
%arr = memref.alloc() : memref<128xi4>
%subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>>
%1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>>
return %1 : i4
}

// CHECK-LABEL: func @memref_strided_i4
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>>
// CHECK: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]

// CHECK32-LABEL: func @memref_strided_i4
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]