diff --git a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt index f33061b2d87cf..9f57627c321fb 100644 --- a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index e09c63295515c..b8f8f1e2d818d 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -19,6 +19,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/RegionKindInterface.td" //===----------------------------------------------------------------------===// // EmitC op definitions @@ -247,6 +248,83 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> { let results = (outs FloatIntegerIndexOrOpaqueType); } +def EmitC_ExpressionOp : EmitC_Op<"expression", + [HasOnlyGraphRegion, SingleBlockImplicitTerminator<"emitc::YieldOp">, + NoRegionArguments]> { + let summary = "Expression operation"; + let description = [{ + The `expression` operation returns a single SSA value which is yielded by + its single-basic-block region. The operation doesn't take any arguments. + + As the operation is to be emitted as a C expression, the operations within + its body must form a single Def-Use tree of emitc ops whose result is + yielded by a terminating `yield`. + + Example: + + ```mlir + %r = emitc.expression : () -> i32 { + %0 = emitc.add %a, %b : (i32, i32) -> i32 + %1 = emitc.call "foo"(%0) : () -> i32 + %2 = emitc.add %c, %d : (i32, i32) -> i32 + %3 = emitc.mul %1, %2 : (i32, i32) -> i32 + yield %3 + } + ``` + + May be emitted as + + ```c++ + int32_t v7 = foo(v1 + v2) * (v3 + v4); + ``` + + The operations allowed within expression body are emitc.add, emitc.apply, + emitc.call, emitc.cast, emitc.cmp, emitc.div, emitc.mul, emitc.rem and + emitc.sub. + + When specified, the optional `do_not_inline` indicates that the expression is + to be emitted as seen above, i.e. as the rhs of an EmitC SSA value + definition. Otherwise, the expression may be emitted inline, i.e. directly + at its use. + }]; + + let arguments = (ins UnitAttr:$do_not_inline); + let results = (outs AnyType:$result); + let regions = (region SizedRegion<1>:$region); + + let hasVerifier = 1; + let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region"; + + let extraClassDeclaration = [{ + static bool isCExpression(Operation &op) { + return isa(op); + } + bool hasSideEffects() { + auto predicate = [](Operation &op) { + assert(isCExpression(op) && "Expected a C expression"); + // Conservatively assume calls to read and write memory. + if (isa(op)) + return true; + // De-referencing reads modifiable memory, address-taking has no + // side-effect. + auto applyOp = dyn_cast(op); + if (applyOp) + return applyOp.getApplicableOperator() == "*"; + // Any operation using variables is assumed to have a side effect of + // reading memory mutable by emitc::assign ops. + return llvm::any_of(op.getOperands(), [](Value operand) { + Operation *def = operand.getDefiningOp(); + return def && isa(def); + }); + }; + return llvm::any_of(getRegion().front().without_terminator(), predicate); + }; + Operation *getRootOp(); + }]; +} + def EmitC_ForOp : EmitC_Op<"for", [AllTypesMatch<["lowerBound", "upperBound", "step"]>, SingleBlockImplicitTerminator<"emitc::YieldOp">, @@ -494,18 +572,24 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> { } def EmitC_YieldOp : EmitC_Op<"yield", - [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> { + [Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> { let summary = "block termination operation"; let description = [{ - "yield" terminates blocks within EmitC control-flow operations. Since - control-flow constructs in C do not return values, this operation doesn't - take any arguments. + "yield" terminates its parent EmitC op's region, optionally yielding + an SSA value. The semantics of how the values are yielded is defined by the + parent operation. + If "yield" has an operand, the operand must match the parent operation's + result. If the parent operation defines no values, then the "emitc.yield" + may be left out in the custom syntax and the builders will insert one + implicitly. Otherwise, it has to be present in the syntax to indicate which + value is yielded. }]; - let arguments = (ins); + let arguments = (ins Optional:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let assemblyFormat = [{ attr-dict }]; + let hasVerifier = 1; + let assemblyFormat = [{ attr-dict ($result^ `:` type($result))? }]; } def EmitC_IfOp : EmitC_Op<"if", diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..0b507d75fa07a --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name EmitC) +add_public_tablegen_target(MLIREmitCTransformsIncGen) + +add_mlir_doc(Passes EmitCPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h new file mode 100644 index 0000000000000..5cd27149d366e --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace emitc { + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// + +/// Creates an instance of the C-style expressions forming pass. +std::unique_ptr createFormExpressionsPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td new file mode 100644 index 0000000000000..fd083abc95715 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td @@ -0,0 +1,24 @@ +//===-- Passes.td - pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES +#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def FormExpressions : Pass<"form-expressions"> { + let summary = "Form C-style expressions from C-operator ops"; + let description = [{ + The pass wraps emitc ops modelling C operators in emitc.expression ops and + then folds single-use expressions into their users where possible. + }]; + let constructor = "mlir::emitc::createFormExpressionsPass()"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h new file mode 100644 index 0000000000000..2574acd7d48e0 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h @@ -0,0 +1,34 @@ +//===- Transforms.h - EmitC transformations as patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace emitc { + +//===----------------------------------------------------------------------===// +// Expression transforms +//===----------------------------------------------------------------------===// + +ExpressionOp createExpression(Operation *op, OpBuilder &builder); + +//===----------------------------------------------------------------------===// +// Populate functions +//===----------------------------------------------------------------------===// + +/// Populates `patterns` with expression-related patterns. +void populateExpressionPatterns(RewritePatternSet &patterns); + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index f22980036ffcf..5207559f36250 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Async/Passes.h" #include "mlir/Dialect/Bufferization/Pipelines/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -86,6 +87,7 @@ inline void registerAllPasses() { vector::registerVectorPasses(); arm_sme::registerArmSMEPasses(); arm_sve::registerArmSVEPasses(); + emitc::registerEmitCPasses(); // Dialect pipelines bufferization::registerBufferizationPipelines(); diff --git a/mlir/lib/Dialect/EmitC/CMakeLists.txt b/mlir/lib/Dialect/EmitC/CMakeLists.txt index f33061b2d87cf..9f57627c321fb 100644 --- a/mlir/lib/Dialect/EmitC/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e8ea4da0b089c..fd32efe783bcf 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -189,6 +189,50 @@ LogicalResult emitc::ConstantOp::verify() { OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } +//===----------------------------------------------------------------------===// +// ExpressionOp +//===----------------------------------------------------------------------===// + +Operation *ExpressionOp::getRootOp() { + auto yieldOp = cast(getBody()->getTerminator()); + Value yieldedValue = yieldOp.getResult(); + Operation *rootOp = yieldedValue.getDefiningOp(); + assert(rootOp && "Yielded value not defined within expression"); + return rootOp; +} + +LogicalResult ExpressionOp::verify() { + Type resultType = getResult().getType(); + Region ®ion = getRegion(); + + Block &body = region.front(); + + if (!body.mightHaveTerminator()) + return emitOpError("must yield a value at termination"); + + auto yield = cast(body.getTerminator()); + Value yieldResult = yield.getResult(); + + if (!yieldResult) + return emitOpError("must yield a value at termination"); + + Type yieldType = yieldResult.getType(); + + if (resultType != yieldType) + return emitOpError("requires yielded type to match return type"); + + for (Operation &op : region.front().without_terminator()) { + if (!isCExpression(op)) + return emitOpError("contains an unsupported operation"); + if (op.getNumResults() != 1) + return emitOpError("requires exactly one result for each operation"); + if (!op.getResult(0).hasOneUse()) + return emitOpError("requires exactly one use for each operation"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -530,6 +574,23 @@ LogicalResult emitc::VariableOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +LogicalResult emitc::YieldOp::verify() { + Value result = getResult(); + Operation *containingOp = getOperation()->getParentOp(); + + if (result && containingOp->getNumResults() != 1) + return emitOpError() << "yields a value not returned by parent"; + + if (!result && containingOp->getNumResults() != 0) + return emitOpError() << "does not yield a value to be returned by parent"; + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt new file mode 100644 index 0000000000000..bfcc14523f137 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIREmitCTransforms + Transforms.cpp + FormExpressions.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms + + DEPENDS + MLIREmitCTransformsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIREmitCDialect + MLIRTransforms +) diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp new file mode 100644 index 0000000000000..21212155ffb22 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp @@ -0,0 +1,60 @@ +//===- FormExpressions.cpp - Form C-style expressions --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that forms EmitC operations modeling C operators +// into C-style expressions using the emitc.expression op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace emitc { +#define GEN_PASS_DEF_FORMEXPRESSIONS +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" +} // namespace emitc +} // namespace mlir + +using namespace mlir; +using namespace emitc; + +namespace { +struct FormExpressionsPass + : public emitc::impl::FormExpressionsBase { + void runOnOperation() override { + Operation *rootOp = getOperation(); + MLIRContext *context = rootOp->getContext(); + + // Wrap each C operator op with an expression op. + OpBuilder builder(context); + auto matchFun = [&](Operation *op) { + if (emitc::ExpressionOp::isCExpression(*op)) + createExpression(op, builder); + }; + rootOp->walk(matchFun); + + // Fold expressions where possible. + RewritePatternSet patterns(context); + populateExpressionPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns)))) + return signalPassFailure(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // namespace + +std::unique_ptr mlir::emitc::createFormExpressionsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp new file mode 100644 index 0000000000000..593d774cac73b --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -0,0 +1,114 @@ +//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace emitc { + +ExpressionOp createExpression(Operation *op, OpBuilder &builder) { + assert(ExpressionOp::isCExpression(*op) && "Expected a C expression"); + + // Create an expression yielding the value returned by op. + assert(op->getNumResults() == 1 && "Expected exactly one result"); + Value result = op->getResult(0); + Type resultType = result.getType(); + Location loc = op->getLoc(); + + builder.setInsertionPointAfter(op); + auto expressionOp = builder.create(loc, resultType); + + // Replace all op's uses with the new expression's result. + result.replaceAllUsesWith(expressionOp.getResult()); + + // Create an op to yield op's value. + Region ®ion = expressionOp.getRegion(); + Block &block = region.emplaceBlock(); + builder.setInsertionPointToEnd(&block); + auto yieldOp = builder.create(loc, result); + + // Move op into the new expression. + op->moveBefore(yieldOp); + + return expressionOp; +} + +} // namespace emitc +} // namespace mlir + +using namespace mlir; +using namespace mlir::emitc; + +namespace { + +struct FoldExpressionOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ExpressionOp expressionOp, + PatternRewriter &rewriter) const override { + bool anythingFolded = false; + for (Operation &op : llvm::make_early_inc_range( + expressionOp.getBody()->without_terminator())) { + // Don't fold expressions whose result value has its address taken. + auto applyOp = dyn_cast(op); + if (applyOp && applyOp.getApplicableOperator() == "&") + continue; + + for (Value operand : op.getOperands()) { + auto usedExpression = + dyn_cast_if_present(operand.getDefiningOp()); + + if (!usedExpression) + continue; + + // Don't fold expressions with multiple users: assume any + // re-materialization was done separately. + if (!usedExpression.getResult().hasOneUse()) + continue; + + // Don't fold expressions with side effects. + if (usedExpression.hasSideEffects()) + continue; + + // Fold the used expression into this expression by cloning all + // instructions in the used expression just before the operation using + // its value. + rewriter.setInsertionPoint(&op); + IRMapping mapper; + for (Operation &opToClone : + usedExpression.getBody()->without_terminator()) { + Operation *clone = rewriter.clone(opToClone, mapper); + mapper.map(&opToClone, clone); + } + + Operation *expressionRoot = usedExpression.getRootOp(); + Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); + assert(clonedExpressionRootOp && + "Expected cloned expression root to be in mapper"); + assert(clonedExpressionRootOp->getNumResults() == 1 && + "Expected cloned root to have a single result"); + + Value clonedExpressionResult = clonedExpressionRootOp->getResult(0); + + usedExpression.getResult().replaceAllUsesWith(clonedExpressionResult); + rewriter.eraseOp(usedExpression); + anythingFolded = true; + } + } + return anythingFolded ? success() : failure(); + } +}; + +} // namespace + +void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 1b4ec9eae9367..c32cb03caf9db 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/Support/IndentedOstream.h" +#include "mlir/Support/LLVM.h" #include "mlir/Target/Cpp/CppEmitter.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" @@ -65,6 +66,35 @@ inline LogicalResult interleaveCommaWithError(const Container &c, return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); } +/// Return the precedence of a operator as an integer, higher values +/// imply higher precedence. +static int getOperatorPrecedence(Operation *operation) { + return llvm::TypeSwitch(operation) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 13; }) + .Case([&](auto op) { + switch (op.getPredicate()) { + case emitc::CmpPredicate::eq: + case emitc::CmpPredicate::ne: + return 8; + case emitc::CmpPredicate::lt: + case emitc::CmpPredicate::le: + case emitc::CmpPredicate::gt: + case emitc::CmpPredicate::ge: + return 9; + case emitc::CmpPredicate::three_way: + return 10; + } + }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 11; }) + .Case([&](auto op) { return 14; }); + llvm_unreachable("Unsupported operator"); +} + namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { @@ -115,6 +145,12 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); + /// Emits value as an operands of an operation + LogicalResult emitOperand(Value value); + + /// Emit an expression as a C expression. + LogicalResult emitExpression(ExpressionOp expressionOp); + /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); @@ -156,6 +192,21 @@ struct CppEmitter { /// be declared at the beginning of a function. bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + /// Get expression currently being emitted. + ExpressionOp getEmittedExpression() { return emittedExpression; } + + /// Determine whether given value is part of the expression potentially being + /// emitted. + bool isPartOfCurrentExpression(Value value) { + if (!emittedExpression) + return false; + Operation *def = value.getDefiningOp(); + if (!def) + return false; + auto operandExpression = dyn_cast(def->getParentOp()); + return operandExpression == emittedExpression; + }; + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -178,9 +229,50 @@ struct CppEmitter { /// names of values in a scope. std::stack valueInScopeCount; std::stack labelInScopeCount; + + /// State of the current expression being emitted. + ExpressionOp emittedExpression; + SmallVector emittedExpressionPrecedence; + + void pushExpressionPrecedence(int precedence) { + emittedExpressionPrecedence.push_back(precedence); + } + void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); } + static int lowestPrecedence() { return 0; } + int getExpressionPrecedence() { + if (emittedExpressionPrecedence.empty()) + return lowestPrecedence(); + return emittedExpressionPrecedence.back(); + } }; } // namespace +/// Determine whether expression \p expressionOp should be emitted inline, i.e. +/// as part of its user. This function recommends inlining of any expressions +/// that can be inlined unless it is used by another expression, under the +/// assumption that any expression fusion/re-materialization was taken care of +/// by transformations run by the backend. +static bool shouldBeInlined(ExpressionOp expressionOp) { + // Do not inline if expression is marked as such. + if (expressionOp.getDoNotInline()) + return false; + + // Do not inline expressions with side effects to prevent side-effect + // reordering. + if (expressionOp.hasSideEffects()) + return false; + + // Do not inline expressions with multiple uses. + Value result = expressionOp.getResult(); + if (!result.hasOneUse()) + return false; + + // Do not inline expressions used by other expressions, as any desired + // expression folding was taken care of by transformations. + Operation *user = *result.getUsers().begin(); + return !user->getParentOfType(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -253,9 +345,7 @@ static LogicalResult printOperation(CppEmitter &emitter, if (failed(emitter.emitVariableAssignment(result))) return failure(); - emitter.ostream() << emitter.getOrCreateName(assignOp.getValue()); - - return success(); + return emitter.emitOperand(assignOp.getValue()); } static LogicalResult printBinaryOperation(CppEmitter &emitter, @@ -265,9 +355,14 @@ static LogicalResult printBinaryOperation(CppEmitter &emitter, if (failed(emitter.emitAssignPrefix(*operation))) return failure(); - os << emitter.getOrCreateName(operation->getOperand(0)); - os << " " << binaryOperator; - os << " " << emitter.getOrCreateName(operation->getOperand(1)); + + if (failed(emitter.emitOperand(operation->getOperand(0)))) + return failure(); + + os << " " << binaryOperator << " "; + + if (failed(emitter.emitOperand(operation->getOperand(1)))) + return failure(); return success(); } @@ -485,9 +580,20 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) return failure(); os << ") "; - os << emitter.getOrCreateName(castOp.getOperand()); + return emitter.emitOperand(castOp.getOperand()); +} - return success(); +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ExpressionOp expressionOp) { + if (shouldBeInlined(expressionOp)) + return success(); + + Operation &op = *expressionOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + + return emitter.emitExpression(expressionOp); } static LogicalResult printOperation(CppEmitter &emitter, @@ -507,6 +613,17 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { raw_indented_ostream &os = emitter.ostream(); + // Utility function to determine whether a value is an expression that will be + // inlined, and as such should be wrapped in parentheses in order to guarantee + // its precedence and associativity. + auto requiresParentheses = [&](Value value) { + auto expressionOp = + dyn_cast_if_present(value.getDefiningOp()); + if (!expressionOp) + return false; + return shouldBeInlined(expressionOp); + }; + os << "for ("; if (failed( emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) @@ -514,15 +631,24 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { os << " "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " = "; - os << emitter.getOrCreateName(forOp.getLowerBound()); + if (failed(emitter.emitOperand(forOp.getLowerBound()))) + return failure(); os << "; "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " < "; - os << emitter.getOrCreateName(forOp.getUpperBound()); + Value upperBound = forOp.getUpperBound(); + bool upperBoundRequiresParentheses = requiresParentheses(upperBound); + if (upperBoundRequiresParentheses) + os << "("; + if (failed(emitter.emitOperand(upperBound))) + return failure(); + if (upperBoundRequiresParentheses) + os << ")"; os << "; "; os << emitter.getOrCreateName(forOp.getInductionVar()); os << " += "; - os << emitter.getOrCreateName(forOp.getStep()); + if (failed(emitter.emitOperand(forOp.getStep()))) + return failure(); os << ") {\n"; os.indent(); @@ -557,7 +683,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) { }; os << "if ("; - if (failed(emitter.emitOperands(*ifOp.getOperation()))) + if (failed(emitter.emitOperand(ifOp.getCondition()))) return failure(); os << ") {\n"; os.indent(); @@ -585,8 +711,10 @@ static LogicalResult printOperation(CppEmitter &emitter, case 0: return success(); case 1: - os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); - return success(emitter.hasValueInScope(returnOp.getOperand(0))); + os << " "; + if (failed(emitter.emitOperand(returnOp.getOperand(0)))) + return failure(); + return success(); default: os << " std::make_tuple("; if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) @@ -639,7 +767,10 @@ static LogicalResult printOperation(CppEmitter &emitter, // regions. WalkResult result = functionOp.walk([&](Operation *op) -> WalkResult { - if (isa(op)) + if (isa(op) || + isa(op->getParentOp()) || + (isa(op) && + shouldBeInlined(cast(op)))) return WalkResult::skip(); for (OpResult result : op->getResults()) { if (failed(emitter.emitVariableDeclaration( @@ -841,15 +972,70 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return emitError(loc, "cannot emit attribute: ") << attr; } +LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { + assert(emittedExpressionPrecedence.empty() && + "Expected precedence stack to be empty"); + Operation *rootOp = expressionOp.getRootOp(); + + emittedExpression = expressionOp; + pushExpressionPrecedence(getOperatorPrecedence(rootOp)); + + if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false))) + return failure(); + + popExpressionPrecedence(); + assert(emittedExpressionPrecedence.empty() && + "Expected precedence stack to be empty"); + emittedExpression = nullptr; + + return success(); +} + +LogicalResult CppEmitter::emitOperand(Value value) { + if (isPartOfCurrentExpression(value)) { + Operation *def = value.getDefiningOp(); + assert(def && "Expected operand to be defined by an operation"); + int precedence = getOperatorPrecedence(def); + bool encloseInParenthesis = precedence < getExpressionPrecedence(); + if (encloseInParenthesis) { + os << "("; + pushExpressionPrecedence(lowestPrecedence()); + } else + pushExpressionPrecedence(precedence); + + if (failed(emitOperation(*def, /*trailingSemicolon=*/false))) + return failure(); + + if (encloseInParenthesis) + os << ")"; + + popExpressionPrecedence(); + return success(); + } + + auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); + if (expressionOp && shouldBeInlined(expressionOp)) + return emitExpression(expressionOp); + + auto literalOp = dyn_cast_if_present(value.getDefiningOp()); + if (!literalOp && !hasValueInScope(value)) + return failure(); + os << getOrCreateName(value); + return success(); +} + LogicalResult CppEmitter::emitOperands(Operation &op) { - auto emitOperandName = [&](Value result) -> LogicalResult { - auto literalDef = dyn_cast_if_present(result.getDefiningOp()); - if (!literalDef && !hasValueInScope(result)) - return op.emitOpError() << "operand value not in scope"; - os << getOrCreateName(result); + return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { + // If an expression is being emitted, push lowest precedence as these + // operands are either wrapped by parenthesis. + if (getEmittedExpression()) + pushExpressionPrecedence(lowestPrecedence()); + if (failed(emitOperand(operand))) + return failure(); + if (getEmittedExpression()) + popExpressionPrecedence(); return success(); - }; - return interleaveCommaWithError(op.getOperands(), os, emitOperandName); + }); } LogicalResult @@ -902,6 +1088,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, } LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { + // If op is being emitted as part of an expression, bail out. + if (getEmittedExpression()) + return success(); + switch (op.getNumResults()) { case 0: break; @@ -952,9 +1142,9 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { // EmitC ops. .Case( + emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp, + emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp, + emitc::RemOp, emitc::SubOp, emitc::VariableOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( @@ -973,7 +1163,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (isa(op)) return success(); + if (getEmittedExpression() || + (isa(op) && + shouldBeInlined(cast(op)))) + return success(); + os << (trailingSemicolon ? ";\n" : "\n"); + return success(); } diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 49efb962dfa25..6ad646d7c62f1 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -203,7 +203,7 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr, %arg1: !emitc.ptr) { // ----- func.func @test_misplaced_yield() { - // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.if, emitc.for'}} + // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.expression, emitc.if, emitc.for'}} emitc.yield return } @@ -224,3 +224,60 @@ func.func @test_assign_type_mismatch(%arg1: f32) { emitc.assign %arg1 : f32 to %v : i32 return } + +// ----- + +func.func @test_expression_no_yield() -> i32 { + // expected-error @+1 {{'emitc.expression' op must yield a value at termination}} + %r = emitc.expression : i32 { + %c7 = "emitc.constant"(){value = 7 : i32} : () -> i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_illegal_op(%arg0 : i1) -> i32 { + // expected-error @+1 {{'emitc.expression' op contains an unsupported operation}} + %r = emitc.expression : i32 { + %x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 + emitc.yield %x : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + %r = emitc.expression : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + %r = emitc.expression : i32 { + %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.add %a, %arg0 : (i32, i32) -> i32 + %c = emitc.mul %arg1, %a : (i32, i32) -> i32 + emitc.yield %a : i32 + } + return %r : i32 +} + +// ----- + +func.func @test_expression_multiple_results(%arg0: i32) -> i32 { + // expected-error @+1 {{'emitc.expression' op requires exactly one result for each operation}} + %r = emitc.expression : i32 { + %a:2 = emitc.call_opaque "bar" (%arg0) : (i32) -> (i32, i32) + emitc.yield %a : i32 + } + return %r : i32 +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index b3a24c26b96ca..45ce2bcb99092 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -128,6 +128,23 @@ func.func @test_assign(%arg1: f32) { return } +func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> i32 { + %c7 = "emitc.constant"() {value = 7 : i32} : () -> i32 + %q = emitc.expression : i32 { + %a = emitc.rem %arg1, %c7 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + %r = emitc.expression noinline : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2, %q) : (i32, i32, i32) -> (i32) + %c = emitc.mul %arg3, %arg4 : (f32, f32) -> f32 + %d = emitc.cast %c : f32 to i32 + %e = emitc.sub %b, %d : (i32, i32) -> i32 + emitc.yield %e : i32 + } + return %r : i32 +} + func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { emitc.for %i0 = %arg0 to %arg1 step %arg2 { %0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32 diff --git a/mlir/test/Dialect/EmitC/transforms.mlir b/mlir/test/Dialect/EmitC/transforms.mlir new file mode 100644 index 0000000000000..ad167fa455a1a --- /dev/null +++ b/mlir/test/Dialect/EmitC/transforms.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s --form-expressions --verify-diagnostics --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @single_expression( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 { +// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32 +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_4]] : (i32, i32) -> i32 +// CHECK: %[[VAL_7:.*]] = emitc.sub %[[VAL_6]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_3]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 { + %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32 + %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 + %b = emitc.sub %a, %arg2 : (i32, i32) -> i32 + %c = emitc.cmp lt, %b, %arg3 :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @multiple_expressions( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) { +// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_6:.*]] = emitc.sub %[[VAL_5]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_6]] : i32 +// CHECK: } +// CHECK: %[[VAL_7:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_8:.*]] = emitc.add %[[VAL_1]], %[[VAL_3]] : (i32, i32) -> i32 +// CHECK: %[[VAL_9:.*]] = emitc.div %[[VAL_8]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_9]] : i32 +// CHECK: } +// CHECK: return %[[VAL_4]], %[[VAL_7]] : i32, i32 +// CHECK: } + +func.func @multiple_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> (i32, i32) { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.sub %a, %arg2 : (i32, i32) -> i32 + %c = emitc.add %arg1, %arg3 : (i32, i32) -> i32 + %d = emitc.div %c, %arg2 : (i32, i32) -> i32 + return %b, %d : i32, i32 +} + +// CHECK-LABEL: func.func @expression_with_call( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 { +// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_6:.*]] = emitc.call_opaque "foo"(%[[VAL_5]], %[[VAL_2]]) : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_6]] : i32 +// CHECK: } +// CHECK: %[[VAL_7:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_4]], %[[VAL_1]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_7]] : i1 +// CHECK: } + +func.func @expression_with_call(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "foo" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @expression_with_dereference( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr) -> i1 { +// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_4:.*]] = emitc.apply "*"(%[[VAL_2]]) : (!emitc.ptr) -> i32 +// CHECK: emitc.yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: %[[VAL_7:.*]] = emitc.cmp lt, %[[VAL_6]], %[[VAL_3]] : (i32, i32) -> i1 +// CHECK: emitc.yield %[[VAL_7]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @expression_with_dereference(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.apply "*"(%arg2) : (!emitc.ptr) -> (i32) + %c = emitc.cmp lt, %a, %b :(i32, i32) -> i1 + return %c : i1 +} + +// CHECK-LABEL: func.func @expression_with_address_taken( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr) -> i1 { +// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 { +// CHECK: %[[VAL_4:.*]] = emitc.rem %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32 +// CHECK: emitc.yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 { +// CHECK: %[[VAL_6:.*]] = emitc.apply "&"(%[[VAL_3]]) : (i32) -> !emitc.ptr +// CHECK: %[[VAL_7:.*]] = emitc.add %[[VAL_6]], %[[VAL_1]] : (!emitc.ptr, i32) -> !emitc.ptr +// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_2]] : (!emitc.ptr, !emitc.ptr) -> i1 +// CHECK: emitc.yield %[[VAL_8]] : i1 +// CHECK: } +// CHECK: return %[[VAL_5]] : i1 +// CHECK: } + +func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.rem %arg0, %arg1 : (i32, i32) -> (i32) + %b = emitc.apply "&"(%a) : (i32) -> !emitc.ptr + %c = emitc.add %b, %arg1 : (!emitc.ptr, i32) -> !emitc.ptr + %d = emitc.cmp lt, %c, %arg2 :(!emitc.ptr, !emitc.ptr) -> i1 + return %d : i1 +} diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir new file mode 100644 index 0000000000000..9ec9dcc3c6a84 --- /dev/null +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -0,0 +1,212 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP + +// CPP-DEFAULT: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_5]]) { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return [[VAL_6]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_5]]) { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: return [[VAL_6]]; +// CPP-DECLTOP-NEXT: } + +func.func @single_use(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %p0 = emitc.literal "M_PI" : i32 + %e = emitc.expression : i1 { + %a = emitc.mul %arg0, %p0 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.sub %b, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + return %v : i32 +} + +// CPP-DEFAULT: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]]; +// CPP-DEFAULT-NEXT: return [[VAL_4]]; +// CPP-DEFAULT-NEXT:} + +// CPP-DECLTOP: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]]; +// CPP-DECLTOP-NEXT: return [[VAL_4]]; +// CPP-DECLTOP-NEXT:} + +func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 { + %e = emitc.expression noinline : i32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.mul %a, %arg2 : (i32, i32) -> i32 + emitc.yield %b : i32 + } + return %e : i32 +} + +// CPP-DEFAULT: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DECLTOP-NEXT: } + +func.func @paranthesis_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 { + %e = emitc.expression : f32 { + %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.mul %a, %arg2 : (i32, i32) -> i32 + %d = emitc.cast %b : i32 to f32 + emitc.yield %d : f32 + } + return %e : f32 +} + +// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_5]]) { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: bool [[VAL_7:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: return [[VAL_6]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: bool [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_5]]) { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: return [[VAL_6]]; +// CPP-DECLTOP-NEXT: } + +func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %e = emitc.expression : i1 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32) + %c = emitc.sub %b, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + %q = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i1 + emitc.assign %e : i1 to %q : i1 + return %v : i32 +} + +// CPP-DEFAULT: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return [[VAL_7]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]]; +// CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: return [[VAL_7]]; +// CPP-DECLTOP-NEXT: } + +func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 { + %e1 = emitc.expression : i32 { + %a = emitc.rem %arg2, %arg3 : (i32, i32) -> i32 + emitc.yield %a : i32 + } + %e2 = emitc.expression : i32 { + %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.call_opaque "bar" (%e1, %a) : (i32, i32) -> (i32) + emitc.yield %b : i32 + } + %e3 = emitc.expression : i1 { + %c = emitc.sub %e2, %arg3 : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + emitc.yield %d : i1 + } + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32 + emitc.if %e3 { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } else { + emitc.assign %arg0 : i32 to %v : i32 + emitc.yield + } + return %v : i32 +} + +// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]]; +// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]]; +// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DECLTOP-NEXT: } + +func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { + %a = emitc.expression : i32 { + %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + emitc.yield %b : i32 + } + %c = emitc.expression : i1 { + %d = emitc.apply "&"(%a) : (i32) -> !emitc.ptr + %e = emitc.sub %d, %arg1 : (!emitc.ptr, i32) -> !emitc.ptr + %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr, !emitc.ptr) -> i1 + emitc.yield %f : i1 + } + return %c : i1 +} diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir index 90504b1347bb4..b9bd3d98465a2 100644 --- a/mlir/test/Target/Cpp/for.mlir +++ b/mlir/test/Target/Cpp/for.mlir @@ -2,20 +2,32 @@ // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { - emitc.for %i0 = %arg0 to %arg1 step %arg2 { + %lb = emitc.expression : index { + %a = emitc.add %arg0, %arg1 : (index, index) -> index + emitc.yield %a : index + } + %ub = emitc.expression : index { + %a = emitc.mul %arg1, %arg2 : (index, index) -> index + emitc.yield %a : index + } + %step = emitc.expression : index { + %a = emitc.div %arg0, %arg2 : (index, index) -> index + emitc.yield %a : index + } + emitc.for %i0 = %lb to %ub step %step { %0 = emitc.call_opaque "f"() : () -> i32 } return } -// CPP-DEFAULT: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) { -// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) { +// CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { +// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f(); // CPP-DEFAULT-NEXT: } // CPP-DEFAULT-NEXT: return; -// CPP-DECLTOP: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) { +// CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { // CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) { +// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DECLTOP-NEXT: [[V4]] = f(); // CPP-DECLTOP-NEXT: } // CPP-DECLTOP-NEXT: return;