Skip to content

Commit 889b67c

Browse files
authored
[mlir] [memref] add more checks to the memref.reinterpret_cast (#112669)
Operation memref.reinterpret_cast was accept input like: %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] : memref<?xf32> to memref<10xf32> A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out. This patch fixes this by enforcing that the return value's data type aligns with the input parameter.
1 parent 5f7bad0 commit 889b67c

File tree

9 files changed

+81
-63
lines changed

9 files changed

+81
-63
lines changed

mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ namespace mlir {
2929

3030
using namespace mlir;
3131

32+
static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
33+
auto sourceType = cast<BaseMemRefType>(source.getType());
34+
SmallVector<int64_t> staticOffsets;
35+
SmallVector<Value> dynamicOffsets;
36+
dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
37+
auto stridedLayout =
38+
StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
39+
return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
40+
sourceType.getMemorySpace());
41+
}
42+
3243
static void setInsertionPointToStart(OpBuilder &builder, Value val) {
3344
if (auto *parentOp = val.getDefiningOp()) {
3445
builder.setInsertionPointAfter(parentOp);
@@ -98,7 +109,7 @@ static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
98109
SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
99110
auto &&[base, offset, ignore] =
100111
getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
101-
auto retType = cast<MemRefType>(base.getType());
112+
MemRefType retType = inferCastResultType(base, offset);
102113
return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
103114
std::nullopt, std::nullopt);
104115
}

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,11 +1892,12 @@ LogicalResult ReinterpretCastOp::verify() {
18921892
// Match sizes in result memref type and in static_sizes attribute.
18931893
for (auto [idx, resultSize, expectedSize] :
18941894
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1895-
if (!ShapedType::isDynamic(resultSize) &&
1896-
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
1895+
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
18971896
return emitError("expected result type with size = ")
1898-
<< expectedSize << " instead of " << resultSize
1899-
<< " in dim = " << idx;
1897+
<< (ShapedType::isDynamic(expectedSize)
1898+
? std::string("dynamic")
1899+
: std::to_string(expectedSize))
1900+
<< " instead of " << resultSize << " in dim = " << idx;
19001901
}
19011902

19021903
// Match offset and strides in static_offset and static_strides attributes. If
@@ -1910,20 +1911,22 @@ LogicalResult ReinterpretCastOp::verify() {
19101911

19111912
// Match offset in result memref type and in static_offsets attribute.
19121913
int64_t expectedOffset = getStaticOffsets().front();
1913-
if (!ShapedType::isDynamic(resultOffset) &&
1914-
!ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
1914+
if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
19151915
return emitError("expected result type with offset = ")
1916-
<< expectedOffset << " instead of " << resultOffset;
1916+
<< (ShapedType::isDynamic(expectedOffset)
1917+
? std::string("dynamic")
1918+
: std::to_string(expectedOffset))
1919+
<< " instead of " << resultOffset;
19171920

19181921
// Match strides in result memref type and in static_strides attribute.
19191922
for (auto [idx, resultStride, expectedStride] :
19201923
llvm::enumerate(resultStrides, getStaticStrides())) {
1921-
if (!ShapedType::isDynamic(resultStride) &&
1922-
!ShapedType::isDynamic(expectedStride) &&
1923-
resultStride != expectedStride)
1924+
if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
19241925
return emitError("expected result type with stride = ")
1925-
<< expectedStride << " instead of " << resultStride
1926-
<< " in dim = " << idx;
1926+
<< (ShapedType::isDynamic(expectedStride)
1927+
? std::string("dynamic")
1928+
: std::to_string(expectedStride))
1929+
<< " instead of " << resultStride << " in dim = " << idx;
19271930
}
19281931

19291932
return success();

mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
8989
strides.resize(rank);
9090

9191
Location loc = op.getLoc();
92-
Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
92+
Value stride = nullptr;
93+
int64_t staticStride = 1;
9394
for (int i = rank - 1; i >= 0; --i) {
9495
Value size;
9596
// Load dynamic sizes from the shape input, use constants for static dims.
@@ -105,9 +106,22 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
105106
size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
106107
sizes[i] = sizeAttr;
107108
}
108-
strides[i] = stride;
109-
if (i > 0)
110-
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
109+
if (stride)
110+
strides[i] = stride;
111+
else
112+
strides[i] = rewriter.getIndexAttr(staticStride);
113+
114+
if (i > 0) {
115+
if (stride) {
116+
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
117+
} else if (op.getType().isDynamicDim(i)) {
118+
stride = rewriter.create<arith::MulIOp>(
119+
loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
120+
size);
121+
} else {
122+
staticStride *= op.getType().getDimSize(i);
123+
}
124+
}
111125
}
112126
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
113127
op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,8 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
507507

508508
SmallVector<OpFoldResult> groupStrides;
509509
ArrayRef<int64_t> srcShape = sourceType.getShape();
510+
511+
OpFoldResult lastValidStride = nullptr;
510512
for (int64_t currentDim : reassocGroup) {
511513
// Skip size-of-1 dimensions, since right now their strides may be
512514
// meaningless.
@@ -517,11 +519,11 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
517519
continue;
518520

519521
int64_t currentStride = strides[currentDim];
520-
groupStrides.push_back(ShapedType::isDynamic(currentStride)
521-
? origStrides[currentDim]
522-
: builder.getIndexAttr(currentStride));
522+
lastValidStride = ShapedType::isDynamic(currentStride)
523+
? origStrides[currentDim]
524+
: builder.getIndexAttr(currentStride);
523525
}
524-
if (groupStrides.empty()) {
526+
if (!lastValidStride) {
525527
// We're dealing with a 1x1x...x1 shape. The stride is meaningless,
526528
// but we still have to make the type system happy.
527529
MemRefType collapsedType = collapseShape.getResultType();
@@ -543,12 +545,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
543545
return {builder.getIndexAttr(finalStride)};
544546
}
545547

546-
// For the general case, we just want the minimum stride
547-
// since the collapsed dimensions are contiguous.
548-
auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
549-
builder.getContext());
550-
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
551-
groupStrides)};
548+
return {lastValidStride};
552549
}
553550

