Skip to content

Commit 26e59cc

Browse files
committed
[mlir] factor math-to-llvm out of standard-to-llvm
After the Math has been split out of the Standard dialect, the conversion to the LLVM dialect remained as a huge monolithic pass. This is undesirable for the same complexity management reasons as having a huge Standard dialect itself, and is even more confusing given the existence of a separate dialect. Extract the conversion of the Math dialect operations to LLVM into a separate library and a separate conversion pass. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D105702
1 parent d2e4ccc commit 26e59cc

File tree

10 files changed

+433
-320
lines changed

10 files changed

+433
-320
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- MathToLLVM.h - Math to LLVM dialect conversion -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
10+
#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
16+
class LLVMTypeConverter;
17+
class RewritePatternSet;
18+
class Pass;
19+
20+
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
21+
RewritePatternSet &patterns);
22+
23+
std::unique_ptr<Pass> createConvertMathToLLVMPass();
24+
} // namespace mlir
25+
26+
#endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
2323
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
2424
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
25+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2526
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
2627
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
2728
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,19 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
255255
let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
256256
}
257257

258+
//===----------------------------------------------------------------------===//
259+
// MathToLLVM
260+
//===----------------------------------------------------------------------===//
261+
262+
def ConvertMathToLLVM : FunctionPass<"convert-math-to-llvm"> {
263+
let summary = "Convert Math dialect to LLVM dialect";
264+
let description = [{
265+
This pass converts supported Math ops to LLVM dialect intrinsics.
266+
}];
267+
let constructor = "mlir::createConvertMathToLLVMPass()";
268+
let dependentDialects = ["LLVM::LLVMDialect"];
269+
}
270+
258271
//===----------------------------------------------------------------------===//
259272
// MemRefToLLVM
260273
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_subdirectory(LinalgToSPIRV)
1313
add_subdirectory(LinalgToStandard)
1414
add_subdirectory(LLVMCommon)
1515
add_subdirectory(MathToLibm)
16+
add_subdirectory(MathToLLVM)
1617
add_subdirectory(MemRefToLLVM)
1718
add_subdirectory(OpenACCToLLVM)
1819
add_subdirectory(OpenACCToSCF)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRMathToLLVM
2+
MathToLLVM.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRLLVMCommonConversion
15+
MLIRLLVMIR
16+
MLIRMath
17+
MLIRPass
18+
MLIRTransforms
19+
)
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
10+
#include "../PassDetail.h"
11+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
13+
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15+
#include "mlir/Dialect/Math/IR/Math.h"
16+
#include "mlir/IR/TypeUtilities.h"
17+
18+
using namespace mlir;
19+
20+
namespace {
21+
using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
22+
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
23+
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
24+
using Log10OpLowering =
25+
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
26+
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
27+
using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
28+
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
29+
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
30+
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
31+
32+
// A `expm1` is converted into `exp - 1`.
33+
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
34+
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
35+
36+
LogicalResult
37+
matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
38+
ConversionPatternRewriter &rewriter) const override {
39+
math::ExpM1Op::Adaptor transformed(operands);
40+
auto operandType = transformed.operand().getType();
41+
42+
if (!operandType || !LLVM::isCompatibleType(operandType))
43+
return failure();
44+
45+
auto loc = op.getLoc();
46+
auto resultType = op.getResult().getType();
47+
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
48+
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
49+
50+
if (!operandType.isa<LLVM::LLVMArrayType>()) {
51+
LLVM::ConstantOp one;
52+
if (LLVM::isCompatibleVectorType(operandType)) {
53+
one = rewriter.create<LLVM::ConstantOp>(
54+
loc, operandType,
55+
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
56+
} else {
57+
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
58+
}
59+
auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
60+
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
61+
return success();
62+
}
63+
64+
auto vectorType = resultType.dyn_cast<VectorType>();
65+
if (!vectorType)
66+
return rewriter.notifyMatchFailure(op, "expected vector result type");
67+
68+
return LLVM::detail::handleMultidimensionalVectors(
69+
op.getOperation(), operands, *getTypeConverter(),
70+
[&](Type llvm1DVectorTy, ValueRange operands) {
71+
auto splatAttr = SplatElementsAttr::get(
72+
mlir::VectorType::get(
73+
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
74+
floatType),
75+
floatOne);
76+
auto one =
77+
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
78+
auto exp =
79+
rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
80+
return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
81+
},
82+
rewriter);
83+
}
84+
};
85+
86+
// A `log1p` is converted into `log(1 + ...)`.
87+
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
88+
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
89+
90+
LogicalResult
91+
matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
92+
ConversionPatternRewriter &rewriter) const override {
93+
math::Log1pOp::Adaptor transformed(operands);
94+
auto operandType = transformed.operand().getType();
95+
96+
if (!operandType || !LLVM::isCompatibleType(operandType))
97+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
98+
99+
auto loc = op.getLoc();
100+
auto resultType = op.getResult().getType();
101+
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
102+
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
103+
104+
if (!operandType.isa<LLVM::LLVMArrayType>()) {
105+
LLVM::ConstantOp one =
106+
LLVM::isCompatibleVectorType(operandType)
107+
? rewriter.create<LLVM::ConstantOp>(
108+
loc, operandType,
109+
SplatElementsAttr::get(resultType.cast<ShapedType>(),
110+
floatOne))
111+
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
112+
113+
auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
114+
transformed.operand());
115+
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
116+
return success();
117+
}
118+
119+
auto vectorType = resultType.dyn_cast<VectorType>();
120+
if (!vectorType)
121+
return rewriter.notifyMatchFailure(op, "expected vector result type");
122+
123+
return LLVM::detail::handleMultidimensionalVectors(
124+
op.getOperation(), operands, *getTypeConverter(),
125+
[&](Type llvm1DVectorTy, ValueRange operands) {
126+
auto splatAttr = SplatElementsAttr::get(
127+
mlir::VectorType::get(
128+
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
129+
floatType),
130+
floatOne);
131+
auto one =
132+
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
133+
auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
134+
operands[0]);
135+
return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
136+
},
137+
rewriter);
138+
}
139+
};
140+
141+
// A `rsqrt` is converted into `1 / sqrt`.
142+
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
143+
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
144+
145+
LogicalResult
146+
matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
147+
ConversionPatternRewriter &rewriter) const override {
148+
math::RsqrtOp::Adaptor transformed(operands);
149+
auto operandType = transformed.operand().getType();
150+
151+
if (!operandType || !LLVM::isCompatibleType(operandType))
152+
return failure();
153+
154+
auto loc = op.getLoc();
155+
auto resultType = op.getResult().getType();
156+
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
157+
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
158+
159+
if (!operandType.isa<LLVM::LLVMArrayType>()) {
160+
LLVM::ConstantOp one;
161+
if (LLVM::isCompatibleVectorType(operandType)) {
162+
one = rewriter.create<LLVM::ConstantOp>(
163+
loc, operandType,
164+
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
165+
} else {
166+
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
167+
}
168+
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
169+
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
170+
return success();
171+
}
172+
173+
auto vectorType = resultType.dyn_cast<VectorType>();
174+
if (!vectorType)
175+
return failure();
176+
177+
return LLVM::detail::handleMultidimensionalVectors(
178+
op.getOperation(), operands, *getTypeConverter(),
179+
[&](Type llvm1DVectorTy, ValueRange operands) {
180+
auto splatAttr = SplatElementsAttr::get(
181+
mlir::VectorType::get(
182+
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
183+
floatType),
184+
floatOne);
185+
auto one =
186+
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
187+
auto sqrt =
188+
rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
189+
return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
190+
},
191+
rewriter);
192+
}
193+
};
194+
195+
struct ConvertMathToLLVMPass
196+
: public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
197+
ConvertMathToLLVMPass() = default;
198+
199+
void runOnFunction() override {
200+
RewritePatternSet patterns(&getContext());
201+
LLVMTypeConverter converter(&getContext());
202+
populateMathToLLVMConversionPatterns(converter, patterns);
203+
LLVMConversionTarget target(getContext());
204+
target.addLegalOp<LLVM::DialectCastOp>();
205+
if (failed(
206+
applyPartialConversion(getFunction(), target, std::move(patterns))))
207+
signalPassFailure();
208+
}
209+
};
210+
} // namespace
211+
212+
void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
213+
RewritePatternSet &patterns) {
214+
// clang-format off
215+
patterns.add<
216+
CosOpLowering,
217+
ExpOpLowering,
218+
Exp2OpLowering,
219+
ExpM1OpLowering,
220+
Log10OpLowering,
221+
Log1pOpLowering,
222+
Log2OpLowering,
223+
LogOpLowering,
224+
PowFOpLowering,
225+
RsqrtOpLowering,
226+
SinOpLowering,
227+
SqrtOpLowering
228+
>(converter);
229+
// clang-format on
230+
}
231+
232+
std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
233+
return std::make_unique<ConvertMathToLLVMPass>();
234+
}

0 commit comments

Comments
 (0)