diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index f19e0f8c4c2a4..202df89025f26 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -301,9 +301,9 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { if (isa(user)) return false; - // Do not inline expressions used by other expressions, as any desired - // expression folding was taken care of by transformations. - return !user->getParentOfType(); + // Do not inline expressions used by ops with the CExpression trait. If this + // was intended, the user could have been merged into the expression op. + return !user->hasTrait(); } static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index aaddd5af874a9..caa0a340d3e0a 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -100,6 +100,86 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) - return %e : i32 } +// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 0; +// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: bool [[CAST:v[0-9]+]] = (bool) [[EXP_0]]; +// CPP-DEFAULT-NEXT: int32_t [[ADD:v[0-9]+]] = [[EXP_1]] + [[VAL_4]]; +// CPP-DEFAULT-NEXT: int32_t [[CALL:v[0-9]+]] = bar([[EXP_2]], [[VAL_4]]); +// CPP-DEFAULT-NEXT: int32_t [[COND:v[0-9]+]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]]; +// CPP-DEFAULT-NEXT: int32_t [[VAR:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_0:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_1:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_2:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_3:v[0-9]+]]; +// CPP-DECLTOP-NEXT: bool [[CAST:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[ADD:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[CALL:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[COND:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAR:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = 0; +// CPP-DECLTOP-NEXT: [[EXP_0]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[EXP_1]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[EXP_2]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[EXP_3]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[CAST]] = (bool) [[EXP_0]]; +// CPP-DECLTOP-NEXT: [[ADD]] = [[EXP_1]] + [[VAL_4]]; +// CPP-DECLTOP-NEXT: [[CALL]] = bar([[EXP_2]], [[VAL_4]]); +// CPP-DECLTOP-NEXT: [[COND]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: } +func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { + %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32 + %e0 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e1 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e2 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e3 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e4 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e5 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %cast = emitc.cast %e0 : i32 to i1 + %add = emitc.add %e1, %c0 : (i32, i32) -> i32 + %call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32) + %cond = emitc.conditional %cast, %e3, %c0 : i32 + %var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32 + emitc.assign %e4 : i32 to %var : i32 + return %e5 : i32 +} + // CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { // CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; // CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];