Skip to content

Commit 75e5f0a

Browse files
committed
[mlir] factor memref-to-llvm lowering out of std-to-llvm
After the MemRef has been split out of the Standard dialect, the conversion to the LLVM dialect remained as a huge monolithic pass. This is undesirable for the same complexity management reasons as having a huge Standard dialect itself, and is even more confusing given the existence of a separate dialect. Extract the conversion of the MemRef dialect operations to LLVM into a separate library and a separate conversion pass. Reviewed By: herhut, silvas Differential Revision: https://reviews.llvm.org/D105625
1 parent 9a01527 commit 75e5f0a

File tree

140 files changed

+3927
-3722
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

140 files changed

+3927
-3722
lines changed

mlir/examples/toy/Ch6/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ target_link_libraries(toyc-ch6
4242
MLIRCastInterfaces
4343
MLIRExecutionEngine
4444
MLIRIR
45+
MLIRLLVMCommonConversion
4546
MLIRLLVMIR
4647
MLIRLLVMToLLVMIRTranslation
48+
MLIRMemRef
4749
MLIRParser
4850
MLIRPass
4951
MLIRSideEffectInterfaces

mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#include "toy/Passes.h"
2626

2727
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
28+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
29+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
30+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2831
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
2932
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
3033
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
@@ -195,6 +198,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
195198
RewritePatternSet patterns(&getContext());
196199
populateAffineToStdConversionPatterns(patterns);
197200
populateLoopToStdConversionPatterns(patterns);
201+
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
198202
populateStdToLLVMConversionPatterns(typeConverter, patterns);
199203

200204
// The only remaining operation to lower from the `toy` dialect, is the

mlir/examples/toy/Ch7/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ target_link_libraries(toyc-ch7
4242
MLIRCastInterfaces
4343
MLIRExecutionEngine
4444
MLIRIR
45+
MLIRLLVMCommonConversion
4546
MLIRLLVMToLLVMIRTranslation
47+
MLIRMemRef
4648
MLIRParser
4749
MLIRPass
4850
MLIRSideEffectInterfaces

mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#include "toy/Passes.h"
2626

2727
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
28+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
29+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
30+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2831
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
2932
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
3033
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
@@ -195,6 +198,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
195198
RewritePatternSet patterns(&getContext());
196199
populateAffineToStdConversionPatterns(patterns);
197200
populateLoopToStdConversionPatterns(patterns);
201+
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
198202
populateStdToLLVMConversionPatterns(typeConverter, patterns);
199203

200204
// The only remaining operation to lower from the `toy` dialect, is the

mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
#define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
1010

1111
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
12-
#include "mlir/Transforms/DialectConversion.h"
1312

1413
namespace mlir {
1514
class LLVMTypeConverter;
1615
class ModuleOp;
1716
template <typename T>
1817
class OperationPass;
18+
class RewritePatternSet;
1919

2020
class ComplexStructBuilder : public StructBuilder {
2121
public:
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- ConversionTarget.h - LLVM dialect conversion target ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_LLVMCOMMON_CONVERSIONTARGET_H
10+
#define MLIR_CONVERSION_LLVMCOMMON_CONVERSIONTARGET_H
11+
12+
#include "mlir/Transforms/DialectConversion.h"
13+
14+
namespace mlir {
15+
/// Derived class that automatically populates legalization information for
16+
/// different LLVM ops.
17+
class LLVMConversionTarget : public ConversionTarget {
18+
public:
19+
explicit LLVMConversionTarget(MLIRContext &ctx);
20+
};
21+
} // namespace mlir
22+
23+
#endif // MLIR_CONVERSION_LLVMCOMMON_CONVERSIONTARGET_H

mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
#ifndef MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_
99
#define MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_
1010

11-
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12-
#include "mlir/Transforms/DialectConversion.h"
11+
#include <memory>
1312

1413
namespace mlir {
14+
class LLVMTypeConverter;
1515
class MLIRContext;
1616
class ModuleOp;
1717
template <typename T>
1818
class OperationPass;
19+
class RewritePatternSet;
1920

2021
/// Populate the given list with patterns that convert from Linalg to LLVM.
2122
void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//===- AllocLikeConversion.h - Convert allocation ops to LLVM ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
10+
#define MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
11+
12+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
13+
14+
namespace mlir {
15+
16+
/// Lowering for AllocOp and AllocaOp.
17+
struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
18+
using ConvertToLLVMPattern::createIndexConstant;
19+
using ConvertToLLVMPattern::getIndexType;
20+
using ConvertToLLVMPattern::getVoidPtrType;
21+
22+
explicit AllocLikeOpLLVMLowering(StringRef opName,
23+
LLVMTypeConverter &converter)
24+
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
25+
26+
protected:
27+
// Returns 'input' aligned up to 'alignment'. Computes
28+
// bumped = input + alignement - 1
29+
// aligned = bumped - bumped % alignment
30+
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
31+
Value input, Value alignment);
32+
33+
/// Allocates the underlying buffer. Returns the allocated pointer and the
34+
/// aligned pointer.
35+
virtual std::tuple<Value, Value>
36+
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
37+
Value sizeBytes, Operation *op) const = 0;
38+
39+
private:
40+
static MemRefType getMemRefResultType(Operation *op) {
41+
return op->getResult(0).getType().cast<MemRefType>();
42+
}
43+
44+
// An `alloc` is converted into a definition of a memref descriptor value and
45+
// a call to `malloc` to allocate the underlying data buffer. The memref
46+
// descriptor is of the LLVM structure type where:
47+
// 1. the first element is a pointer to the allocated (typed) data buffer,
48+
// 2. the second element is a pointer to the (typed) payload, aligned to the
49+
// specified alignment,
50+
// 3. the remaining elements serve to store all the sizes and strides of the
51+
// memref using LLVM-converted `index` type.
52+
//
53+
// Alignment is performed by allocating `alignment` more bytes than
54+
// requested and shifting the aligned pointer relative to the allocated
55+
// memory. Note: `alignment - <minimum malloc alignment>` would actually be
56+
// sufficient. If alignment is unspecified, the two pointers are equal.
57+
58+
// An `alloca` is converted into a definition of a memref descriptor value and
59+
// an llvm.alloca to allocate the underlying data buffer.
60+
LogicalResult
61+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
62+
ConversionPatternRewriter &rewriter) const override;
63+
};
64+
65+
} // namespace mlir
66+
67+
#endif // MLIR_CONVERSION_MEMREFTOLLVM_ALLOCLIKECONVERSION_H
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MemRefToLLVM.h - MemRef to LLVM dialect conversion -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MEMREFTOLLVM_MEMREFTOLLVM_H
10+
#define MLIR_CONVERSION_MEMREFTOLLVM_MEMREFTOLLVM_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
class LLVMTypeConverter;
17+
class RewritePatternSet;
18+
19+
/// Collect a set of patterns to convert memory-related operations from the
20+
/// MemRef dialect to the LLVM dialect.
21+
void populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
22+
RewritePatternSet &patterns);
23+
24+
std::unique_ptr<Pass> createMemRefToLLVMPass();
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_MEMREFTOLLVM_MEMREFTOLLVM_H

