Skip to content

[mlir] [memref] add more checks to the memref.reinterpret_cast #112669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 26, 2024

Conversation

cxy-1993
Copy link
Contributor

Operation memref.reinterpret_cast was accept input like:

%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out.

This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.

Operation memref.reinterpret_cast was accept input like:

  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
         : memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset, but its
data type indicates an offset of 0. Permitting this inconsistency can result in
incorrect outcomes, as certain pass might erroneously extract the offset from
the data type of %out.

 This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: donald chen (cxy-1993)

Changes

Operation memref.reinterpret_cast was accept input like:

%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out.

This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.


Full diff: https://github.com/llvm/llvm-project/pull/112669.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+3-7)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..add78e78a97a8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,8 +1892,7 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
-    if (!ShapedType::isDynamic(resultSize) &&
-        !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+    if (resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
              << " in dim = " << idx;
@@ -1910,17 +1909,14 @@ LogicalResult ReinterpretCastOp::verify() {
 
   // Match offset in result memref type and in static_offsets attribute.
   int64_t expectedOffset = getStaticOffsets().front();
-  if (!ShapedType::isDynamic(resultOffset) &&
-      !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+  if (resultOffset != expectedOffset)
     return emitError("expected result type with offset = ")
            << expectedOffset << " instead of " << resultOffset;
 
   // Match strides in result memref type and in static_strides attribute.
   for (auto [idx, resultStride, expectedStride] :
        llvm::enumerate(resultStrides, getStaticStrides())) {
-    if (!ShapedType::isDynamic(resultStride) &&
-        !ShapedType::isDynamic(expectedStride) &&
-        resultStride != expectedStride)
+    if (resultStride != expectedStride)
       return emitError("expected result type with stride = ")
              << expectedStride << " instead of " << resultStride
              << " in dim = " << idx;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..739cf76429c045 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
 
 // -----
 
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+  // expected-error @+1 {{expected result type with offset = -9223372036854775808 instead of 0}}
+  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
 func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
   // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
   %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]

matthias-springer added a commit that referenced this pull request Oct 19, 2024
…sult type

This commit simplifies the result type of materialization functions.

Previously: `std::optional<Value>`
Now: `Value`

The previous implementation allowed 3 possible return values:
- Non-null value: The materialization function produced a valid materialization.
- `std::nullopt`: The materialization function failed, but another materialization can be attempted.
- `Value()`: The materialization failed and so should the dialect conversion. (Previously: Dialect conversion can roll back.)

This commit removes the last variant. It is not particularly useful because the dialect conversion will fail anyway if all other materialization functions produced `std::nullopt`. In contrast to type conversions, at least one materialization callback is expected to succeed. In case of a failing type conversion, the current dialect conversion can roll back and try a different pattern. This also used to be the case for materializations, but that functionality was removed with #112669: failed materializations can no longer trigger a rollback. (They can just make the entire dialect conversion fail immediately without rollback.) With this in mind, it is even less useful to have an additional error state for materialization functions.

This commit is in preparation of merging the 1:1 and 1:N type converters. Target materializations will have to return multiple values instead of a single one. With this commit, we can keep the API simple: `SmallVector<Value>` instead of `std::optional<SmallVector<Value>>`.
matthias-springer added a commit that referenced this pull request Oct 23, 2024
…sult type

This commit simplifies the result type of materialization functions.

Previously: `std::optional<Value>`
Now: `Value`

The previous implementation allowed 3 possible return values:
- Non-null value: The materialization function produced a valid materialization.
- `std::nullopt`: The materialization function failed, but another materialization can be attempted.
- `Value()`: The materialization failed and so should the dialect conversion. (Previously: Dialect conversion can roll back.)

This commit removes the last variant. It is not particularly useful because the dialect conversion will fail anyway if all other materialization functions produced `std::nullopt`. In contrast to type conversions, at least one materialization callback is expected to succeed. In case of a failing type conversion, the current dialect conversion can roll back and try a different pattern. This also used to be the case for materializations, but that functionality was removed with #112669: failed materializations can no longer trigger a rollback. (They can just make the entire dialect conversion fail immediately without rollback.) With this in mind, it is even less useful to have an additional error state for materialization functions.

This commit is in preparation of merging the 1:1 and 1:N type converters. Target materializations will have to return multiple values instead of a single one. With this commit, we can keep the API simple: `SmallVector<Value>` instead of `std::optional<SmallVector<Value>>`.
@cxy-1993
Copy link
Contributor Author

Hi @MaheshRavishankar ,do you have any further comments on this patch?

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to look through all the lit test changes. They all look good to me visually.

return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << idx;
<< (ShapedType::isDynamic(expectedSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am reading the logic correctly, this check is not needed (here and everywhere below.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. However, I believe these checks are necessary. As the PR description mentions, mismatches between return value data types and operands may lead to other transforms incorrectly obtaining values of the wrong data type, resulting in erroneous outcomes. This submission is specifically designed to address this issue.

@cxy-1993 cxy-1993 merged commit 889b67c into llvm:main Oct 26, 2024
9 checks passed
winner245 pushed a commit to winner245/llvm-project that referenced this pull request Oct 26, 2024
…112669)

Operation memref.reinterpret_cast was accept input like:

%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10],
strides: [1]
         : memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset,
but its data type indicates an offset of 0. Permitting this
inconsistency can result in incorrect outcomes, as certain pass might
erroneously extract the offset from the data type of %out.

This patch fixes this by enforcing that the return value's data type
aligns
with the input parameter.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
…112669)

Operation memref.reinterpret_cast was accept input like:

%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10],
strides: [1]
         : memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset,
but its data type indicates an offset of 0. Permitting this
inconsistency can result in incorrect outcomes, as certain pass might
erroneously extract the offset from the data type of %out.

This patch fixes this by enforcing that the return value's data type
aligns
with the input parameter.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants