Skip to content

Commit fe68c3c

Browse files
[mlir][LLVM] LLVMTypeConverter: Tighten materialization checks
1 parent 4e4a5c8 commit fe68c3c

File tree

5 files changed

+154
-15
lines changed

5 files changed

+154
-15
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Helper function that checks if the given value range is a bare pointer.
157+
auto isBarePointer = [](ValueRange values) {
158+
return values.size() == 1 &&
159+
isa<LLVM::LLVMPointerType>(values.front().getType());
160+
};
161+
156162
// Argument materializations convert from the new block argument types
157163
// (multiple SSA values that make up a memref descriptor) back to the
158164
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
161167
addArgumentMaterialization([&](OpBuilder &builder,
162168
UnrankedMemRefType resultType,
163169
ValueRange inputs, Location loc) {
164-
if (inputs.size() == 1) {
165-
// Bare pointers are not supported for unranked memrefs because a
166-
// memref descriptor cannot be built just from a bare pointer.
170+
// Note: Bare pointers are not supported for unranked memrefs because a
171+
// memref descriptor cannot be built just from a bare pointer.
172+
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
167173
return Value();
168-
}
169174
Value desc =
170175
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171176
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
177182
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178183
ValueRange inputs, Location loc) {
179184
Value desc;
180-
if (inputs.size() == 1) {
181-
// This is a bare pointer. We allow bare pointers only for function entry
182-
// blocks.
183-
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
184-
if (!barePtr)
185-
return Value();
186-
Block *block = barePtr.getOwner();
187-
if (!block->isEntryBlock() ||
188-
!isa<FunctionOpInterface>(block->getParentOp()))
189-
return Value();
185+
if (isBarePointer(inputs)) {
190186
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191187
inputs[0]);
192-
} else {
188+
} else if (TypeRange(inputs) ==
189+
getMemRefDescriptorFields(resultType,
190+
/*unpackAggregates=*/true)) {
193191
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
192+
} else {
193+
// The inputs are neither a bare pointer nor an unpacked memref
194+
// descriptor. This materialization function cannot be used.
195+
return Value();
194196
}
195197
// An argument materialization must return a value of type `resultType`,
196198
// so insert a cast from the memref descriptor type (!llvm.struct) to the
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
2+
3+
// Test the argument materializer for ranked MemRef types.
4+
5+
// CHECK-LABEL: func @construct_ranked_memref_descriptor(
6+
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
7+
// CHECK-COUNT-7: llvm.insertvalue
8+
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
9+
func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
10+
%0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
11+
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
12+
return
13+
}
14+
15+
// -----
16+
17+
// The argument materializer for ranked MemRef types is called with incorrect
18+
// input types. Make sure that the materializer is skipped and we do not
19+
// generate invalid IR.
20+
21+
// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
22+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
23+
// CHECK: "test.legal_op"(%[[cast]])
24+
func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
25+
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
26+
"test.legal_op"(%0) : (memref<5x4xf32>) -> ()
27+
return
28+
}
29+
30+
// -----
31+
32+
// Test the argument materializer for unranked MemRef types.
33+
34+
// CHECK-LABEL: func @construct_unranked_memref_descriptor(
35+
// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)>
36+
// CHECK-COUNT-2: llvm.insertvalue
37+
// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
38+
func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
39+
%0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
40+
"test.legal_op"(%0) : (memref<*xf32>) -> ()
41+
return
42+
}
43+
44+
// -----
45+
46+
// The argument materializer for unranked MemRef types is called with incorrect
47+
// input types. Make sure that the materializer is skipped and we do not
48+
// generate invalid IR.
49+
50+
// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
51+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
52+
// CHECK: "test.legal_op"(%[[cast]])
53+
func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
54+
%0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
55+
"test.legal_op"(%0) : (memref<*xf32>) -> ()
56+
return
57+
}

mlir/test/lib/Dialect/LLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRLLVMTestPasses
33
TestLowerToLLVM.cpp
4+
TestPatterns.cpp
45

56
EXCLUDE_FROM_LIBMLIR
67

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/DialectConversion.h"
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
19+
/// Replace this op (which is expected to have 1 result) with the operands.
20+
struct TestDirectReplacementOp : public ConversionPattern {
21+
TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
22+
: ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
23+
LogicalResult
24+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
25+
ConversionPatternRewriter &rewriter) const final {
26+
if (op->getNumResults() != 1)
27+
return failure();
28+
rewriter.replaceOpWithMultiple(op, {operands});
29+
return success();
30+
}
31+
};
32+
33+
struct TestLLVMLegalizePatternsPass
34+
: public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
35+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
36+
37+
StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
38+
StringRef getDescription() const final {
39+
return "Run LLVM dialect legalization patterns";
40+
}
41+
42+
void getDependentDialects(DialectRegistry &registry) const override {
43+
registry.insert<LLVM::LLVMDialect>();
44+
}
45+
46+
void runOnOperation() override {
47+
MLIRContext *ctx = &getContext();
48+
LLVMTypeConverter converter(ctx);
49+
mlir::RewritePatternSet patterns(ctx);
50+
patterns.add<TestDirectReplacementOp>(ctx, converter);
51+
52+
// Define the conversion target used for the test.
53+
ConversionTarget target(*ctx);
54+
target.addLegalOp(OperationName("test.legal_op", ctx));
55+
56+
// Handle a partial conversion.
57+
DenseSet<Operation *> unlegalizedOps;
58+
ConversionConfig config;
59+
config.unlegalizedOps = &unlegalizedOps;
60+
if (failed(applyPartialConversion(getOperation(), target,
61+
std::move(patterns), config)))
62+
getOperation()->emitError() << "applyPartialConversion failed";
63+
}
64+
};
65+
} // namespace
66+
67+
//===----------------------------------------------------------------------===//
68+
// PassRegistration
69+
//===----------------------------------------------------------------------===//
70+
71+
namespace mlir {
72+
namespace test {
73+
void registerTestLLVMLegalizePatternsPass() {
74+
PassRegistration<TestLLVMLegalizePatternsPass>();
75+
}
76+
} // namespace test
77+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ void registerTestLinalgRankReduceContractionOps();
113113
void registerTestLinalgTransforms();
114114
void registerTestLivenessAnalysisPass();
115115
void registerTestLivenessPass();
116+
void registerTestLLVMLegalizePatternsPass();
116117
void registerTestLoopFusion();
117118
void registerTestLoopMappingPass();
118119
void registerTestLoopUnrollingPass();
@@ -250,6 +251,7 @@ void registerTestPasses() {
250251
mlir::test::registerTestLinalgTransforms();
251252
mlir::test::registerTestLivenessAnalysisPass();
252253
mlir::test::registerTestLivenessPass();
254+
mlir::test::registerTestLLVMLegalizePatternsPass();
253255
mlir::test::registerTestLoopFusion();
254256
mlir::test::registerTestLoopMappingPass();
255257
mlir::test::registerTestLoopUnrollingPass();

0 commit comments

Comments
 (0)