mlir/include/mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#ifndef MLIR_CONVERSION_OPENACCTOLLVM_CONVERTOPENACCTOLLVM_H
99
#define MLIR_CONVERSION_OPENACCTOLLVM_CONVERTOPENACCTOLLVM_H
1010

11-
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
11+
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
1212
#include <memory>
1313

1414
namespace mlir {

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
2424
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
2525
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
26+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2627
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
2728
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
2829
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,24 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
255255
let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
256256
}
257257

258+
//===----------------------------------------------------------------------===//
259+
// MemRefToLLVM
260+
//===----------------------------------------------------------------------===//
261+
262+
def ConvertMemRefToLLVM : Pass<"convert-memref-to-llvm", "ModuleOp"> {
263+
let summary = "Convert operations from the MemRef dialect to the LLVM "
264+
"dialect";
265+
let constructor = "mlir::createMemRefToLLVMPass()";
266+
let dependentDialects = ["LLVM::LLVMDialect"];
267+
let options = [
268+
Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
269+
"Use aligned_alloc in place of malloc for heap allocations">,
270+
Option<"indexBitwidth", "index-bitwidth", "unsigned",
271+
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
272+
"Bitwidth of the index type, 0 to use size of machine word">,
273+
];
274+
}
275+
258276
//===----------------------------------------------------------------------===//
259277
// OpenACCToSCF
260278
//===----------------------------------------------------------------------===//
@@ -434,8 +452,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
434452
let constructor = "mlir::createLowerToLLVMPass()";
435453
let dependentDialects = ["LLVM::LLVMDialect"];
436454
let options = [
437-
Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
438-
"Use aligned_alloc in place of malloc for heap allocations">,
439455
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
440456
/*default=*/"false",
441457
"Replace FuncOp's MemRef arguments with bare pointers to the MemRef "

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,12 @@
1515
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
1616
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H
1717

18-
#include "mlir/Conversion/LLVMCommon/Pattern.h"
19-
2018
namespace mlir {
2119

20+
class MLIRContext;
2221
class LLVMTypeConverter;
2322
class RewritePatternSet;
2423

25-
/// Collect a set of patterns to convert memory-related operations from the
26-
/// Standard dialect to the LLVM dialect, excluding non-memory-related
27-
/// operations and FuncOp.
28-
void populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter &converter,
29-
RewritePatternSet &patterns);
30-
31-
/// Collect a set of patterns to convert from the Standard dialect to the LLVM
32-
/// dialect, excluding the memory-related operations.
33-
void populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter &converter,
34-
RewritePatternSet &patterns);
35-
3624
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
3725
/// `emitCWrappers` is set, the pattern will also produce functions
3826
/// that pass memref descriptors by pointer-to-structure in addition to the
@@ -47,62 +35,6 @@ void populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
4735
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
4836
RewritePatternSet &patterns);
4937

