From f5e1d9596b09942b80effa79773f10aac9405063 Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Wed, 29 May 2024 14:10:03 +0000 Subject: [PATCH 1/5] [mlir][EmitC] Emit parentheses for users of expression ops --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 9 ++- mlir/test/Target/Cpp/expressions.mlir | 84 +++++++++++++++++++++++--- mlir/test/Target/Cpp/for.mlir | 4 +- 3 files changed, 85 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index f19e0f8c4c2a4..e7d80d80855a5 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1338,8 +1338,13 @@ LogicalResult CppEmitter::emitOperand(Value value) { } auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); - if (expressionOp && shouldBeInlined(expressionOp)) - return emitExpression(expressionOp); + if (expressionOp && shouldBeInlined(expressionOp)) { + os << "("; + if (failed(emitExpression(expressionOp))) + return failure(); + os << ")"; + return success(); + } auto literalOp = dyn_cast_if_present(value.getDefiningOp()); if (!literalOp && !hasValueInScope(value)) diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index aaddd5af874a9..37e0a0ffbdeb1 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -66,11 +66,11 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 { } // CPP-DEFAULT: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DEFAULT-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]])); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); +// CPP-DECLTOP-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]])); // CPP-DECLTOP-NEXT: } func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 { @@ -84,11 +84,11 @@ func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> } // CPP-DEFAULT: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); // CPP-DECLTOP-NEXT: } func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { %e = emitc.expression : i32 { @@ -100,6 +100,74 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) - return %e : i32 } +// CPP-DEFAULT: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t v4 = 0; +// CPP-DEFAULT-NEXT: bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DEFAULT-NEXT: int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4; +// CPP-DEFAULT-NEXT: int32_t v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4); +// CPP-DEFAULT-NEXT: int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4; +// CPP-DEFAULT-NEXT: int32_t v9; +// CPP-DEFAULT-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t v4; +// CPP-DECLTOP-NEXT: bool v5; +// CPP-DECLTOP-NEXT: int32_t v6; +// CPP-DECLTOP-NEXT: int32_t v7; +// CPP-DECLTOP-NEXT: int32_t v8; +// CPP-DECLTOP-NEXT: int32_t v9; +// CPP-DECLTOP-NEXT: v4 = 0; +// CPP-DECLTOP-NEXT: v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DECLTOP-NEXT: v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4; +// CPP-DECLTOP-NEXT: v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4); +// CPP-DECLTOP-NEXT: v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DECLTOP-NEXT: } +func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { + %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32 + %e0 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e1 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e2 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e3 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e4 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %e5 = emitc.expression : i32 { + %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 + %1 = emitc.div %arg2, %0 : (i32, i32) -> i32 + emitc.yield %1 : i32 + } + %cast = emitc.cast %e0 : i32 to i1 + %add = emitc.add %e1, %c0 : (i32, i32) -> i32 + %call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32) + %cond = emitc.conditional %cast, %e3, %c0 : i32 + %var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32 + emitc.assign %e4 : i32 to %var : i32 + return %e5 : i32 +} + // CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { // CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; // CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; @@ -154,7 +222,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 // CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]]; // CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]]; -// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DEFAULT-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) { // CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; // CPP-DEFAULT-NEXT: } else { // CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; @@ -169,7 +237,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 // CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]]; // CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: ; -// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { +// CPP-DECLTOP-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) { // CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; // CPP-DECLTOP-NEXT: } else { // CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; @@ -205,13 +273,13 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) // CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { // CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]]; -// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DEFAULT-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { // CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; // CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]]; -// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DECLTOP-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]); // CPP-DECLTOP-NEXT: } func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir index 60988bcb46556..2e41dce45f580 100644 --- a/mlir/test/Target/Cpp/for.mlir +++ b/mlir/test/Target/Cpp/for.mlir @@ -20,14 +20,14 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { -// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { +// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) { // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f(); // CPP-DEFAULT-NEXT: } // CPP-DEFAULT-NEXT: return; // CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { // CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { +// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) { // CPP-DECLTOP-NEXT: [[V4]] = f(); // CPP-DECLTOP-NEXT: } // CPP-DECLTOP-NEXT: return; From aa4b1bfbf320886e3855fbb0b2a0a3f76cae455a Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Wed, 29 May 2024 14:15:53 +0000 Subject: [PATCH 2/5] Skip parenthesis where its safe --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 10 +++++++-- mlir/test/Target/Cpp/expressions.mlir | 28 +++++++++++++------------- mlir/test/Target/Cpp/for.mlir | 4 ++-- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index e7d80d80855a5..83ef2a39950f2 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1339,10 +1339,16 @@ LogicalResult CppEmitter::emitOperand(Value value) { auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); if (expressionOp && shouldBeInlined(expressionOp)) { - os << "("; + Operation *user = *expressionOp->getUsers().begin(); + const bool safeToSkipParentheses = + isa(user); + if (!safeToSkipParentheses) + os << "("; if (failed(emitExpression(expressionOp))) return failure(); - os << ")"; + if (!safeToSkipParentheses) + os << ")"; return success(); } diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 37e0a0ffbdeb1..1c55b9404225d 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -66,11 +66,11 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 { } // CPP-DEFAULT: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DEFAULT-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]])); +// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DECLTOP-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]])); +// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]); // CPP-DECLTOP-NEXT: } func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 { @@ -84,11 +84,11 @@ func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> } // CPP-DEFAULT: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: } func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { %e = emitc.expression : i32 { @@ -104,11 +104,11 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) - // CPP-DEFAULT-NEXT: int32_t v4 = 0; // CPP-DEFAULT-NEXT: bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); // CPP-DEFAULT-NEXT: int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4; -// CPP-DEFAULT-NEXT: int32_t v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4); +// CPP-DEFAULT-NEXT: int32_t v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4); // CPP-DEFAULT-NEXT: int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4; // CPP-DEFAULT-NEXT: int32_t v9; -// CPP-DEFAULT-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); -// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DEFAULT-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { @@ -121,11 +121,11 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) - // CPP-DECLTOP-NEXT: v4 = 0; // CPP-DECLTOP-NEXT: v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); // CPP-DECLTOP-NEXT: v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4; -// CPP-DECLTOP-NEXT: v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4); +// CPP-DECLTOP-NEXT: v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4); // CPP-DECLTOP-NEXT: v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4; // CPP-DECLTOP-NEXT: ; -// CPP-DECLTOP-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); -// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); +// CPP-DECLTOP-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: } func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32 @@ -222,7 +222,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 // CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]]; // CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]]; -// CPP-DEFAULT-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) { +// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { // CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; // CPP-DEFAULT-NEXT: } else { // CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]]; @@ -237,7 +237,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 // CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]]; // CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: ; -// CPP-DECLTOP-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) { +// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) { // CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; // CPP-DECLTOP-NEXT: } else { // CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]]; @@ -273,13 +273,13 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) // CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { // CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]]; -// CPP-DEFAULT-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]); +// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { // CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; // CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]]; -// CPP-DECLTOP-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]); +// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; // CPP-DECLTOP-NEXT: } func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr) -> i1 { diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir index 2e41dce45f580..60988bcb46556 100644 --- a/mlir/test/Target/Cpp/for.mlir +++ b/mlir/test/Target/Cpp/for.mlir @@ -20,14 +20,14 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { -// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) { +// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f(); // CPP-DEFAULT-NEXT: } // CPP-DEFAULT-NEXT: return; // CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) { // CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]]; -// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) { +// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) { // CPP-DECLTOP-NEXT: [[V4]] = f(); // CPP-DECLTOP-NEXT: } // CPP-DECLTOP-NEXT: return; From c6817f88c0ce4cb7e0a86bb6137242f2d13742cc Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Mon, 3 Jun 2024 11:55:33 +0000 Subject: [PATCH 3/5] Do not inline expressions into ops with the CExpression trait --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 19 ++++------- mlir/test/Target/Cpp/expressions.mlir | 44 ++++++++++++++++---------- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 83ef2a39950f2..01648ba693180 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -303,7 +303,12 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { // Do not inline expressions used by other expressions, as any desired // expression folding was taken care of by transformations. - return !user->getParentOfType(); + if (user->getParentOfType()) + return false; + + // Do not inline expressions used by ops with the CExpression trait. If this + // was intended, the user could have been merged into the expression op. + return !user->hasTrait(); } static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, @@ -1339,17 +1344,7 @@ LogicalResult CppEmitter::emitOperand(Value value) { auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); if (expressionOp && shouldBeInlined(expressionOp)) { - Operation *user = *expressionOp->getUsers().begin(); - const bool safeToSkipParentheses = - isa(user); - if (!safeToSkipParentheses) - os << "("; - if (failed(emitExpression(expressionOp))) - return failure(); - if (!safeToSkipParentheses) - os << ")"; - return success(); + return emitExpression(expressionOp); } auto literalOp = dyn_cast_if_present(value.getDefiningOp()); diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 1c55b9404225d..810a629c71533 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -100,34 +100,46 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) - return %e : i32 } -// CPP-DEFAULT: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { // CPP-DEFAULT-NEXT: int32_t v4 = 0; -// CPP-DEFAULT-NEXT: bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); -// CPP-DEFAULT-NEXT: int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4; -// CPP-DEFAULT-NEXT: int32_t v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4); -// CPP-DEFAULT-NEXT: int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4; -// CPP-DEFAULT-NEXT: int32_t v9; -// CPP-DEFAULT-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: bool v9 = (bool) [[EXP_0]]; +// CPP-DEFAULT-NEXT: int32_t v10 = [[EXP_1]] + v4; +// CPP-DEFAULT-NEXT: int32_t v11 = bar([[EXP_2]], v4); +// CPP-DEFAULT-NEXT: int32_t v12 = v9 ? [[EXP_3]] : v4; +// CPP-DEFAULT-NEXT: int32_t v13; +// CPP-DEFAULT-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: } -// CPP-DECLTOP: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { // CPP-DECLTOP-NEXT: int32_t v4; -// CPP-DECLTOP-NEXT: bool v5; +// CPP-DECLTOP-NEXT: int32_t v5; // CPP-DECLTOP-NEXT: int32_t v6; // CPP-DECLTOP-NEXT: int32_t v7; // CPP-DECLTOP-NEXT: int32_t v8; -// CPP-DECLTOP-NEXT: int32_t v9; +// CPP-DECLTOP-NEXT: bool v9; +// CPP-DECLTOP-NEXT: int32_t v10; +// CPP-DECLTOP-NEXT: int32_t v11; +// CPP-DECLTOP-NEXT: int32_t v12; +// CPP-DECLTOP-NEXT: int32_t v13; // CPP-DECLTOP-NEXT: v4 = 0; -// CPP-DECLTOP-NEXT: v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])); -// CPP-DECLTOP-NEXT: v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4; -// CPP-DECLTOP-NEXT: v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4); -// CPP-DECLTOP-NEXT: v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4; +// CPP-DECLTOP-NEXT: v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: v9 = (bool) v5; +// CPP-DECLTOP-NEXT: v10 = v6 + v4; +// CPP-DECLTOP-NEXT: v11 = bar(v7, v4); +// CPP-DECLTOP-NEXT: v12 = v9 ? v8 : v4; // CPP-DECLTOP-NEXT: ; -// CPP-DECLTOP-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: } -func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { +func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32 %e0 = emitc.expression : i32 { %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 From fd5b962dffdfb071209004027fba6d9d8633ec0e Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Tue, 4 Jun 2024 08:03:53 +0000 Subject: [PATCH 4/5] Review comments --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 3 +- mlir/test/Target/Cpp/expressions.mlir | 54 +++++++++++++------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 01648ba693180..6cfe846a785dd 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1343,9 +1343,8 @@ LogicalResult CppEmitter::emitOperand(Value value) { } auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); - if (expressionOp && shouldBeInlined(expressionOp)) { + if (expressionOp && shouldBeInlined(expressionOp)) return emitExpression(expressionOp); - } auto literalOp = dyn_cast_if_present(value.getDefiningOp()); if (!literalOp && !hasValueInScope(value)) diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 810a629c71533..caa0a340d3e0a 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -101,42 +101,42 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) - } // CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DEFAULT-NEXT: int32_t v4 = 0; +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 0; // CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); -// CPP-DEFAULT-NEXT: bool v9 = (bool) [[EXP_0]]; -// CPP-DEFAULT-NEXT: int32_t v10 = [[EXP_1]] + v4; -// CPP-DEFAULT-NEXT: int32_t v11 = bar([[EXP_2]], v4); -// CPP-DEFAULT-NEXT: int32_t v12 = v9 ? [[EXP_3]] : v4; -// CPP-DEFAULT-NEXT: int32_t v13; -// CPP-DEFAULT-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DEFAULT-NEXT: bool [[CAST:v[0-9]+]] = (bool) [[EXP_0]]; +// CPP-DEFAULT-NEXT: int32_t [[ADD:v[0-9]+]] = [[EXP_1]] + [[VAL_4]]; +// CPP-DEFAULT-NEXT: int32_t [[CALL:v[0-9]+]] = bar([[EXP_2]], [[VAL_4]]); +// CPP-DEFAULT-NEXT: int32_t [[COND:v[0-9]+]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]]; +// CPP-DEFAULT-NEXT: int32_t [[VAR:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { -// CPP-DECLTOP-NEXT: int32_t v4; -// CPP-DECLTOP-NEXT: int32_t v5; -// CPP-DECLTOP-NEXT: int32_t v6; -// CPP-DECLTOP-NEXT: int32_t v7; -// CPP-DECLTOP-NEXT: int32_t v8; -// CPP-DECLTOP-NEXT: bool v9; -// CPP-DECLTOP-NEXT: int32_t v10; -// CPP-DECLTOP-NEXT: int32_t v11; -// CPP-DECLTOP-NEXT: int32_t v12; -// CPP-DECLTOP-NEXT: int32_t v13; -// CPP-DECLTOP-NEXT: v4 = 0; -// CPP-DECLTOP-NEXT: v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); -// CPP-DECLTOP-NEXT: v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); -// CPP-DECLTOP-NEXT: v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); -// CPP-DECLTOP-NEXT: v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); -// CPP-DECLTOP-NEXT: v9 = (bool) v5; -// CPP-DECLTOP-NEXT: v10 = v6 + v4; -// CPP-DECLTOP-NEXT: v11 = bar(v7, v4); -// CPP-DECLTOP-NEXT: v12 = v9 ? v8 : v4; +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_0:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_1:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_2:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[EXP_3:v[0-9]+]]; +// CPP-DECLTOP-NEXT: bool [[CAST:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[ADD:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[CALL:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[COND:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAR:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = 0; +// CPP-DECLTOP-NEXT: [[EXP_0]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[EXP_1]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[EXP_2]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[EXP_3]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[CAST]] = (bool) [[EXP_0]]; +// CPP-DECLTOP-NEXT: [[ADD]] = [[EXP_1]] + [[VAL_4]]; +// CPP-DECLTOP-NEXT: [[CALL]] = bar([[EXP_2]], [[VAL_4]]); +// CPP-DECLTOP-NEXT: [[COND]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]]; // CPP-DECLTOP-NEXT: ; -// CPP-DECLTOP-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); +// CPP-DECLTOP-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]); // CPP-DECLTOP-NEXT: } func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { From 4c9b5da3c6b98a49aa76c88f9dafdba24fd140a6 Mon Sep 17 00:00:00 2001 From: Simon Camphausen Date: Tue, 4 Jun 2024 10:06:31 +0000 Subject: [PATCH 5/5] Remove redundant check in shouldBeInlined --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 6cfe846a785dd..202df89025f26 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -301,11 +301,6 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { if (isa(user)) return false; - // Do not inline expressions used by other expressions, as any desired - // expression folding was taken care of by transformations. - if (user->getParentOfType()) - return false; - // Do not inline expressions used by ops with the CExpression trait. If this // was intended, the user could have been merged into the expression op. return !user->hasTrait();