diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h index 9cb43689d1ce6..0dfd141214180 100644 --- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h @@ -10,11 +10,14 @@ #define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H namespace mlir { +class DialectRegistry; class RewritePatternSet; class TypeConverter; void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns); + +void registerConvertArithToEmitCInterface(DialectRegistry ®istry); } // namespace mlir #endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h b/mlir/include/mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h new file mode 100644 index 0000000000000..8b4a5483816b7 --- /dev/null +++ b/mlir/include/mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h @@ -0,0 +1,24 @@ +//===- ConvertToEmitCPass.h - Conversion to EmitC pass ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONVERTTOEMITC_CONVERTTOEMITCPASS_H +#define MLIR_CONVERSION_CONVERTTOEMITC_CONVERTTOEMITCPASS_H + +#include "llvm/ADT/SmallVector.h" + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTTOEMITC +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_CONVERSION_CONVERTTOEMITC_CONVERTTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h b/mlir/include/mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h new file mode 100644 index 0000000000000..d438f27006232 --- /dev/null +++ b/mlir/include/mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h @@ -0,0 +1,45 @@ +//===- ToEmitCInterface.h - Conversion to EmitC iface ---*- C++ -*-===========// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONVERTTOEMITC_TOEMITCINTERFACE_H +#define MLIR_CONVERSION_CONVERTTOEMITC_TOEMITCINTERFACE_H + +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +class ConversionTarget; +class TypeConverter; +class MLIRContext; +class Operation; +class RewritePatternSet; +class AnalysisManager; + +class ConvertToEmitCPatternInterface + : public DialectInterface::Base { +public: + ConvertToEmitCPatternInterface(Dialect *dialect) : Base(dialect) {} + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + virtual void populateConvertToEmitCConversionPatterns( + ConversionTarget &target, TypeConverter &typeConverter, + RewritePatternSet &patterns) const = 0; +}; + +/// Recursively walk the IR and collect all dialects implementing the interface, +/// and populate the conversion patterns. +void populateConversionTargetFromOperation(Operation *op, + ConversionTarget &target, + TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_CONVERTTOEMITC_TOEMITCINTERFACE_H diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h index 5c7f87e470306..be6b6cfe5a6db 100644 --- a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h @@ -10,9 +10,14 @@ #define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H namespace mlir { +class DialectRegistry; class RewritePatternSet; +class TypeConverter; -void populateFuncToEmitCPatterns(RewritePatternSet &patterns); +void populateFuncToEmitCPatterns(const TypeConverter &typeConverter, + RewritePatternSet &patterns); + +void registerConvertFuncToEmitCInterface(DialectRegistry ®istry); } // namespace mlir #endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index 950d50229bac1..364a70ce6469b 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H namespace mlir { +class DialectRegistry; class RewritePatternSet; class TypeConverter; @@ -16,6 +17,8 @@ void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter); void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, const TypeConverter &converter); + +void registerConvertMemRefToEmitCInterface(DialectRegistry ®istry); } // namespace mlir #endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index ccd862f67c068..c9d2a54433736 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h" +#include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..ed88ada9778d2 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -12,6 +12,23 @@ include "mlir/Pass/PassBase.td" include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" +//===----------------------------------------------------------------------===// +// ToEmitC +//===----------------------------------------------------------------------===// + +def ConvertToEmitC : Pass<"convert-to-emitc"> { + let summary = "Convert to EmitC dialect via dialect interfaces"; + let description = [{ + This is a generic pass to convert to the EmitC dialect, it uses the + `ConvertToEmitCPatternInterface` dialect interface to delegate to dialects + the injection of conversion patterns. + }]; + let options = [ + ListOption<"filterDialects", "filter-dialects", "std::string", + "Test conversion patterns of only the specified dialects">, + ]; +} + //===----------------------------------------------------------------------===// // ToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h index acc39e6acf726..493da54e5294f 100644 --- a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h +++ b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h @@ -13,6 +13,7 @@ #include namespace mlir { +class DialectRegistry; class Pass; class RewritePatternSet; @@ -22,6 +23,8 @@ class RewritePatternSet; /// Collect a set of patterns to convert SCF operations to the EmitC dialect. void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter); + +void registerConvertSCFToEmitCInterface(DialectRegistry ®istry); } // namespace mlir #endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 373644d31a46c..37e4904cb48ed 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,18 +14,22 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AMX/Transforms.h" @@ -63,18 +67,22 @@ namespace mlir { /// pipelines and transformations you are using. inline void registerAllExtensions(DialectRegistry ®istry) { // Register all conversions to LLVM extensions. + registerConvertArithToEmitCInterface(registry); arith::registerConvertArithToLLVMInterface(registry); registerConvertComplexToLLVMInterface(registry); cf::registerConvertControlFlowToLLVMInterface(registry); func::registerAllExtensions(registry); tensor::registerAllExtensions(registry); + registerConvertFuncToEmitCInterface(registry); registerConvertFuncToLLVMInterface(registry); index::registerConvertIndexToLLVMInterface(registry); registerConvertMathToLLVMInterface(registry); mpi::registerConvertMPIToLLVMInterface(registry); + registerConvertMemRefToEmitCInterface(registry); registerConvertMemRefToLLVMInterface(registry); registerConvertNVVMToLLVMInterface(registry); registerConvertOpenMPToLLVMInterface(registry); + registerConvertSCFToEmitCInterface(registry); ub::registerConvertUBToLLVMInterface(registry); registerConvertAMXToLLVMInterface(registry); gpu::registerConvertGpuToLLVMInterface(registry); diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 359d7b2279639..a5c08a6378021 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" @@ -22,6 +23,27 @@ using namespace mlir; +namespace { +/// Implement the interface to convert Arith to EmitC. +struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface { + using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToEmitCConversionPatterns( + ConversionTarget &target, TypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateArithToEmitCPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertArithToEmitCInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + dialect->addInterfaces(); + }); +} + //===----------------------------------------------------------------------===// // Conversion Patterns //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index b6c21440c571c..e4b4974600577 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(ComplexToStandard) add_subdirectory(ControlFlowToLLVM) add_subdirectory(ControlFlowToSCF) add_subdirectory(ControlFlowToSPIRV) +add_subdirectory(ConvertToEmitC) add_subdirectory(ConvertToLLVM) add_subdirectory(FuncToEmitC) add_subdirectory(FuncToLLVM) diff --git a/mlir/lib/Conversion/ConvertToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ConvertToEmitC/CMakeLists.txt new file mode 100644 index 0000000000000..e0d766570d5eb --- /dev/null +++ b/mlir/lib/Conversion/ConvertToEmitC/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRConvertToEmitC + ConvertToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithToEmitC + MLIRFuncToEmitC + MLIRMemRefToEmitC + MLIRPass + MLIRSCFToEmitC + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp new file mode 100644 index 0000000000000..c9b1dc19ab0dd --- /dev/null +++ b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp @@ -0,0 +1,223 @@ +//===- ConvertToEmitCPass.cpp - Conversion to EmitC pass --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h" + +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +#include + +#define DEBUG_TYPE "convert-to-emitc" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +/// Base class for creating the internal implementation of `convert-to-emitc` +/// passes. +class ConvertToEmitCPassInterface { +public: + ConvertToEmitCPassInterface(MLIRContext *context, + ArrayRef filterDialects); + virtual ~ConvertToEmitCPassInterface() = default; + + /// Get the dependent dialects used by `convert-to-emitc`. + static void getDependentDialects(DialectRegistry ®istry); + + /// Initialize the internal state of the `convert-to-emitc` pass + /// implementation. This method is invoked by `ConvertToEmitC::initialize`. + /// This method returns whether the initialization process failed. + virtual LogicalResult initialize() = 0; + + /// Transform `op` to the EmitC dialect with the conversions available in the + /// pass. The analysis manager can be used to query analyzes like + /// `DataLayoutAnalysis` to further configure the conversion process. This + /// method is invoked by `ConvertToEmitC::runOnOperation`. This method returns + /// whether the transformation process failed. + virtual LogicalResult transform(Operation *op, + AnalysisManager manager) const = 0; + +protected: + /// Visit the `ConvertToEmitCPatternInterface` dialect interfaces and call + /// `visitor` with each of the interfaces. If `filterDialects` is non-empty, + /// then `visitor` is invoked only with the dialects in the `filterDialects` + /// list. + LogicalResult visitInterfaces( + llvm::function_ref visitor); + MLIRContext *context; + /// List of dialects names to use as filters. + ArrayRef filterDialects; +}; + +/// This DialectExtension can be attached to the context, which will invoke the +/// `apply()` method for every loaded dialect. If a dialect implements the +/// `ConvertToEmitCPatternInterface` interface, we load dependent dialects +/// through the interface. This extension is loaded in the context before +/// starting a pass pipeline that involves dialect conversion to the EmitC +/// dialect. +class LoadDependentDialectExtension : public DialectExtensionBase { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension) + + LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {} + + void apply(MLIRContext *context, + MutableArrayRef dialects) const final { + LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC extension load\n"); + for (Dialect *dialect : dialects) { + auto *iface = dyn_cast(dialect); + if (!iface) + continue; + LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC found dialect interface for " + << dialect->getNamespace() << "\n"); + } + } + + /// Return a copy of this extension. + std::unique_ptr clone() const final { + return std::make_unique(*this); + } +}; + +//===----------------------------------------------------------------------===// +// StaticConvertToEmitC +//===----------------------------------------------------------------------===// + +/// Static implementation of the `convert-to-emitc` pass. This version only +/// looks at dialect interfaces to configure the conversion process. +struct StaticConvertToEmitC : public ConvertToEmitCPassInterface { + /// Pattern set with conversions to the EmitC dialect. + std::shared_ptr patterns; + /// The conversion target. + std::shared_ptr target; + /// The type converter. + std::shared_ptr typeConverter; + using ConvertToEmitCPassInterface::ConvertToEmitCPassInterface; + + /// Configure the conversion to EmitC at pass initialization. + LogicalResult initialize() final { + auto target = std::make_shared(*context); + auto typeConverter = std::make_shared(); + + // Add fallback identity converison. + typeConverter->addConversion([](Type type) -> std::optional { + if (emitc::isSupportedEmitCType(type)) + return type; + return std::nullopt; + }); + + RewritePatternSet tempPatterns(context); + target->addLegalDialect(); + // Populate the patterns with the dialect interface. + if (failed(visitInterfaces([&](ConvertToEmitCPatternInterface *iface) { + iface->populateConvertToEmitCConversionPatterns( + *target, *typeConverter, tempPatterns); + }))) + return failure(); + this->patterns = + std::make_unique(std::move(tempPatterns)); + this->target = target; + this->typeConverter = typeConverter; + return success(); + } + + /// Apply the conversion driver. + LogicalResult transform(Operation *op, AnalysisManager manager) const final { + if (failed(applyPartialConversion(op, *target, *patterns))) + return failure(); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertToEmitC +//===----------------------------------------------------------------------===// + +/// This is a generic pass to convert to the EmitC dialect. It uses the +/// `ConvertToEmitCPatternInterface` dialect interface to delegate the injection +/// of conversion patterns to dialects. +class ConvertToEmitC : public impl::ConvertToEmitCBase { + std::shared_ptr impl; + +public: + using impl::ConvertToEmitCBase::ConvertToEmitCBase; + void getDependentDialects(DialectRegistry ®istry) const final { + ConvertToEmitCPassInterface::getDependentDialects(registry); + } + + LogicalResult initialize(MLIRContext *context) final { + std::shared_ptr impl; + impl = std::make_shared(context, filterDialects); + if (failed(impl->initialize())) + return failure(); + this->impl = impl; + return success(); + } + + void runOnOperation() final { + if (failed(impl->transform(getOperation(), getAnalysisManager()))) + return signalPassFailure(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// ConvertToEmitCPassInterface +//===----------------------------------------------------------------------===// + +ConvertToEmitCPassInterface::ConvertToEmitCPassInterface( + MLIRContext *context, ArrayRef filterDialects) + : context(context), filterDialects(filterDialects) {} + +void ConvertToEmitCPassInterface::getDependentDialects( + DialectRegistry ®istry) { + registry.insert(); + registry.addExtensions(); +} + +LogicalResult ConvertToEmitCPassInterface::visitInterfaces( + llvm::function_ref visitor) { + if (!filterDialects.empty()) { + // Test mode: Populate only patterns from the specified dialects. Produce + // an error if the dialect is not loaded or does not implement the + // interface. + for (StringRef dialectName : filterDialects) { + Dialect *dialect = context->getLoadedDialect(dialectName); + if (!dialect) + return emitError(UnknownLoc::get(context)) + << "dialect not loaded: " << dialectName << "\n"; + auto *iface = dyn_cast(dialect); + if (!iface) + return emitError(UnknownLoc::get(context)) + << "dialect does not implement ConvertToEmitCPatternInterface: " + << dialectName << "\n"; + visitor(iface); + } + } else { + // Normal mode: Populate all patterns from all dialects that implement the + // interface. + for (Dialect *dialect : context->getLoadedDialects()) { + auto *iface = dyn_cast(dialect); + if (!iface) + continue; + visitor(iface); + } + } + return success(); +} diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp index 53b79839da04c..f8dc06f41ab87 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -13,12 +13,35 @@ #include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +namespace { + +/// Implement the interface to convert Func to EmitC. +struct FuncToEmitCDialectInterface : public ConvertToEmitCPatternInterface { + using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToEmitCConversionPatterns( + ConversionTarget &target, TypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateFuncToEmitCPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertFuncToEmitCInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { + dialect->addInterfaces(); + }); +} + //===----------------------------------------------------------------------===// // Conversion Patterns //===----------------------------------------------------------------------===// @@ -51,14 +74,36 @@ class FuncOpConversion final : public OpConversionPattern { LogicalResult matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + FunctionType fnType = funcOp.getFunctionType(); - if (funcOp.getFunctionType().getNumResults() > 1) + if (fnType.getNumResults() > 1) return rewriter.notifyMatchFailure( funcOp, "only functions with zero or one result can be converted"); + TypeConverter::SignatureConversion signatureConverter( + fnType.getNumInputs()); + for (const auto &argType : enumerate(fnType.getInputs())) { + auto convertedType = getTypeConverter()->convertType(argType.value()); + if (!convertedType) + return rewriter.notifyMatchFailure(funcOp, + "argument type conversion failed"); + signatureConverter.addInputs(argType.index(), convertedType); + } + + Type resultType; + if (fnType.getNumResults() == 1) { + resultType = getTypeConverter()->convertType(fnType.getResult(0)); + if (!resultType) + return rewriter.notifyMatchFailure(funcOp, + "result type conversion failed"); + } + // Create the converted `emitc.func` op. emitc::FuncOp newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); + funcOp.getLoc(), funcOp.getName(), + FunctionType::get(rewriter.getContext(), + signatureConverter.getConvertedTypes(), + resultType ? TypeRange(resultType) : TypeRange())); // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { @@ -80,9 +125,13 @@ class FuncOpConversion final : public OpConversionPattern { newFuncOp.setSpecifiersAttr(specifiers); } - if (!funcOp.isDeclaration()) + if (!funcOp.isDeclaration()) { rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); + if (failed(rewriter.convertRegionTypes( + &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) + return failure(); + } rewriter.eraseOp(funcOp); return success(); @@ -112,8 +161,10 @@ class ReturnOpConversion final : public OpConversionPattern { // Pattern population //===----------------------------------------------------------------------===// -void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) { +void mlir::populateFuncToEmitCPatterns(const TypeConverter &typeConverter, + RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); - patterns.add(ctx); + patterns.add( + typeConverter, ctx); } diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp index 0b97f2641ad08..5b59e7675d7c6 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp @@ -39,7 +39,11 @@ void ConvertFuncToEmitC::runOnOperation() { target.addIllegalOp(); RewritePatternSet patterns(&getContext()); - populateFuncToEmitCPatterns(patterns); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + + populateFuncToEmitCPatterns(typeConverter, patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 39532d34f616e..c69890a10d61e 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" @@ -20,6 +21,32 @@ using namespace mlir; +namespace { +/// Implement the interface to convert MemRef to EmitC. +struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { + using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToEmitCConversionPatterns( + ConversionTarget &target, TypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateMemRefToEmitCTypeConversion(typeConverter); + populateMemRefToEmitCConversionPatterns(patterns, typeConverter); + } +}; +} // namespace + +void mlir::registerConvertMemRefToEmitCInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + dialect->addInterfaces(); + }); +} + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + namespace { struct ConvertAlloca final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -179,6 +206,19 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { return emitc::ArrayType::get(memRefType.getShape(), convertedElementType); }); + + auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + + return builder.create(loc, resultType, inputs) + .getResult(0); + }; + + typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); + typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); } void mlir::populateMemRefToEmitCConversionPatterns( diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index 33097c71e70b1..cf25c09a2c2f3 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -40,19 +40,6 @@ struct ConvertMemRefToEmitCPass populateMemRefToEmitCTypeConversion(converter); - auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> Value { - if (inputs.size() != 1) - return Value(); - - return builder.create(loc, resultType, inputs) - .getResult(0); - }; - - converter.addSourceMaterialization(materializeAsUnrealizedCast); - converter.addTargetMaterialization(materializeAsUnrealizedCast); - RewritePatternSet patterns(&getContext()); populateMemRefToEmitCConversionPatterns(patterns, converter); diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index d81e92c842369..345e8494194eb 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" @@ -34,6 +35,29 @@ using namespace mlir::scf; namespace { +/// Implement the interface to convert SCF to EmitC. +struct SCFToEmitCDialectInterface : public ConvertToEmitCPatternInterface { + using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToEmitCConversionPatterns( + ConversionTarget &target, TypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateEmitCSizeTTypeConversions(typeConverter); + populateSCFToEmitCConversionPatterns(patterns, typeConverter); + } +}; +} // namespace + +void mlir::registerConvertSCFToEmitCInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { + dialect->addInterfaces(); + }); +} + +namespace { + struct SCFToEmitCPass : public impl::SCFToEmitCBase { void runOnOperation() override; }; diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp index 042acf6100900..ebcb951cf3518 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" @@ -48,6 +49,7 @@ void arith::ArithDialect::initialize() { #include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc" >(); addInterfaces(); + declarePromisedInterface(); declarePromisedInterface(); declarePromisedInterface(); diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index b4d7482554fbc..0c68086b7cd17 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1044,7 +1044,9 @@ Type emitc::ArrayType::parse(AsmParser &parser) { // Check that array is formed from allowed types. if (!isValidElementType(elementType)) - return parser.emitError(typeLoc, "invalid array element type"), Type(); + return parser.emitError(typeLoc, "invalid array element type '") + << elementType << "'", + Type(); if (parser.parseGreater()) return Type(); return parser.getChecked(dimensions, elementType); diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index ba7b84f27d6a8..aa068966ba607 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/BuiltinOps.h" @@ -41,6 +42,7 @@ void FuncDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" >(); + declarePromisedInterface(); declarePromisedInterface(); declarePromisedInterface(); declarePromisedInterfaces(); addInterfaces(); + declarePromisedInterface(); declarePromisedInterface(); declarePromisedInterfaces(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 344941da260fe..48726db9a3fa3 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" @@ -73,6 +74,7 @@ void SCFDialect::initialize() { #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc" >(); addInterfaces(); + declarePromisedInterface(); declarePromisedInterfaces(); declarePromisedInterfaces index { + // expected-error @+1 {{type mismatch for bb argument #0 of successor #0}} + cf.cond_br %arg0, ^bb1(%arg1: index), ^bb2(%arg2: index) +^bb1(%0: index): + return %0 : index +^bb2(%1: index): + return %1 : index +} diff --git a/mlir/test/Conversion/ConvertToEmitC/func.mlir b/mlir/test/Conversion/ConvertToEmitC/func.mlir new file mode 100644 index 0000000000000..4f2518401581f --- /dev/null +++ b/mlir/test/Conversion/ConvertToEmitC/func.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt -convert-to-emitc %s | FileCheck %s + +// CHECK-LABEL emitc.func @int(%[[ARG:.*]]: i32) -> i32 +func.func @int(%arg0: i32) -> i32 { + // CHECK: return + return %arg0 : i32 +} + +// CHECK-LABEL emitc.func @index(%[[ARG:.*]]: !emitc.size_t) -> !emitc.size_t +func.func @index(%arg0: index) -> index { + // CHECK: return + return %arg0 : index +} diff --git a/mlir/test/Conversion/ConvertToEmitC/tosa.mlir b/mlir/test/Conversion/ConvertToEmitC/tosa.mlir new file mode 100644 index 0000000000000..8ced05eced4b9 --- /dev/null +++ b/mlir/test/Conversion/ConvertToEmitC/tosa.mlir @@ -0,0 +1,41 @@ +// DEFINE: %{pipeline} = "builtin.module(\ +// DEFINE: func.func(\ +// DEFINE: tosa-to-linalg\ +// DEFINE: ),\ +// DEFINE: one-shot-bufferize{\ +// DEFINE: bufferize-function-boundaries\ +// DEFINE: function-boundary-type-conversion=identity-layout-map\ +// DEFINE: buffer-alignment=0\ +// DEFINE: },\ +// DEFINE: buffer-results-to-out-params{\ +// DEFINE: hoist-static-allocs=true\ +// DEFINE: },\ +// DEFINE: func.func(\ +// DEFINE: convert-linalg-to-loops\ +// DEFINE: ),\ +// DEFINE: canonicalize,\ +// DEFINE: convert-to-emitc\ +// DEFINE: )" + +// RUN: mlir-opt --pass-pipeline=%{pipeline} %s | FileCheck %s +// ----- + +// CHECK: emitc.func @main(%[[ARG0:.*]]: !emitc.array<2xf32>, %[[ARG1:.*]]: !emitc.array<2xf32>, %[[RES:.*]]: !emitc.array<2xf32>) { +// CHECK-DAG: %[[C0:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t +// CHECK-DAG: %[[C1:.*]] = "emitc.constant"() <{value = 1 : index}> : () -> !emitc.size_t +// CHECK-DAG: %[[C2:.*]] = "emitc.constant"() <{value = 2 : index}> : () -> !emitc.size_t +// CHECK-NEXT: for %[[INDEX:.*]] = %[[C0]] to %[[C2]] step %[[C1]] : !emitc.size_t { +// CHECK-NEXT: %[[V0_LVALUE:.*]] = subscript %[[ARG0]][%[[INDEX]]] : (!emitc.array<2xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: %[[V0:.*]] = load %[[V0_LVALUE]] : +// CHECK-NEXT: %[[V1_LVALUE:.*]] = subscript %[[ARG1]][%[[INDEX]]] : (!emitc.array<2xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: %[[V1:.*]] = load %[[V1_LVALUE]] : +// CHECK-NEXT: %[[VADD:.*]] = add %[[V0]], %[[V1]] : (f32, f32) -> f32 +// CHECK-NEXT: %[[RES_LVALUE:.*]] = subscript %[[RES]][%[[INDEX]]] : (!emitc.array<2xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: assign %[[VADD]] : f32 to %[[RES_LVALUE]] : +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = tosa.add %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} diff --git a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir index bd48886ed739e..6824a64dda3ef 100644 --- a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir +++ b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -split-input-file -convert-func-to-emitc %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-to-emitc="filter-dialects=func" %s | FileCheck %s // CHECK-LABEL: emitc.func @foo() // CHECK-NEXT: return diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index f5ef821cc9c05..d37fd1de90add 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s +// RUN: mlir-opt -convert-to-emitc="filter-dialects=memref" %s -split-input-file | FileCheck %s // CHECK-LABEL: alloca() func.func @alloca() { diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir index 232a0fb2e8252..571517fe7bf19 100644 --- a/mlir/test/Conversion/SCFToEmitC/for.mlir +++ b/mlir/test/Conversion/SCFToEmitC/for.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-to-emitc="filter-dialects=scf" %s | FileCheck %s func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) { scf.for %i0 = %arg0 to %arg1 step %arg2 { diff --git a/mlir/test/Conversion/SCFToEmitC/if.mlir b/mlir/test/Conversion/SCFToEmitC/if.mlir index 9acd9b0783d21..596a4872b3030 100644 --- a/mlir/test/Conversion/SCFToEmitC/if.mlir +++ b/mlir/test/Conversion/SCFToEmitC/if.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-to-emitc="filter-dialects=scf" %s | FileCheck %s func.func @test_if(%arg0: i1, %arg1: f32) { scf.if %arg0 { diff --git a/mlir/test/Conversion/SCFToEmitC/switch.mlir b/mlir/test/Conversion/SCFToEmitC/switch.mlir index 3f0793ccd7e3b..0dab98c8cc341 100644 --- a/mlir/test/Conversion/SCFToEmitC/switch.mlir +++ b/mlir/test/Conversion/SCFToEmitC/switch.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-to-emitc="filter-dialects=scf" %s | FileCheck %s // CHECK-LABEL: func.func @switch_no_result( // CHECK-SAME: %[[ARG_0:.*]]: index) {