50-
/// Lowering for AllocOp and AllocaOp.
51-
struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern {
52-
using ConvertToLLVMPattern::createIndexConstant;
53-
using ConvertToLLVMPattern::getIndexType;
54-
using ConvertToLLVMPattern::getVoidPtrType;
55-
56-
explicit AllocLikeOpLLVMLowering(StringRef opName,
57-
LLVMTypeConverter &converter)
58-
: ConvertToLLVMPattern(opName, &converter.getContext(), converter) {}
59-
60-
protected:
61-
// Returns 'input' aligned up to 'alignment'. Computes
62-
// bumped = input + alignement - 1
63-
// aligned = bumped - bumped % alignment
64-
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
65-
Value input, Value alignment);
66-
67-
/// Allocates the underlying buffer. Returns the allocated pointer and the
68-
/// aligned pointer.
69-
virtual std::tuple<Value, Value>
70-
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
71-
Value sizeBytes, Operation *op) const = 0;
72-
73-
private:
74-
static MemRefType getMemRefResultType(Operation *op) {
75-
return op->getResult(0).getType().cast<MemRefType>();
76-
}
77-
78-
// An `alloc` is converted into a definition of a memref descriptor value and
79-
// a call to `malloc` to allocate the underlying data buffer. The memref
80-
// descriptor is of the LLVM structure type where:
81-
// 1. the first element is a pointer to the allocated (typed) data buffer,
82-
// 2. the second element is a pointer to the (typed) payload, aligned to the
83-
// specified alignment,
84-
// 3. the remaining elements serve to store all the sizes and strides of the
85-
// memref using LLVM-converted `index` type.
86-
//
87-
// Alignment is performed by allocating `alignment` more bytes than
88-
// requested and shifting the aligned pointer relative to the allocated
89-
// memory. Note: `alignment - <minimum malloc alignment>` would actually be
90-
// sufficient. If alignment is unspecified, the two pointers are equal.
91-
92-
// An `alloca` is converted into a definition of a memref descriptor value and
93-
// an llvm.alloca to allocate the underlying data buffer.
94-
LogicalResult
95-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
96-
ConversionPatternRewriter &rewriter) const override;
97-
};
98-
99-
/// Derived class that automatically populates legalization information for
100-
/// different LLVM ops.
101-
class LLVMConversionTarget : public ConversionTarget {
102-
public:
103-
explicit LLVMConversionTarget(MLIRContext &ctx);
104-
};
105-
10638
} // namespace mlir
10739

10840
#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H

mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
1010

1111
#include "../PassDetail.h"
12+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1214
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
1315
#include "mlir/Dialect/Async/IR/Async.h"
1416
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRAsyncToLLVM
1212

1313
LINK_LIBS PUBLIC
1414
MLIRAsync
15+
MLIRLLVMCommonConversion
1516
MLIRLLVMIR
1617
MLIRStandardOpsTransforms
1718
MLIRStandardToLLVM

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_subdirectory(LinalgToSPIRV)
1313
add_subdirectory(LinalgToStandard)
1414
add_subdirectory(LLVMCommon)
1515
add_subdirectory(MathToLibm)
16+
add_subdirectory(MemRefToLLVM)
1617
add_subdirectory(OpenACCToLLVM)
1718
add_subdirectory(OpenACCToSCF)
1819
add_subdirectory(OpenMPToLLVM)

mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,5 @@ add_mlir_conversion_library(MLIRComplexToLLVM
1515
MLIRLLVMCommonConversion
1616
MLIRLLVMIR
1717
MLIRStandardOpsTransforms
18-
MLIRStandardToLLVM
1918
MLIRTransforms
2019
)

mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
1010

1111
#include "../PassDetail.h"
12-
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1314
#include "mlir/Dialect/Complex/IR/Complex.h"
1415
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516

mlir/lib/Conversion/GPUCommon/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms
3434
MLIRIR
3535
MLIRLLVMCommonConversion
3636
MLIRLLVMIR
37+
MLIRMemRefToLLVM
3738
MLIRPass
3839
MLIRSupport
3940
MLIRStandardToLLVM

0 commit comments

Comments
 (0)