Skip to content

Commit b3b0910

Browse files
committed
Wrap cast expr in parens when the ty ends with empty generic args
Fixes 4621 Previously rustfmt would remove empty generic args from cast expressions e.g. `x as i32<>` => `x as i32`. This lead to code that could not compile when the cast occurred in the context of a binary expression where the operand was a less than (`<`) or left shift (`<<`). The advice from the parse error emitted by the compiler was to wrap the cast in parentheses, which is now the behavior rustfmt employs. * `x as i32<> < y` => `(x as i32) < y` * `x as i32<> << y` => `(x as i32) << y`
1 parent 2403f82 commit b3b0910

File tree

8 files changed

+154
-3
lines changed

8 files changed

+154
-3
lines changed

src/expr.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::borrow::Cow;
22
use std::cmp::min;
3+
use std::ops::Deref;
34

45
use itertools::Itertools;
56
use rustc_ast::token::{Delimiter, LitKind};
@@ -95,13 +96,15 @@ pub(crate) fn format_expr(
9596
ast::ExprKind::Binary(op, ref lhs, ref rhs) => {
9697
// FIXME: format comments between operands and operator
9798
rewrite_all_pairs(expr, shape, context).or_else(|| {
99+
let wrap_lhs_in_parns = lhs_needs_parens(&op, lhs);
98100
rewrite_pair(
99101
&**lhs,
100102
&**rhs,
101103
PairParts::infix(&format!(" {} ", context.snippet(op.span))),
102104
context,
103105
shape,
104106
context.config.binop_separator(),
107+
wrap_lhs_in_parns,
105108
)
106109
})
107110
}
@@ -240,6 +243,7 @@ pub(crate) fn format_expr(
240243
context,
241244
shape,
242245
SeparatorPlace::Front,
246+
false,
243247
),
244248
ast::ExprKind::Type(ref expr, ref ty) => rewrite_pair(
245249
&**expr,
@@ -248,6 +252,7 @@ pub(crate) fn format_expr(
248252
context,
249253
shape,
250254
SeparatorPlace::Back,
255+
false,
251256
),
252257
ast::ExprKind::Index(ref expr, ref index) => {
253258
rewrite_index(&**expr, &**index, context, shape)
@@ -259,6 +264,7 @@ pub(crate) fn format_expr(
259264
context,
260265
shape,
261266
SeparatorPlace::Back,
267+
false,
262268
),
263269
ast::ExprKind::Range(ref lhs, ref rhs, limits) => {
264270
let delim = match limits {
@@ -313,6 +319,7 @@ pub(crate) fn format_expr(
313319
context,
314320
shape,
315321
context.config.binop_separator(),
322+
false,
316323
)
317324
}
318325
(None, Some(rhs)) => {
@@ -409,6 +416,35 @@ pub(crate) fn format_expr(
409416
})
410417
}
411418

419+
/// Check if we need to wrap the lhs of a binary expression in parens to avoid compilation errors.
420+
/// See <https://github.com/rust-lang/rustfmt/issues/4621>
421+
pub(crate) fn lhs_needs_parens(op: &ast::BinOp, lhs: &ast::Expr) -> bool {
422+
let is_lt_or_shl = matches!(op.node, ast::BinOpKind::Shl | ast::BinOpKind::Lt);
423+
if !is_lt_or_shl {
424+
return false;
425+
}
426+
matches!(
427+
lhs.kind,
428+
ast::ExprKind::Cast(_, ref ty) if ty_ends_with_empty_angle_brackets(ty)
429+
)
430+
}
431+
432+
/// Check if they type ends with an empty generic argument list e.g. `i32<>`.
433+
fn ty_ends_with_empty_angle_brackets(ty: &ast::Ty) -> bool {
434+
if let ast::TyKind::Path(_, path) = &ty.kind {
435+
matches!(
436+
path.segments.last(),
437+
Some(ast::PathSegment {args: Some(generic_args), ..})
438+
if matches!(
439+
generic_args.deref(),
440+
ast::GenericArgs::AngleBracketed(bracket_args) if bracket_args.args.is_empty()
441+
)
442+
)
443+
} else {
444+
false
445+
}
446+
}
447+
412448
pub(crate) fn rewrite_array<'a, T: 'a + IntoOverflowableItem<'a>>(
413449
name: &'a str,
414450
exprs: impl Iterator<Item = &'a T>,

src/pairs.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rustc_ast::ast;
22

33
use crate::config::lists::*;
44
use crate::config::IndentStyle;
5+
use crate::expr::lhs_needs_parens;
56
use crate::rewrite::{Rewrite, RewriteContext};
67
use crate::shape::Shape;
78
use crate::utils::{
@@ -157,23 +158,31 @@ pub(crate) fn rewrite_pair<LHS, RHS>(
157158
context: &RewriteContext<'_>,
158159
shape: Shape,
159160
separator_place: SeparatorPlace,
161+
wrap_lhs_in_parens: bool,
160162
) -> Option<String>
161163
where
162164
LHS: Rewrite,
163165
RHS: Rewrite,
164166
{
165167
let tab_spaces = context.config.tab_spaces();
166168
let lhs_overhead = match separator_place {
169+
SeparatorPlace::Back if wrap_lhs_in_parens => {
170+
shape.used_width() + pp.prefix.len() + pp.infix.trim_end().len() + 2
171+
}
167172
SeparatorPlace::Back => shape.used_width() + pp.prefix.len() + pp.infix.trim_end().len(),
168173
SeparatorPlace::Front => shape.used_width(),
169174
};
170175
let lhs_shape = Shape {
171176
width: context.budget(lhs_overhead),
172177
..shape
173178
};
174-
let lhs_result = lhs
175-
.rewrite(context, lhs_shape)
176-
.map(|lhs_str| format!("{}{}", pp.prefix, lhs_str))?;
179+
let lhs_result = lhs.rewrite(context, lhs_shape).map(|lhs_str| {
180+
if wrap_lhs_in_parens {
181+
format!("{}({})", pp.prefix, lhs_str)
182+
} else {
183+
format!("{}{}", pp.prefix, lhs_str)
184+
}
185+
})?;
177186

178187
// Try to put both lhs and rhs on the same line.
179188
let rhs_orig_result = shape
@@ -298,6 +307,12 @@ impl FlattenPair for ast::Expr {
298307
match pop.kind {
299308
ast::ExprKind::Binary(op, _, ref rhs) => {
300309
separators.push(op.node.to_string());
310+
if lhs_needs_parens(&op, node) {
311+
// safe to unwrap since we just pushed onto the list
312+
let (lhs, rw) = list.pop().unwrap();
313+
let rw = rw.and_then(|s| Some(format!("({})", s)));
314+
list.push((lhs, rw));
315+
}
301316
node = rhs;
302317
}
303318
_ => unreachable!(),

src/patterns.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ impl Rewrite for Pat {
217217
context,
218218
shape,
219219
SeparatorPlace::Front,
220+
false,
220221
)
221222
}
222223
PatKind::Ref(ref pat, mutability) => {

src/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,7 @@ impl Rewrite for ast::Ty {
809809
context,
810810
shape,
811811
SeparatorPlace::Back,
812+
false,
812813
),
813814
ast::TyKind::Infer => {
814815
if shape.width >= 1 {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
fn less_than_or_equal_operand() {
2+
let x: u32 = 100;
3+
if x as i32<> <= 0 {
4+
// ...
5+
}
6+
}
7+
8+
fn long_binary_op_chain_no_wrap() {
9+
let x: u32 = 100;
10+
if x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 && x as i32 <= 0 {
11+
// ...
12+
}
13+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
fn less_than_operand() {
2+
let x: u32 = 100;
3+
if x as i32<> < 0 {
4+
// ...
5+
}
6+
}
7+
8+
fn left_shift_operand() {
9+
let x: u32 = 100;
10+
if x as i32<> << 1 < 0 {
11+
// ...
12+
}
13+
}
14+
15+
fn long_binary_op_chain_wrap_all() {
16+
let x: u32 = 100;
17+
if x as i32<> < 0 && x as i32<> < 0 && x as i32<> << 1 < 0 && x as i32<> << 1 < 0 && x as i32<> << 1 < 0 && x as i32<> << 1 < 0 {
18+
// ...
19+
}
20+
}
21+
22+
fn long_binary_op_chain_wrap_some() {
23+
let x: u32 = 100;
24+
if x as i32<> < 0 && x as i32<> <= 0 && x as i32<> << 1 < 0 && x as i32<> <= 0 && x as i32<> << 1 < 0 {
25+
// ...
26+
}
27+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
fn less_than_or_equal_operand() {
2+
let x: u32 = 100;
3+
if x as i32 <= 0 {
4+
// ...
5+
}
6+
}
7+
8+
fn long_binary_op_chain_no_wrap() {
9+
let x: u32 = 100;
10+
if x as i32 <= 0
11+
&& x as i32 <= 0
12+
&& x as i32 <= 0
13+
&& x as i32 <= 0
14+
&& x as i32 <= 0
15+
&& x as i32 <= 0
16+
&& x as i32 <= 0
17+
{
18+
// ...
19+
}
20+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
fn less_than_operand() {
2+
let x: u32 = 100;
3+
if (x as i32) < 0 {
4+
// ...
5+
}
6+
}
7+
8+
fn left_shift_operand() {
9+
let x: u32 = 100;
10+
if (x as i32) << 1 < 0 {
11+
// ...
12+
}
13+
}
14+
15+
fn long_binary_op_chain_wrap_all() {
16+
let x: u32 = 100;
17+
if (x as i32) < 0
18+
&& (x as i32) < 0
19+
&& (x as i32) << 1 < 0
20+
&& (x as i32) << 1 < 0
21+
&& (x as i32) << 1 < 0
22+
&& (x as i32) << 1 < 0
23+
{
24+
// ...
25+
}
26+
}
27+
28+
fn long_binary_op_chain_wrap_some() {
29+
let x: u32 = 100;
30+
if (x as i32) < 0
31+
&& x as i32 <= 0
32+
&& (x as i32) << 1 < 0
33+
&& x as i32 <= 0
34+
&& (x as i32) << 1 < 0
35+
{
36+
// ...
37+
}
38+
}

0 commit comments

Comments
 (0)