Skip to content

[mlir] [memref] add more checks to the memref.reinterpret_cast #112669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ namespace mlir {

using namespace mlir;

static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
auto sourceType = cast<BaseMemRefType>(source.getType());
SmallVector<int64_t> staticOffsets;
SmallVector<Value> dynamicOffsets;
dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
auto stridedLayout =
StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
sourceType.getMemorySpace());
}

static void setInsertionPointToStart(OpBuilder &builder, Value val) {
if (auto *parentOp = val.getDefiningOp()) {
builder.setInsertionPointAfter(parentOp);
Expand Down Expand Up @@ -98,7 +109,7 @@ static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
auto &&[base, offset, ignore] =
getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
auto retType = cast<MemRefType>(base.getType());
MemRefType retType = inferCastResultType(base, offset);
return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
std::nullopt, std::nullopt);
}
Expand Down
27 changes: 15 additions & 12 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1892,11 +1892,12 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
if (!ShapedType::isDynamic(resultSize) &&
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << idx;
<< (ShapedType::isDynamic(expectedSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am reading the logic correctly, this check is not needed (here and everywhere below.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. However, I believe these checks are necessary. As the PR description mentions, mismatches between return value data types and operands may lead to other transforms incorrectly obtaining values of the wrong data type, resulting in erroneous outcomes. This submission is specifically designed to address this issue.

? std::string("dynamic")
: std::to_string(expectedSize))
<< " instead of " << resultSize << " in dim = " << idx;
}

// Match offset and strides in static_offset and static_strides attributes. If
Expand All @@ -1910,20 +1911,22 @@ LogicalResult ReinterpretCastOp::verify() {

// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) &&
!ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;
<< (ShapedType::isDynamic(expectedOffset)
? std::string("dynamic")
: std::to_string(expectedOffset))
<< " instead of " << resultOffset;

// Match strides in result memref type and in static_strides attribute.
for (auto [idx, resultStride, expectedStride] :
llvm::enumerate(resultStrides, getStaticStrides())) {
if (!ShapedType::isDynamic(resultStride) &&
!ShapedType::isDynamic(expectedStride) &&
resultStride != expectedStride)
if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << idx;
<< (ShapedType::isDynamic(expectedStride)
? std::string("dynamic")
: std::to_string(expectedStride))
<< " instead of " << resultStride << " in dim = " << idx;
}

return success();
Expand Down
22 changes: 18 additions & 4 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
strides.resize(rank);

Location loc = op.getLoc();
Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value stride = nullptr;
int64_t staticStride = 1;
for (int i = rank - 1; i >= 0; --i) {
Value size;
// Load dynamic sizes from the shape input, use constants for static dims.
Expand All @@ -105,9 +106,22 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
sizes[i] = sizeAttr;
}
strides[i] = stride;
if (i > 0)
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
if (stride)
strides[i] = stride;
else
strides[i] = rewriter.getIndexAttr(staticStride);

if (i > 0) {
if (stride) {
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
} else if (op.getType().isDynamicDim(i)) {
stride = rewriter.create<arith::MulIOp>(
loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
size);
} else {
staticStride *= op.getType().getDimSize(i);
}
}
}
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
Expand Down
17 changes: 7 additions & 10 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,

SmallVector<OpFoldResult> groupStrides;
ArrayRef<int64_t> srcShape = sourceType.getShape();

OpFoldResult lastValidStride = nullptr;
for (int64_t currentDim : reassocGroup) {
// Skip size-of-1 dimensions, since right now their strides may be
// meaningless.
Expand All @@ -517,11 +519,11 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
continue;

int64_t currentStride = strides[currentDim];
groupStrides.push_back(ShapedType::isDynamic(currentStride)
? origStrides[currentDim]
: builder.getIndexAttr(currentStride));
lastValidStride = ShapedType::isDynamic(currentStride)
? origStrides[currentDim]
: builder.getIndexAttr(currentStride);
}
if (groupStrides.empty()) {
if (!lastValidStride) {
// We're dealing with a 1x1x...x1 shape. The stride is meaningless,
// but we still have to make the type system happy.
MemRefType collapsedType = collapseShape.getResultType();
Expand All @@ -543,12 +545,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {builder.getIndexAttr(finalStride)};
}

// For the general case, we just want the minimum stride
// since the collapsed dimensions are contiguous.
auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
builder.getContext());
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
return {lastValidStride};
}

/// From `reshape_like(memref, subSizes, subStrides))` compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
Expand Down Expand Up @@ -548,23 +546,19 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32>
// CHECK: return %[[RES]] : memref<1x?xf32>
// CHECK: }
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/GPU/decompose-memrefs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -33,8 +33,8 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -59,8 +59,8 @@ func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, stride
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32>
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32, strided<[], offset: ?>>
// CHECK: "test.test"(%[[RES]]) : (f32) -> ()
func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
Expand Down
13 changes: 6 additions & 7 deletions mlir/test/Dialect/MemRef/expand-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,19 @@ func.func @memref_reshape(%input: memref<*xf32>,
// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {

// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C8:%.*]] = arith.constant 8 : index
// CHECK: [[STRIDE_1:%.*]] = arith.muli [[C1]], [[C8]] : index

// CHECK: [[C1_:%.*]] = arith.constant 1 : index
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
// CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index
// CHECK: [[STRIDE_0:%.*]] = arith.muli [[STRIDE_1]], [[SIZE_1]] : index

// CHECK: [[C8_:%.*]] = arith.constant 8 : index
// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index

// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
// CHECK: [[SIZE_0:%.*]] = arith.index_cast [[DIM_0]] : i32 to index

// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1]
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
21 changes: 6 additions & 15 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -931,19 +931,15 @@ func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf
// = min(7, 1)
// = 1
//
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-LABEL: func @simplify_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1]
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> memref<?x?x42xi32> {

Expand Down Expand Up @@ -1046,15 +1042,12 @@ func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
// We just return the first dynamic one for this group.
//
//
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)>
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1]
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2]
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
(%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
-> memref<6x1xi32, strided<[?, ?], offset: ?>> {
Expand Down Expand Up @@ -1083,8 +1076,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
// Stride 2 = origStride5
// = 1
//
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
Expand All @@ -1094,10 +1086,9 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
//
// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> (memref<i32>, index,
index, index, index,
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {

// -----

func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
// expected-error @+1 {{expected result type with offset = dynamic instead of 0}}
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>
return
}

// -----

func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
Expand Down
Loading