Skip to content

Commit 492b3de

Browse files
committed
Auto merge of #13971 - lowr:fix/more-precise-builtin-binop-types, r=Veykril
fix: more precise binop inference While inferring binary operator expressions, Rust puts some extra constraints on the types of the operands for better inference. Relevant part in rustc is [this](https://github.com/rust-lang/rust/blob/159ba8a92c9e2fa4121f106176309521f4af87e9/compiler/rustc_hir_typeck/src/op.rs#L128-L152). There are two things we currently fail to consider: - we should enforce them only when both lhs and rhs type are builtin types that are applicable to the binop - lhs and rhs types may be single reference to applicable builtin types This PR basically ports [`enforce_builtin_binop_types()`](https://github.com/rust-lang/rust/blob/159ba8a92c9e2fa4121f106176309521f4af87e9/compiler/rustc_hir_typeck/src/op.rs#L159) and [`is_builtin_binop()`](https://github.com/rust-lang/rust/blob/159ba8a92c9e2fa4121f106176309521f4af87e9/compiler/rustc_hir_typeck/src/op.rs#LL927) to our inference context.
2 parents fa87462 + c53064f commit 492b3de

File tree

4 files changed

+247
-123
lines changed

4 files changed

+247
-123
lines changed

crates/hir-ty/src/chalk_ext.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Various extensions traits for Chalk types.
22
3-
use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, UintTy};
3+
use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, TyVariableKind, UintTy};
44
use hir_def::{
55
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinType, BuiltinUint},
66
generics::TypeOrConstParamData,
@@ -18,6 +18,8 @@ use crate::{
1818

1919
pub trait TyExt {
2020
fn is_unit(&self) -> bool;
21+
fn is_integral(&self) -> bool;
22+
fn is_floating_point(&self) -> bool;
2123
fn is_never(&self) -> bool;
2224
fn is_unknown(&self) -> bool;
2325
fn is_ty_var(&self) -> bool;
@@ -51,6 +53,21 @@ impl TyExt for Ty {
5153
matches!(self.kind(Interner), TyKind::Tuple(0, _))
5254
}
5355

56+
fn is_integral(&self) -> bool {
57+
matches!(
58+
self.kind(Interner),
59+
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
60+
| TyKind::InferenceVar(_, TyVariableKind::Integer)
61+
)
62+
}
63+
64+
fn is_floating_point(&self) -> bool {
65+
matches!(
66+
self.kind(Interner),
67+
TyKind::Scalar(Scalar::Float(_)) | TyKind::InferenceVar(_, TyVariableKind::Float)
68+
)
69+
}
70+
5471
fn is_never(&self) -> bool {
5572
matches!(self.kind(Interner), TyKind::Never)
5673
}

crates/hir-ty/src/infer.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,10 +1041,6 @@ impl Expectation {
10411041
}
10421042
}
10431043

1044-
fn from_option(ty: Option<Ty>) -> Self {
1045-
ty.map_or(Expectation::None, Expectation::HasType)
1046-
}
1047-
10481044
/// The following explanation is copied straight from rustc:
10491045
/// Provides an expectation for an rvalue expression given an *optional*
10501046
/// hint, which is not required for type safety (the resulting type might

crates/hir-ty/src/infer/expr.rs

Lines changed: 131 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ use chalk_ir::{
1010
};
1111
use hir_def::{
1212
expr::{
13-
ArithOp, Array, BinaryOp, ClosureKind, CmpOp, Expr, ExprId, LabelId, Literal, Statement,
14-
UnaryOp,
13+
ArithOp, Array, BinaryOp, ClosureKind, Expr, ExprId, LabelId, Literal, Statement, UnaryOp,
1514
},
1615
generics::TypeOrConstParamData,
1716
path::{GenericArg, GenericArgs},
@@ -1017,11 +1016,21 @@ impl<'a> InferenceContext<'a> {
10171016
let (trait_, func) = match trait_func {
10181017
Some(it) => it,
10191018
None => {
1020-
let rhs_ty = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone());
1021-
let rhs_ty = self.infer_expr_coerce(rhs, &Expectation::from_option(rhs_ty));
1022-
return self
1023-
.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty)
1024-
.unwrap_or_else(|| self.err_ty());
1019+
// HACK: `rhs_ty` is a general inference variable with no clue at all at this
1020+
// point. Passing `lhs_ty` as both operands just to check if `lhs_ty` is a builtin
1021+
// type applicable to `op`.
1022+
let ret_ty = if self.is_builtin_binop(&lhs_ty, &lhs_ty, op) {
1023+
// Assume both operands are builtin so we can continue inference. No guarantee
1024+
// on the correctness, rustc would complain as necessary lang items don't seem
1025+
// to exist anyway.
1026+
self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op)
1027+
} else {
1028+
self.err_ty()
1029+
};
1030+
1031+
self.infer_expr_coerce(rhs, &Expectation::has_type(rhs_ty));
1032+
1033+
return ret_ty;
10251034
}
10261035
};
10271036

