-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] [memref] add more checks to the memref.reinterpret_cast #112669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Operation memref.reinterpret_cast was accept input like: %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] : memref<?xf32> to memref<10xf32> A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out. This patch fixes this by enforcing that the return value's data type aligns with the input parameter.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-memref Author: donald chen (cxy-1993) ChangesOperation memref.reinterpret_cast was accept input like: %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out. This patch fixes this by enforcing that the return value's data type aligns Full diff: https://github.com/llvm/llvm-project/pull/112669.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..add78e78a97a8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,8 +1892,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
- if (!ShapedType::isDynamic(resultSize) &&
- !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+ if (resultSize != expectedSize)
return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << idx;
@@ -1910,17 +1909,14 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
- if (!ShapedType::isDynamic(resultOffset) &&
- !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+ if (resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto [idx, resultStride, expectedStride] :
llvm::enumerate(resultStrides, getStaticStrides())) {
- if (!ShapedType::isDynamic(resultStride) &&
- !ShapedType::isDynamic(expectedStride) &&
- resultStride != expectedStride)
+ if (resultStride != expectedStride)
return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << idx;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..739cf76429c045 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
// -----
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+ // expected-error @+1 {{expected result type with offset = -9223372036854775808 instead of 0}}
+ %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
|
…sult type This commit simplifies the result type of materialization functions. Previously: `std::optional<Value>` Now: `Value` The previous implementation allowed 3 possible return values: - Non-null value: The materialization function produced a valid materialization. - `std::nullopt`: The materialization function failed, but another materialization can be attempted. - `Value()`: The materialization failed and so should the dialect conversion. (Previously: Dialect conversion can roll back.) This commit removes the last variant. It is not particularly useful because the dialect conversion will fail anyway if all other materialization functions produced `std::nullopt`. In contrast to type conversions, at least one materialization callback is expected to succeed. In case of a failing type conversion, the current dialect conversion can roll back and try a different pattern. This also used to be the case for materializations, but that functionality was removed with #112669: failed materializations can no longer trigger a rollback. (They can just make the entire dialect conversion fail immediately without rollback.) With this in mind, it is even less useful to have an additional error state for materialization functions. This commit is in preparation of merging the 1:1 and 1:N type converters. Target materializations will have to return multiple values instead of a single one. With this commit, we can keep the API simple: `SmallVector<Value>` instead of `std::optional<SmallVector<Value>>`.
…sult type This commit simplifies the result type of materialization functions. Previously: `std::optional<Value>` Now: `Value` The previous implementation allowed 3 possible return values: - Non-null value: The materialization function produced a valid materialization. - `std::nullopt`: The materialization function failed, but another materialization can be attempted. - `Value()`: The materialization failed and so should the dialect conversion. (Previously: Dialect conversion can roll back.) This commit removes the last variant. It is not particularly useful because the dialect conversion will fail anyway if all other materialization functions produced `std::nullopt`. In contrast to type conversions, at least one materialization callback is expected to succeed. In case of a failing type conversion, the current dialect conversion can roll back and try a different pattern. This also used to be the case for materializations, but that functionality was removed with #112669: failed materializations can no longer trigger a rollback. (They can just make the entire dialect conversion fail immediately without rollback.) With this in mind, it is even less useful to have an additional error state for materialization functions. This commit is in preparation of merging the 1:1 and 1:N type converters. Target materializations will have to return multiple values instead of a single one. With this commit, we can keep the API simple: `SmallVector<Value>` instead of `std::optional<SmallVector<Value>>`.
Hi @MaheshRavishankar ,do you have any further comments on this patch? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to look through all the lit test changes. They all look good to me visually.
return emitError("expected result type with size = ") | ||
<< expectedSize << " instead of " << resultSize | ||
<< " in dim = " << idx; | ||
<< (ShapedType::isDynamic(expectedSize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I am reading the logic correctly, this check is not needed (here and everywhere below.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your review. However, I believe these checks are necessary. As the PR description mentions, mismatches between return value data types and operands may lead to other transforms incorrectly obtaining values of the wrong data type, resulting in erroneous outcomes. This submission is specifically designed to address this issue.
…112669) Operation memref.reinterpret_cast was accept input like: %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] : memref<?xf32> to memref<10xf32> A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out. This patch fixes this by enforcing that the return value's data type aligns with the input parameter.
…112669) Operation memref.reinterpret_cast was accept input like: %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] : memref<?xf32> to memref<10xf32> A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out. This patch fixes this by enforcing that the return value's data type aligns with the input parameter.
Operation memref.reinterpret_cast was accept input like:
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>
A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out.
This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.