From b3b0910baa2c28f0ff713bcdc04ead550503438e Mon Sep 17 00:00:00 2001 From: Yacin Tmimi Date: Mon, 13 Jun 2022 22:12:00 -0400 Subject: [PATCH] Wrap cast expr in parens when the ty ends with empty generic args Fixes 4621 Previously rustfmt would remove empty generic args from cast expressions e.g. `x as i32<>` => `x as i32`. This lead to code that could not compile when the cast occurred in the context of a binary expression where the operand was a less than (`<`) or left shift (`<<`). The advice from the parse error emitted by the compiler was to wrap the cast in parentheses, which is now the behavior rustfmt employs. * `x as i32<> < y` => `(x as i32) < y` * `x as i32<> << y` => `(x as i32) << y` --- src/expr.rs | 36 ++++++++++++++++++ src/pairs.rs | 21 ++++++++-- src/patterns.rs | 1 + src/types.rs | 1 + .../issue-4621/do_not_wrap_in_parens.rs | 13 +++++++ tests/source/issue-4621/wrap_in_parens.rs | 27 +++++++++++++ .../issue-4621/do_not_wrap_in_parens.rs | 20 ++++++++++ tests/target/issue-4621/wrap_in_parens.rs | 38 +++++++++++++++++++ 8 files changed, 154 insertions(+), 3 deletions(-) create mode 100644 tests/source/issue-4621/do_not_wrap_in_parens.rs create mode 100644 tests/source/issue-4621/wrap_in_parens.rs create mode 100644 tests/target/issue-4621/do_not_wrap_in_parens.rs create mode 100644 tests/target/issue-4621/wrap_in_parens.rs diff --git a/src/expr.rs b/src/expr.rs index 13d068d0c2d..cfcaab54145 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::cmp::min; +use std::ops::Deref; use itertools::Itertools; use rustc_ast::token::{Delimiter, LitKind}; @@ -95,6 +96,7 @@ pub(crate) fn format_expr( ast::ExprKind::Binary(op, ref lhs, ref rhs) => { // FIXME: format comments between operands and operator rewrite_all_pairs(expr, shape, context).or_else(|| { + let wrap_lhs_in_parns = lhs_needs_parens(&op, lhs); rewrite_pair( &**lhs, &**rhs, @@ -102,6 +104,7 @@ pub(crate) fn format_expr( context, shape, context.config.binop_separator(), + wrap_lhs_in_parns, ) }) } @@ -240,6 +243,7 @@ pub(crate) fn format_expr( context, shape, SeparatorPlace::Front, + false, ), ast::ExprKind::Type(ref expr, ref ty) => rewrite_pair( &**expr, @@ -248,6 +252,7 @@ pub(crate) fn format_expr( context, shape, SeparatorPlace::Back, + false, ), ast::ExprKind::Index(ref expr, ref index) => { rewrite_index(&**expr, &**index, context, shape) @@ -259,6 +264,7 @@ pub(crate) fn format_expr( context, shape, SeparatorPlace::Back, + false, ), ast::ExprKind::Range(ref lhs, ref rhs, limits) => { let delim = match limits { @@ -313,6 +319,7 @@ pub(crate) fn format_expr( context, shape, context.config.binop_separator(), + false, ) } (None, Some(rhs)) => { @@ -409,6 +416,35 @@ pub(crate) fn format_expr( }) } +/// Check if we need to wrap the lhs of a binary expression in parens to avoid compilation errors. +/// See +pub(crate) fn lhs_needs_parens(op: &ast::BinOp, lhs: &ast::Expr) -> bool { + let is_lt_or_shl = matches!(op.node, ast::BinOpKind::Shl | ast::BinOpKind::Lt); + if !is_lt_or_shl { + return false; + } + matches!( + lhs.kind, + ast::ExprKind::Cast(_, ref ty) if ty_ends_with_empty_angle_brackets(ty) + ) +} + +/// Check if they type ends with an empty generic argument list e.g. `i32<>`. +fn ty_ends_with_empty_angle_brackets(ty: &ast::Ty) -> bool { + if let ast::TyKind::Path(_, path) = &ty.kind { + matches!( + path.segments.last(), + Some(ast::PathSegment {args: Some(generic_args), ..}) + if matches!( + generic_args.deref(), + ast::GenericArgs::AngleBracketed(bracket_args) if bracket_args.args.is_empty() + ) + ) + } else { + false + } +} + pub(crate) fn rewrite_array<'a, T: 'a + IntoOverflowableItem<'a>>( name: &'a str, exprs: impl Iterator, diff --git a/src/pairs.rs b/src/pairs.rs index d1c75126ea4..032fda893b2 100644 --- a/src/pairs.rs +++ b/src/pairs.rs @@ -2,6 +2,7 @@ use rustc_ast::ast; use crate::config::lists::*; use crate::config::IndentStyle; +use crate::expr::lhs_needs_parens; use crate::rewrite::{Rewrite, RewriteContext}; use crate::shape::Shape; use crate::utils::{ @@ -157,6 +158,7 @@ pub(crate) fn rewrite_pair( context: &RewriteContext<'_>, shape: Shape, separator_place: SeparatorPlace, + wrap_lhs_in_parens: bool, ) -> Option where LHS: Rewrite, @@ -164,6 +166,9 @@ where { let tab_spaces = context.config.tab_spaces(); let lhs_overhead = match separator_place { + SeparatorPlace::Back if wrap_lhs_in_parens => { + shape.used_width() + pp.prefix.len() + pp.infix.trim_end().len() + 2 + } SeparatorPlace::Back => shape.used_width() + pp.prefix.len() + pp.infix.trim_end().len(), SeparatorPlace::Front => shape.used_width(), }; @@ -171,9 +176,13 @@ where width: context.budget(lhs_overhead), ..shape }; - let lhs_result = lhs - .rewrite(context, lhs_shape) - .map(|lhs_str| format!("{}{}", pp.prefix, lhs_str))?; + let lhs_result = lhs.rewrite(context, lhs_shape).map(|lhs_str| { + if wrap_lhs_in_parens { + format!("{}({})", pp.prefix, lhs_str) + } else { + format!("{}{}", pp.prefix, lhs_str) + } + })?; // Try to put both lhs and rhs on the same line. let rhs_orig_result = shape @@ -298,6 +307,12 @@ impl FlattenPair for ast::Expr { match pop.kind { ast::ExprKind::Binary(op, _, ref rhs) => { separators.push(op.node.to_string()); + if lhs_needs_parens(&op, node) { + // safe to unwrap since we just pushed onto the list + let (lhs, rw) = list.pop().unwrap(); + let rw = rw.and_then(|s| Some(format!("({})", s))); + list.push((lhs, rw)); + } node = rhs; } _ => unreachable!(), diff --git a/src/patterns.rs b/src/patterns.rs index 9b74b35f314..60946fe0fc2 100644 --- a/src/patterns.rs +++ b/src/patterns.rs @@ -217,6 +217,7 @@ impl Rewrite for Pat { context, shape, SeparatorPlace::Front, + false, ) } PatKind::Ref(ref pat, mutability) => { diff --git a/src/types.rs b/src/types.rs index 64a201e45dd..fcf06833cc3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -809,6 +809,7 @@ impl Rewrite for ast::Ty { context, shape, SeparatorPlace::Back, + false, ), ast::TyKind::Infer => { if shape.width >= 1 { diff --git a/tests/source/issue-4621/do_not_wrap_in_parens.rs b/tests/source/issue-4621/do_not_wrap_in_parens.rs new file mode 100644 index 00000000000..f68edd942aa --- /dev/null +++ b/tests/source/issue-4621/do_not_wrap_in_parens.rs @@ -0,0 +1,13 @@ +fn less_than_or_equal_operand() { + let x: u32 = 100; + if x as i32<> <= 0 { + // ... + } +} + +fn long_binary_op_chain_no_wrap() { + let x: u32 = 100; + if x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 { + // ... + } +} diff --git a/tests/source/issue-4621/wrap_in_parens.rs b/tests/source/issue-4621/wrap_in_parens.rs new file mode 100644 index 00000000000..962a40a53e6 --- /dev/null +++ b/tests/source/issue-4621/wrap_in_parens.rs @@ -0,0 +1,27 @@ +fn less_than_operand() { + let x: u32 = 100; + if x as i32<> < 0 { + // ... + } +} + +fn left_shift_operand() { + let x: u32 = 100; + if x as i32<> << 1 < 0 { + // ... + } +} + +fn long_binary_op_chain_wrap_all() { + let x: u32 = 100; + if x as i32<> < 0 && x as i32<> < 0 && x as i32<> << 1 < 0 && x as i32<> << 1 < 0 && x as i32<> << 1 < 0 && x as i32<> << 1 < 0 { + // ... + } +} + +fn long_binary_op_chain_wrap_some() { + let x: u32 = 100; + if x as i32<> < 0 && x as i32<> <= 0 && x as i32<> << 1 < 0 && x as i32<> <= 0 && x as i32<> << 1 < 0 { + // ... + } +} diff --git a/tests/target/issue-4621/do_not_wrap_in_parens.rs b/tests/target/issue-4621/do_not_wrap_in_parens.rs new file mode 100644 index 00000000000..b63c6540f95 --- /dev/null +++ b/tests/target/issue-4621/do_not_wrap_in_parens.rs @@ -0,0 +1,20 @@ +fn less_than_or_equal_operand() { + let x: u32 = 100; + if x as i32 <= 0 { + // ... + } +} + +fn long_binary_op_chain_no_wrap() { + let x: u32 = 100; + if x as i32 <= 0 + && x as i32 <= 0 + && x as i32 <= 0 + && x as i32 <= 0 + && x as i32 <= 0 + && x as i32 <= 0 + && x as i32 <= 0 + { + // ... + } +} diff --git a/tests/target/issue-4621/wrap_in_parens.rs b/tests/target/issue-4621/wrap_in_parens.rs new file mode 100644 index 00000000000..b2e5f4b0ba2 --- /dev/null +++ b/tests/target/issue-4621/wrap_in_parens.rs @@ -0,0 +1,38 @@ +fn less_than_operand() { + let x: u32 = 100; + if (x as i32) < 0 { + // ... + } +} + +fn left_shift_operand() { + let x: u32 = 100; + if (x as i32) << 1 < 0 { + // ... + } +} + +fn long_binary_op_chain_wrap_all() { + let x: u32 = 100; + if (x as i32) < 0 + && (x as i32) < 0 + && (x as i32) << 1 < 0 + && (x as i32) << 1 < 0 + && (x as i32) << 1 < 0 + && (x as i32) << 1 < 0 + { + // ... + } +} + +fn long_binary_op_chain_wrap_some() { + let x: u32 = 100; + if (x as i32) < 0 + && x as i32 <= 0 + && (x as i32) << 1 < 0 + && x as i32 <= 0 + && (x as i32) << 1 < 0 + { + // ... + } +}