Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===- MathToEmitC.h - Math to EmitC Patterns -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
#include "mlir/Dialect/EmitC/IR/EmitC.h"
namespace mlir {
class RewritePatternSet;
namespace emitc {

/// Enum to specify the language target for EmitC code generation.
enum class LanguageTarget { c99, cpp11 };

} // namespace emitc

void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns,
emitc::LanguageTarget languageTarget);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
21 changes: 21 additions & 0 deletions mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- MathToEmitCPass.h - Math to EmitC Pass -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H

#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
#include <memory>
namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTMATHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
Expand Down
22 changes: 22 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,28 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//

def ConvertMathToEmitC : Pass<"convert-math-to-emitc"> {
let summary = "Convert some Math operations to EmitC call_opaque operations";
let description = [{
This pass converts supported Math ops to `call_opaque` ops targeting libc/libm
functions. Unlike convert-math-to-funcs pass, converting to `call_opaque` ops
allows to overload the same function with different argument types.
}];
let dependentDialects = ["emitc::EmitCDialect"];
let options = [
Option<"languageTarget", "language-target", "::mlir::emitc::LanguageTarget",
/*default=*/"::mlir::emitc::LanguageTarget::c99", "Select the language standard target for callees (c99 or cpp11).",
[{::llvm::cl::values(
clEnumValN(::mlir::emitc::LanguageTarget::c99, "c99", "c99"),
clEnumValN(::mlir::emitc::LanguageTarget::cpp11, "cpp11", "cpp11")
)}]>
];
}

//===----------------------------------------------------------------------===//
// MathToFuncs
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToEmitC)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRMathToEmitC
MathToEmitC.cpp
MathToEmitCPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIREmitCDialect
MLIRMathDialect
MLIRPass
MLIRTransformUtils
)
85 changes: 85 additions & 0 deletions mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {
template <typename OpType>
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
std::string calleeStr;
emitc::LanguageTarget languageTarget;

public:
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
emitc::LanguageTarget languageTarget)
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
languageTarget(languageTarget) {}

LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override;
};

template <typename OpType>
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
OpType op, PatternRewriter &rewriter) const {
if (!llvm::all_of(op->getOperandTypes(),
llvm::IsaPred<Float32Type, Float64Type>) ||
!llvm::all_of(op->getResultTypes(),
llvm::IsaPred<Float32Type, Float64Type>))
return rewriter.notifyMatchFailure(
op.getLoc(),
"expected all operands and results to be of type f32 or f64");
std::string modifiedCalleeStr = calleeStr;
if (languageTarget == emitc::LanguageTarget::cpp11) {
modifiedCalleeStr = "std::" + calleeStr;
} else if (languageTarget == emitc::LanguageTarget::c99) {
auto operandType = op->getOperandTypes()[0];
if (operandType.isF32())
modifiedCalleeStr = calleeStr + "f";
}
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op, op.getType(), modifiedCalleeStr, op->getOperands());
return success();
}

} // namespace

// Populates patterns to replace `math` operations with `emitc.call_opaque`,
// using function names consistent with those in <math.h>.
void mlir::populateConvertMathToEmitCPatterns(
RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) {
auto *context = patterns.getContext();
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
languageTarget);
}
53 changes: 53 additions & 0 deletions mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert the Math dialect to the EmitC dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
namespace {

// Replaces Math operations with `emitc.call_opaque` operations.
struct ConvertMathToEmitC
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> {
using ConvertMathToEmitCBase::ConvertMathToEmitCBase;

public:
void runOnOperation() final;
};

} // namespace

void ConvertMathToEmitC::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalOp<emitc::CallOpaqueOp>();

target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp, math::CosOp,
math::SinOp, math::Atan2Op, math::CeilOp, math::AcosOp,
math::AsinOp, math::AbsFOp, math::PowFOp>();

RewritePatternSet patterns(&getContext());
populateConvertMathToEmitCPatterns(patterns, languageTarget);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
23 changes: 23 additions & 0 deletions mlir/test/Conversion/MathToEmitC/math-to-emitc-failed.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-opt -split-input-file -convert-math-to-emitc -verify-diagnostics %s

func.func @unsupported_tensor_type(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
%0 = math.absf %arg0 : tensor<4xf32>
return %0 : tensor<4xf32>
}