554551
/// From `reshape_like(memref, subSizes, subStrides))` compute

mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
425425
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
426426
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
427427
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
428-
// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
429-
// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
430428
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
431429
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
432430
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
@@ -548,23 +546,19 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
548546
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
549547
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
550548
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
551-
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
552549
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
553550
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64
554551
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
555552
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
556-
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
557-
// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
558-
// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
559-
// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
560553
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
561554
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
562555
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
563556
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
557+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
564558
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
565559
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
566560
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
567-
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
561+
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
568562
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32>
569563
// CHECK: return %[[RES]] : memref<1x?xf32>
570564
// CHECK: }

mlir/test/Dialect/GPU/decompose-memrefs.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
// CHECK: gpu.launch
88
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
99
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
10-
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
11-
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
10+
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
11+
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
1212
func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
1313
%c0 = arith.constant 0 : index
1414
%c1 = arith.constant 1 : index
@@ -33,8 +33,8 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
3333
// CHECK: gpu.launch
3434
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
3535
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
36-
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
37-
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
36+
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
37+
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
3838
func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
3939
%c0 = arith.constant 0 : index
4040
%c1 = arith.constant 1 : index
@@ -59,8 +59,8 @@ func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, stride
5959
// CHECK: gpu.launch
6060
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
6161
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
62-
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
63-
// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32>
62+
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
63+
// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32, strided<[], offset: ?>>
6464
// CHECK: "test.test"(%[[RES]]) : (f32) -> ()
6565
func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
6666
%c0 = arith.constant 0 : index

mlir/test/Dialect/MemRef/expand-ops.mlir

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,19 @@ func.func @memref_reshape(%input: memref<*xf32>,
5252
// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
5353
// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {
5454

55-
// CHECK: [[C1:%.*]] = arith.constant 1 : index
5655
// CHECK: [[C8:%.*]] = arith.constant 8 : index
57-
// CHECK: [[STRIDE_1:%.*]] = arith.muli [[C1]], [[C8]] : index
58-
59-
// CHECK: [[C1_:%.*]] = arith.constant 1 : index
60-
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
56+
// CHECK: [[C1:%.*]] = arith.constant 1 : index
57+
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
6158
// CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index
62-
// CHECK: [[STRIDE_0:%.*]] = arith.muli [[STRIDE_1]], [[SIZE_1]] : index
59+
60+
// CHECK: [[C8_:%.*]] = arith.constant 8 : index
61+
// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index
6362

6463
// CHECK: [[C0:%.*]] = arith.constant 0 : index
6564
// CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
6665
// CHECK: [[SIZE_0:%.*]] = arith.index_cast [[DIM_0]] : i32 to index
6766

6867
// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
6968
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
70-
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
69+
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1]
7170
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -931,19 +931,15 @@ func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf
931931
// = min(7, 1)
932932
// = 1
933933
//
934-
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)>
935-
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
936-
// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
934+
// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
937935
// CHECK-LABEL: func @simplify_collapse(
938936
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
939937
//
940938
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
941939
//
942-
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0]
943-
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
944-
// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
940+
// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
945941
//
946-
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1]
942+
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
947943
func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
948944
-> memref<?x?x42xi32> {
949945

@@ -1046,15 +1042,12 @@ func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
10461042
// We just return the first dynamic one for this group.
10471043
//
10481044
//
1049-
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)>
10501045
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
10511046
// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
10521047
//
10531048
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
10541049
//
1055-
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1]
1056-
//
1057-
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2]
1050+
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
10581051
func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
10591052
(%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
10601053
-> memref<6x1xi32, strided<[?, ?], offset: ?>> {
@@ -1083,8 +1076,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
10831076
// Stride 2 = origStride5
10841077
// = 1
10851078
//
1086-
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
1087-
// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
1079+
// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
10881080
// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
10891081
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
10901082
//
@@ -1094,10 +1086,9 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
10941086
//
10951087
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
10961088
//
1097-
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
10981089
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
10991090
//
1100-
// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
1091+
// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
11011092
func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
11021093
-> (memref<i32>, index,
11031094
index, index, index,

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
217217

218218
// -----
219219

220+
func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
221+
// expected-error @+1 {{expected result type with offset = dynamic instead of 0}}
222+
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
223+
: memref<?xf32> to memref<10xf32>
224+
return
225+
}
226+
227+
// -----
228+
220229
func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
221230
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
222231
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]

0 commit comments

Comments
 (0)