-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC #113799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
d5bd00c
[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC
recursion-man ad9af42
[MLIR][MathToEmitC] Ensure scalar type handling and refactor
recursion-man f6c2406
[MLIR][MathToEmitC] Refactor code, add tests for unsupported types, a…
recursion-man 23ada46
[MLIR][MathToEmitC] Add support for C and C++ targets with Lit tests
recursion-man 8992fd5
[MLIR][MathToEmitC] Add language standard option, create LanguageTarg…
recursion-man 52c35a6
[mlir][EmitC] Apply formatting fixes
recursion-man File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
//===- MathToEmitC.h - Math to EmitC Patterns -------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H | ||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
namespace mlir { | ||
class RewritePatternSet; | ||
namespace emitc { | ||
|
||
/// Enum to specify the language target for EmitC code generation. | ||
enum class LanguageTarget { c99, cpp11 }; | ||
|
||
} // namespace emitc | ||
|
||
void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns, | ||
emitc::LanguageTarget languageTarget); | ||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H |
21 changes: 21 additions & 0 deletions
21
mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
//===- MathToEmitCPass.h - Math to EmitC Pass -------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H | ||
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H | ||
|
||
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h" | ||
#include <memory> | ||
namespace mlir { | ||
class Pass; | ||
|
||
#define GEN_PASS_DECL_CONVERTMATHTOEMITC | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
add_mlir_conversion_library(MLIRMathToEmitC | ||
MathToEmitC.cpp | ||
MathToEmitCPass.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
MLIREmitCDialect | ||
MLIRMathDialect | ||
MLIRPass | ||
MLIRTransformUtils | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h" | ||
|
||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
template <typename OpType> | ||
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> { | ||
std::string calleeStr; | ||
emitc::LanguageTarget languageTarget; | ||
|
||
public: | ||
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr, | ||
emitc::LanguageTarget languageTarget) | ||
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)), | ||
languageTarget(languageTarget) {} | ||
|
||
LogicalResult matchAndRewrite(OpType op, | ||
PatternRewriter &rewriter) const override; | ||
}; | ||
|
||
template <typename OpType> | ||
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite( | ||
OpType op, PatternRewriter &rewriter) const { | ||
if (!llvm::all_of(op->getOperandTypes(), | ||
llvm::IsaPred<Float32Type, Float64Type>) || | ||
!llvm::all_of(op->getResultTypes(), | ||
llvm::IsaPred<Float32Type, Float64Type>)) | ||
return rewriter.notifyMatchFailure( | ||
op.getLoc(), | ||
"expected all operands and results to be of type f32 or f64"); | ||
std::string modifiedCalleeStr = calleeStr; | ||
if (languageTarget == emitc::LanguageTarget::cpp11) { | ||
modifiedCalleeStr = "std::" + calleeStr; | ||
} else if (languageTarget == emitc::LanguageTarget::c99) { | ||
auto operandType = op->getOperandTypes()[0]; | ||
if (operandType.isF32()) | ||
modifiedCalleeStr = calleeStr + "f"; | ||
} | ||
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( | ||
op, op.getType(), modifiedCalleeStr, op->getOperands()); | ||
return success(); | ||
} | ||
|
||
} // namespace | ||
|
||
// Populates patterns to replace `math` operations with `emitc.call_opaque`, | ||
// using function names consistent with those in <math.h>. | ||
void mlir::populateConvertMathToEmitCPatterns( | ||
RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) { | ||
auto *context = patterns.getContext(); | ||
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs", | ||
languageTarget); | ||
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow", | ||
languageTarget); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements a pass to convert the Math dialect to the EmitC dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h" | ||
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h" | ||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_CONVERTMATHTOEMITC | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
namespace { | ||
|
||
// Replaces Math operations with `emitc.call_opaque` operations. | ||
struct ConvertMathToEmitC | ||
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> { | ||
using ConvertMathToEmitCBase::ConvertMathToEmitCBase; | ||
|
||
public: | ||
void runOnOperation() final; | ||
}; | ||
|
||
} // namespace | ||
|
||
void ConvertMathToEmitC::runOnOperation() { | ||
ConversionTarget target(getContext()); | ||
target.addLegalOp<emitc::CallOpaqueOp>(); | ||
|
||
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp, math::CosOp, | ||
math::SinOp, math::Atan2Op, math::CeilOp, math::AcosOp, | ||
math::AsinOp, math::AbsFOp, math::PowFOp>(); | ||
|
||
RewritePatternSet patterns(&getContext()); | ||
populateConvertMathToEmitCPatterns(patterns, languageTarget); | ||
|
||
if (failed( | ||
applyPartialConversion(getOperation(), target, std::move(patterns)))) | ||
signalPassFailure(); | ||
} |
23 changes: 23 additions & 0 deletions
23
mlir/test/Conversion/MathToEmitC/math-to-emitc-failed.mlir
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// RUN: mlir-opt -split-input-file -convert-math-to-emitc -verify-diagnostics %s | ||
|
||
func.func @unsupported_tensor_type(%arg0 : tensor<4xf32>) -> tensor<4xf32> { | ||
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} | ||
%0 = math.absf %arg0 : tensor<4xf32> | ||
return %0 : tensor<4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
func.func @unsupported_f16_type(%arg0 : f16) -> f16 { | ||
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} | ||
%0 = math.absf %arg0 : f16 | ||
return %0 : f16 | ||
} | ||
|
||
// ----- | ||
|
||
func.func @unsupported_f128_type(%arg0 : f128) -> f128 { | ||
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}} | ||
%0 = math.absf %arg0 : f128 | ||
return %0 : f128 | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// RUN: mlir-opt -convert-math-to-emitc=language-target=c99 %s | FileCheck %s --check-prefix=c99 | ||
// RUN: mlir-opt -convert-math-to-emitc=language-target=cpp11 %s | FileCheck %s --check-prefix=cpp11 | ||
|
||
func.func @absf(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "fabsf" | ||
// c99-NEXT: emitc.call_opaque "fabs" | ||
// cpp11: emitc.call_opaque "std::fabs" | ||
// cpp11-NEXT: emitc.call_opaque "std::fabs" | ||
%0 = math.absf %arg0 : f32 | ||
%1 = math.absf %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @floor(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "floorf" | ||
// c99-NEXT: emitc.call_opaque "floor" | ||
// cpp11: emitc.call_opaque "std::floor" | ||
// cpp11-NEXT: emitc.call_opaque "std::floor" | ||
%0 = math.floor %arg0 : f32 | ||
%1 = math.floor %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @sin(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "sinf" | ||
// c99-NEXT: emitc.call_opaque "sin" | ||
// cpp11: emitc.call_opaque "std::sin" | ||
// cpp11-NEXT: emitc.call_opaque "std::sin" | ||
%0 = math.sin %arg0 : f32 | ||
%1 = math.sin %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @cos(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "cosf" | ||
// c99-NEXT: emitc.call_opaque "cos" | ||
// cpp11: emitc.call_opaque "std::cos" | ||
// cpp11-NEXT: emitc.call_opaque "std::cos" | ||
%0 = math.cos %arg0 : f32 | ||
%1 = math.cos %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @asin(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "asinf" | ||
// c99-NEXT: emitc.call_opaque "asin" | ||
// cpp11: emitc.call_opaque "std::asin" | ||
// cpp11-NEXT: emitc.call_opaque "std::asin" | ||
%0 = math.asin %arg0 : f32 | ||
%1 = math.asin %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @acos(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "acosf" | ||
// c99-NEXT: emitc.call_opaque "acos" | ||
// cpp11: emitc.call_opaque "std::acos" | ||
// cpp11-NEXT: emitc.call_opaque "std::acos" | ||
%0 = math.acos %arg0 : f32 | ||
%1 = math.acos %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @atan2(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) { | ||
// c99: emitc.call_opaque "atan2f" | ||
// c99-NEXT: emitc.call_opaque "atan2" | ||
// cpp11: emitc.call_opaque "std::atan2" | ||
// cpp11-NEXT: emitc.call_opaque "std::atan2" | ||
%0 = math.atan2 %arg0, %arg1 : f32 | ||
%1 = math.atan2 %arg2, %arg3 : f64 | ||
return | ||
} | ||
|
||
func.func @ceil(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "ceilf" | ||
// c99-NEXT: emitc.call_opaque "ceil" | ||
// cpp11: emitc.call_opaque "std::ceil" | ||
// cpp11-NEXT: emitc.call_opaque "std::ceil" | ||
%0 = math.ceil %arg0 : f32 | ||
%1 = math.ceil %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @exp(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "expf" | ||
// c99-NEXT: emitc.call_opaque "exp" | ||
// cpp11: emitc.call_opaque "std::exp" | ||
// cpp11-NEXT: emitc.call_opaque "std::exp" | ||
%0 = math.exp %arg0 : f32 | ||
%1 = math.exp %arg1 : f64 | ||
return | ||
} | ||
|
||
func.func @powf(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) { | ||
// c99: emitc.call_opaque "powf" | ||
// c99-NEXT: emitc.call_opaque "pow" | ||
// cpp11: emitc.call_opaque "std::pow" | ||
// cpp11-NEXT: emitc.call_opaque "std::pow" | ||
%0 = math.powf %arg0, %arg1 : f32 | ||
%1 = math.powf %arg2, %arg3 : f64 | ||
return | ||
} | ||
|
||
func.func @round(%arg0: f32, %arg1: f64) { | ||
// c99: emitc.call_opaque "roundf" | ||
// c99-NEXT: emitc.call_opaque "round" | ||
// cpp11: emitc.call_opaque "std::round" | ||
// cpp11-NEXT: emitc.call_opaque "std::round" | ||
%0 = math.round %arg0 : f32 | ||
%1 = math.round %arg1 : f64 | ||
return | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.