// -----

func.func @unsupported_f16_type(%arg0 : f16) -> f16 {
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
%0 = math.absf %arg0 : f16
return %0 : f16
}

// -----

func.func @unsupported_f128_type(%arg0 : f128) -> f128 {
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
%0 = math.absf %arg0 : f128
return %0 : f128
}
112 changes: 112 additions & 0 deletions mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// RUN: mlir-opt -convert-math-to-emitc=language-target=c99 %s | FileCheck %s --check-prefix=c99
// RUN: mlir-opt -convert-math-to-emitc=language-target=cpp11 %s | FileCheck %s --check-prefix=cpp11

func.func @absf(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "fabsf"
// c99-NEXT: emitc.call_opaque "fabs"
// cpp11: emitc.call_opaque "std::fabs"
// cpp11-NEXT: emitc.call_opaque "std::fabs"
%0 = math.absf %arg0 : f32
%1 = math.absf %arg1 : f64
return
}

func.func @floor(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "floorf"
// c99-NEXT: emitc.call_opaque "floor"
// cpp11: emitc.call_opaque "std::floor"
// cpp11-NEXT: emitc.call_opaque "std::floor"
%0 = math.floor %arg0 : f32
%1 = math.floor %arg1 : f64
return
}

func.func @sin(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "sinf"
// c99-NEXT: emitc.call_opaque "sin"
// cpp11: emitc.call_opaque "std::sin"
// cpp11-NEXT: emitc.call_opaque "std::sin"
%0 = math.sin %arg0 : f32
%1 = math.sin %arg1 : f64
return
}

func.func @cos(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "cosf"
// c99-NEXT: emitc.call_opaque "cos"
// cpp11: emitc.call_opaque "std::cos"
// cpp11-NEXT: emitc.call_opaque "std::cos"
%0 = math.cos %arg0 : f32
%1 = math.cos %arg1 : f64
return
}

func.func @asin(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "asinf"
// c99-NEXT: emitc.call_opaque "asin"
// cpp11: emitc.call_opaque "std::asin"
// cpp11-NEXT: emitc.call_opaque "std::asin"
%0 = math.asin %arg0 : f32
%1 = math.asin %arg1 : f64
return
}

func.func @acos(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "acosf"
// c99-NEXT: emitc.call_opaque "acos"
// cpp11: emitc.call_opaque "std::acos"
// cpp11-NEXT: emitc.call_opaque "std::acos"
%0 = math.acos %arg0 : f32
%1 = math.acos %arg1 : f64
return
}

func.func @atan2(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) {
// c99: emitc.call_opaque "atan2f"
// c99-NEXT: emitc.call_opaque "atan2"
// cpp11: emitc.call_opaque "std::atan2"
// cpp11-NEXT: emitc.call_opaque "std::atan2"
%0 = math.atan2 %arg0, %arg1 : f32
%1 = math.atan2 %arg2, %arg3 : f64
return
}

func.func @ceil(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "ceilf"
// c99-NEXT: emitc.call_opaque "ceil"
// cpp11: emitc.call_opaque "std::ceil"
// cpp11-NEXT: emitc.call_opaque "std::ceil"
%0 = math.ceil %arg0 : f32
%1 = math.ceil %arg1 : f64
return
}

func.func @exp(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "expf"
// c99-NEXT: emitc.call_opaque "exp"
// cpp11: emitc.call_opaque "std::exp"
// cpp11-NEXT: emitc.call_opaque "std::exp"
%0 = math.exp %arg0 : f32
%1 = math.exp %arg1 : f64
return
}

func.func @powf(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) {
// c99: emitc.call_opaque "powf"
// c99-NEXT: emitc.call_opaque "pow"
// cpp11: emitc.call_opaque "std::pow"
// cpp11-NEXT: emitc.call_opaque "std::pow"
%0 = math.powf %arg0, %arg1 : f32
%1 = math.powf %arg2, %arg3 : f64
return
}

func.func @round(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "roundf"
// c99-NEXT: emitc.call_opaque "round"
// cpp11: emitc.call_opaque "std::round"
// cpp11-NEXT: emitc.call_opaque "std::round"
%0 = math.round %arg0 : f32
%1 = math.round %arg1 : f64
return
}
Loading
Loading