diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 2a524ceb9db88..2a9b27debaece 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -271,9 +271,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions( return std::nullopt; StridedLayoutAttr layoutAttr; + // If the offset is 0, we do not need a strided layout as the stride is + // 1, so we only use the strided layout if the offset is not 0. if (offset != 0) { - layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, - ArrayRef{1}); + if (offset == ShapedType::kDynamic) { + layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, + ArrayRef{1}); + } else { + // Check if the number of bytes are a multiple of the loadStoreWidth + // and if so, divide it by the loadStoreWidth to get the offset. + if ((offset * width) % loadStoreWidth != 0) + return std::nullopt; + offset = (offset * width) / loadStoreWidth; + + layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, + ArrayRef{1}); + } } return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),