@@ -1071,11 +1080,9 @@ impl<'a> InferenceContext<'a> {
10711080

10721081
let ret_ty = self.normalize_associated_types_in(ret_ty);
10731082

1074-
// use knowledge of built-in binary ops, which can sometimes help inference
1075-
if let Some(builtin_rhs) = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone()) {
1076-
self.unify(&builtin_rhs, &rhs_ty);
1077-
}
1078-
if let Some(builtin_ret) = self.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty) {
1083+
if self.is_builtin_binop(&lhs_ty, &rhs_ty, op) {
1084+
// use knowledge of built-in binary ops, which can sometimes help inference
1085+
let builtin_ret = self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op);
10791086
self.unify(&builtin_ret, &ret_ty);
10801087
}
10811088

@@ -1477,92 +1484,124 @@ impl<'a> InferenceContext<'a> {
14771484
indices
14781485
}
14791486

1480-
fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty) -> Option<Ty> {
1481-
let lhs_ty = self.resolve_ty_shallow(&lhs_ty);
1482-
let rhs_ty = self.resolve_ty_shallow(&rhs_ty);
1483-
match op {
1484-
BinaryOp::LogicOp(_) | BinaryOp::CmpOp(_) => {
1485-
Some(TyKind::Scalar(Scalar::Bool).intern(Interner))
1487+
/// Dereferences a single level of immutable referencing.
1488+
fn deref_ty_if_possible(&mut self, ty: &Ty) -> Ty {
1489+
let ty = self.resolve_ty_shallow(ty);
1490+
match ty.kind(Interner) {
1491+
TyKind::Ref(Mutability::Not, _, inner) => self.resolve_ty_shallow(inner),
1492+
_ => ty,
1493+
}
1494+
}
1495+
1496+
/// Enforces expectations on lhs type and rhs type depending on the operator and returns the
1497+
/// output type of the binary op.
1498+
fn enforce_builtin_binop_types(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> Ty {
1499+
// Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
1500+
let lhs = self.deref_ty_if_possible(lhs);
1501+
let rhs = self.deref_ty_if_possible(rhs);
1502+
1503+
let (op, is_assign) = match op {
1504+
BinaryOp::Assignment { op: Some(inner) } => (BinaryOp::ArithOp(inner), true),
1505+
_ => (op, false),
1506+
};
1507+
1508+
let output_ty = match op {
1509+
BinaryOp::LogicOp(_) => {
1510+
let bool_ = self.result.standard_types.bool_.clone();
1511+
self.unify(&lhs, &bool_);
1512+
self.unify(&rhs, &bool_);
1513+
bool_
14861514
}
1487-
BinaryOp::Assignment { .. } => Some(TyBuilder::unit()),
1515+
14881516
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
1489-
// all integer combinations are valid here
1490-
if matches!(
1491-
lhs_ty.kind(Interner),
1492-
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
1493-
| TyKind::InferenceVar(_, TyVariableKind::Integer)
1494-
) && matches!(
1495-
rhs_ty.kind(Interner),
1496-
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
1497-
| TyKind::InferenceVar(_, TyVariableKind::Integer)
1498-
) {
1499-
Some(lhs_ty)
1500-
} else {
1501-
None
1502-
}
1517+
// result type is same as LHS always
1518+
lhs
15031519
}
1504-
BinaryOp::ArithOp(_) => match (lhs_ty.kind(Interner), rhs_ty.kind(Interner)) {
1505-
// (int, int) | (uint, uint) | (float, float)
1506-
(TyKind::Scalar(Scalar::Int(_)), TyKind::Scalar(Scalar::Int(_)))
1507-
| (TyKind::Scalar(Scalar::Uint(_)), TyKind::Scalar(Scalar::Uint(_)))
1508-
| (TyKind::Scalar(Scalar::Float(_)), TyKind::Scalar(Scalar::Float(_))) => {
1509-
Some(rhs_ty)
1510-
}
1511-
// ({int}, int) | ({int}, uint)
1512-
(
1513-
TyKind::InferenceVar(_, TyVariableKind::Integer),
1514-
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
1515-
) => Some(rhs_ty),
1516-
// (int, {int}) | (uint, {int})
1517-
(
1518-
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
1519-
TyKind::InferenceVar(_, TyVariableKind::Integer),
1520-
) => Some(lhs_ty),
1521-
// ({float} | float)
1522-
(
1523-
TyKind::InferenceVar(_, TyVariableKind::Float),
1524-
TyKind::Scalar(Scalar::Float(_)),
1525-
) => Some(rhs_ty),
1526-
// (float, {float})
1527-
(
1528-
TyKind::Scalar(Scalar::Float(_)),
1529-
TyKind::InferenceVar(_, TyVariableKind::Float),
1530-
) => Some(lhs_ty),
1531-
// ({int}, {int}) | ({float}, {float})
1532-
(
1533-
TyKind::InferenceVar(_, TyVariableKind::Integer),
1534-
TyKind::InferenceVar(_, TyVariableKind::Integer),
1535-
)
1536-
| (
1537-
TyKind::InferenceVar(_, TyVariableKind::Float),
1538-
TyKind::InferenceVar(_, TyVariableKind::Float),
1539-
) => Some(rhs_ty),
1540-
_ => None,
1541-
},
1520+
1521+
BinaryOp::ArithOp(_) => {
1522+
// LHS, RHS, and result will have the same type
1523+
self.unify(&lhs, &rhs);
1524+
lhs
1525+
}
1526+
1527+
BinaryOp::CmpOp(_) => {
1528+
// LHS and RHS will have the same type
1529+
self.unify(&lhs, &rhs);
1530+
self.result.standard_types.bool_.clone()
1531+
}
1532+
1533+
BinaryOp::Assignment { op: None } => {
1534+
stdx::never!("Simple assignment operator is not binary op.");
1535+
lhs
1536+
}
1537+
1538+
BinaryOp::Assignment { .. } => unreachable!("handled above"),
1539+
};
1540+
1541+
if is_assign {
1542+
self.result.standard_types.unit.clone()
1543+
} else {
1544+
output_ty
15421545
}
15431546
}
15441547

1545-
fn builtin_binary_op_rhs_expectation(&mut self, op: BinaryOp, lhs_ty: Ty) -> Option<Ty> {
1546-
Some(match op {
1547-
BinaryOp::LogicOp(..) => TyKind::Scalar(Scalar::Bool).intern(Interner),
1548-
BinaryOp::Assignment { op: None } => lhs_ty,
1549-
BinaryOp::CmpOp(CmpOp::Eq { .. }) => match self
1550-
.resolve_ty_shallow(&lhs_ty)
1551-
.kind(Interner)
1552-
{
1553-
TyKind::Scalar(_) | TyKind::Str => lhs_ty,
1554-
TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty,
1555-
_ => return None,
1556-
},
1557-
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => return None,
1558-
BinaryOp::CmpOp(CmpOp::Ord { .. })
1559-
| BinaryOp::Assignment { op: Some(_) }
1560-
| BinaryOp::ArithOp(_) => match self.resolve_ty_shallow(&lhs_ty).kind(Interner) {
1561-
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_) | Scalar::Float(_)) => lhs_ty,
1562-
TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty,
1563-
_ => return None,
1564-
},
1565-
})
1548+
fn is_builtin_binop(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> bool {
1549+
// Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
1550+
let lhs = self.deref_ty_if_possible(lhs);
1551+
let rhs = self.deref_ty_if_possible(rhs);
1552+
1553+
let op = match op {
1554+
BinaryOp::Assignment { op: Some(inner) } => BinaryOp::ArithOp(inner),
1555+
_ => op,
1556+
};
1557+
1558+
match op {
1559+
BinaryOp::LogicOp(_) => true,
1560+
1561+
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
1562+
lhs.is_integral() && rhs.is_integral()
1563+
}
1564+
1565+
BinaryOp::ArithOp(
1566+
ArithOp::Add | ArithOp::Sub | ArithOp::Mul | ArithOp::Div | ArithOp::Rem,
1567+
) => {
1568+
lhs.is_integral() && rhs.is_integral()
1569+
|| lhs.is_floating_point() && rhs.is_floating_point()
1570+
}
1571+
1572+
BinaryOp::ArithOp(ArithOp::BitAnd | ArithOp::BitOr | ArithOp::BitXor) => {
1573+
lhs.is_integral() && rhs.is_integral()
1574+
|| lhs.is_floating_point() && rhs.is_floating_point()
1575+
|| matches!(
1576+
(lhs.kind(Interner), rhs.kind(Interner)),
1577+
(TyKind::Scalar(Scalar::Bool), TyKind::Scalar(Scalar::Bool))
1578+
)
1579+
}
1580+
1581+
BinaryOp::CmpOp(_) => {
1582+
let is_scalar = |kind| {
1583+
matches!(
1584+
kind,
1585+
&TyKind::Scalar(_)
1586+
| TyKind::FnDef(..)
1587+
| TyKind::Function(_)
1588+
| TyKind::Raw(..)
1589+
| TyKind::InferenceVar(
1590+
_,
1591+
TyVariableKind::Integer | TyVariableKind::Float
1592+
)
1593+
)
1594+
};
1595+
is_scalar(lhs.kind(Interner)) && is_scalar(rhs.kind(Interner))
1596+
}
1597+
1598+
BinaryOp::Assignment { op: None } => {
1599+
stdx::never!("Simple assignment operator is not binary op.");
1600+
false
1601+
}
1602+
1603+
BinaryOp::Assignment { .. } => unreachable!("handled above"),
1604+
}
15661605
}
15671606

15681607
fn with_breakable_ctx<T>(

0 commit comments

Comments
 (0)